mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
♻️ refactor(llm.session): 重构记忆系统以分离存储和策略
This commit is contained in:
parent
61279f8b3d
commit
a4decd8524
@ -34,7 +34,13 @@ from .manager import (
|
||||
list_model_identifiers,
|
||||
set_global_default_model_name,
|
||||
)
|
||||
from .session import AI, AIConfig, MemoryProcessor, set_default_memory_backend
|
||||
from .memory import (
|
||||
AIConfig,
|
||||
BaseMemory,
|
||||
MemoryProcessor,
|
||||
set_default_memory_backend,
|
||||
)
|
||||
from .session import AI
|
||||
from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager
|
||||
from .types import (
|
||||
EmbeddingTaskType,
|
||||
@ -62,6 +68,7 @@ from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_l
|
||||
__all__ = [
|
||||
"AI",
|
||||
"AIConfig",
|
||||
"BaseMemory",
|
||||
"CommonOverrides",
|
||||
"EmbeddingTaskType",
|
||||
"GeminiCodeExecution",
|
||||
|
||||
243
zhenxun/services/llm/memory.py
Normal file
243
zhenxun/services/llm/memory.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""
|
||||
LLM 服务 - 会话记忆模块
|
||||
|
||||
定义了LLM会话记忆的存储、策略和处理接口。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from zhenxun.services.llm.types import LLMMessage
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
|
||||
class AIConfig(BaseModel):
|
||||
"""AI配置类 (为保持独立性而在此处保留一个副本,实际使用中可能来自更高层)"""
|
||||
|
||||
model: Any = None
|
||||
default_embedding_model: Any = None
|
||||
default_preserve_media_in_history: bool = False
|
||||
tool_providers: list[Any] = Field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseMessageStore(ABC):
|
||||
"""
|
||||
底层存储接口 (DAO - Data Access Object)。
|
||||
|
||||
这是一个抽象基类,定义了消息数据最底层的 **持久化与检索 (CRUD)** 接口。
|
||||
它只关心数据的存取,不涉及任何业务逻辑(如历史记录修剪)。
|
||||
|
||||
开发者如果希望将对话历史存储到 Redis、数据库或其他持久化后端,
|
||||
应当实现这个接口。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages(self, session_id: str) -> list[LLMMessage]:
|
||||
"""
|
||||
根据会话ID获取完整的消息列表。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
"""追加消息"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
"""
|
||||
完全覆盖指定会话ID的消息列表。
|
||||
主要用于历史记录修剪等场景。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, session_id: str) -> None:
|
||||
"""清空指定会话ID的所有消息数据。"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryMessageStore(BaseMessageStore):
|
||||
"""
|
||||
一个基于内存的 `BaseMessageStore` 实现。
|
||||
|
||||
它使用一个Python字典来存储所有会话的消息,提供了最简单、最快速的存储方案。
|
||||
这是框架的默认存储方式,实现了开箱即用。
|
||||
|
||||
注意:此实现是 **非持久化** 的,当应用程序重启时,所有对话历史都会丢失。
|
||||
适用于测试、简单应用或不需要长期记忆的场景。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data: dict[str, list[LLMMessage]] = defaultdict(list)
|
||||
|
||||
async def get_messages(self, session_id: str) -> list[LLMMessage]:
|
||||
"""从内存字典中获取消息列表的副本。"""
|
||||
return self._data.get(session_id, []).copy()
|
||||
|
||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
"""向内存中的消息列表追加消息。"""
|
||||
self._data[session_id].extend(messages)
|
||||
|
||||
async def set_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
"""在内存中直接替换指定会话的消息列表。"""
|
||||
self._data[session_id] = messages
|
||||
|
||||
async def clear(self, session_id: str) -> None:
|
||||
"""从内存字典中删除指定会话的条目。"""
|
||||
if session_id in self._data:
|
||||
del self._data[session_id]
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
记忆系统上层逻辑基类 (Strategy Layer)。
|
||||
|
||||
此抽象基类定义了记忆系统的 **策略层** 接口。它负责对外提供统一的记忆操作
|
||||
接口,并封装了具体的记忆管理策略,如历史记录的修剪、摘要生成等。
|
||||
|
||||
`AI` 会话客户端直接与此接口交互,而不关心底层的存储实现。
|
||||
|
||||
开发者可以通过实现此接口来创建自定义的记忆管理策略,例如:
|
||||
- `SummarizationMemory`: 在历史记录过长时,自动调用LLM生成摘要来压缩历史。
|
||||
- `VectorStoreMemory`: 将对话历史向量化并存入向量数据库,实现长期记忆检索。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
||||
"""获取用于构建模型输入的完整历史消息列表。"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
||||
"""向记忆中添加单条消息。默认实现是调用 `add_messages`。"""
|
||||
await self.add_messages(session_id, [message])
|
||||
|
||||
@abstractmethod
|
||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
"""向记忆中添加多条消息,并可能触发内部的记忆管理策略(如修剪)。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear_history(self, session_id: str) -> None:
|
||||
"""清空指定会话的全部记忆。"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ChatMemory(BaseMemory):
|
||||
"""
|
||||
标准聊天记忆实现:组合 Store + 滑动窗口策略。
|
||||
|
||||
这是 `BaseMemory` 的默认实现,它通过组合一个 `BaseMessageStore` 实例来
|
||||
完成实际的数据存储,并在此之上实现了一个简单的“滑动窗口”记忆修剪策略。
|
||||
"""
|
||||
|
||||
def __init__(self, store: BaseMessageStore, max_messages: int = 50):
|
||||
self.store = store
|
||||
self._max_messages = max_messages
|
||||
|
||||
async def _trim_history(self, session_id: str) -> None:
|
||||
"""
|
||||
记忆修剪策略:确保历史记录不超过 `_max_messages` 条。
|
||||
|
||||
如果存在系统消息 (System Prompt),它将被永久保留在列表的第一位。
|
||||
"""
|
||||
history = await self.store.get_messages(session_id)
|
||||
if len(history) <= self._max_messages:
|
||||
return
|
||||
|
||||
has_system = history and history[0].role == "system"
|
||||
new_history: list[LLMMessage] = []
|
||||
|
||||
if has_system:
|
||||
keep_count = max(0, self._max_messages - 1)
|
||||
new_history = [history[0], *history[-keep_count:]]
|
||||
else:
|
||||
new_history = history[-self._max_messages :]
|
||||
|
||||
await self.store.set_messages(session_id, new_history)
|
||||
|
||||
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
||||
"""直接从底层存储获取历史记录。"""
|
||||
return await self.store.get_messages(session_id)
|
||||
|
||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
"""添加消息到历史记录,并立即执行修剪策略。"""
|
||||
await self.store.add_messages(session_id, messages)
|
||||
await self._trim_history(session_id)
|
||||
|
||||
async def clear_history(self, session_id: str) -> None:
|
||||
"""清空底层存储中的历史记录。"""
|
||||
await self.store.clear(session_id)
|
||||
|
||||
|
||||
class MemoryProcessor(ABC):
|
||||
"""
|
||||
记忆处理器接口 (Hook/Observer)。
|
||||
|
||||
这是一个扩展接口,允许开发者创建自定义的“记忆处理器”,以在记忆被修改后
|
||||
执行额外的操作(“钩子”)。
|
||||
|
||||
当 `AI` 实例的记忆更新时,它会依次调用所有注册的 `MemoryProcessor`。
|
||||
|
||||
使用场景示例:
|
||||
- `LoggingMemoryProcessor`: 将每一轮对话异步记录到外部日志系统。
|
||||
- `SummarizationProcessor`: 在后台任务中检查对话长度,并在需要时生成摘要。
|
||||
- `EntityExtractionProcessor`: 从对话中提取关键实体(如人名、地名)并存储。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None:
|
||||
"""处理新添加到记忆中的消息。"""
|
||||
pass
|
||||
|
||||
|
||||
_default_memory_factory: Callable[[], BaseMemory] | None = None
|
||||
|
||||
|
||||
def set_default_memory_backend(factory: Callable[[], BaseMemory]):
|
||||
"""
|
||||
设置全局默认记忆后端工厂,允许统一替换会话的记忆实现。
|
||||
|
||||
这是一个高级依赖注入函数,允许插件或项目在启动时用自定义的 `BaseMemory`
|
||||
实现替换掉默认的 `ChatMemory(InMemoryMessageStore())`。
|
||||
|
||||
Args:
|
||||
factory: 一个无参数的、返回 `BaseMemory` 实例的函数或类。
|
||||
"""
|
||||
global _default_memory_factory
|
||||
_default_memory_factory = factory
|
||||
|
||||
|
||||
def _get_default_memory() -> BaseMemory:
|
||||
"""
|
||||
[内部函数] 获取一个默认的记忆后端实例。
|
||||
|
||||
它会首先检查是否有通过 `set_default_memory_backend` 设置的全局工厂,
|
||||
如果有,则使用该工厂创建实例;否则,返回一个标准的内存记忆实例。
|
||||
"""
|
||||
if _default_memory_factory:
|
||||
logger.debug("使用自定义的默认记忆后端工厂构建实例。")
|
||||
return _default_memory_factory()
|
||||
|
||||
logger.debug("未配置自定义记忆后端,使用默认的 ChatMemory。")
|
||||
return ChatMemory(store=InMemoryMessageStore())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AIConfig",
|
||||
"BaseMemory",
|
||||
"BaseMessageStore",
|
||||
"ChatMemory",
|
||||
"InMemoryMessageStore",
|
||||
"MemoryProcessor",
|
||||
"_get_default_memory",
|
||||
"set_default_memory_backend",
|
||||
]
|
||||
@ -4,11 +4,8 @@ LLM 服务 - 会话客户端
|
||||
提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from typing import Any, TypeVar, cast
|
||||
import uuid
|
||||
@ -28,8 +25,14 @@ from .config import (
|
||||
LLMGenerationConfig,
|
||||
)
|
||||
from .config.generation import OutputConfig
|
||||
from .config.providers import get_ai_config, get_llm_config
|
||||
from .config.providers import get_llm_config
|
||||
from .manager import get_global_default_model_name, get_model_instance
|
||||
from .memory import (
|
||||
AIConfig,
|
||||
BaseMemory,
|
||||
MemoryProcessor,
|
||||
_get_default_memory,
|
||||
)
|
||||
from .tools import tool_provider_manager
|
||||
from .types import (
|
||||
LLMContentPart,
|
||||
@ -42,7 +45,6 @@ from .types import (
|
||||
StructuredOutputStrategy,
|
||||
ToolChoice,
|
||||
ToolExecutable,
|
||||
ToolProvider,
|
||||
)
|
||||
from .types.models import (
|
||||
GeminiCodeExecution,
|
||||
@ -57,105 +59,6 @@ from .utils import (
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI配置类"""
|
||||
|
||||
model: ModelName = None
|
||||
default_embedding_model: ModelName = None
|
||||
default_preserve_media_in_history: bool = False
|
||||
tool_providers: list[ToolProvider] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后从配置中读取默认值"""
|
||||
ai_config = get_ai_config()
|
||||
if self.model is None:
|
||||
self.model = ai_config.get("default_model_name")
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""记忆系统的抽象基类。"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def clear_history(self, session_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InMemoryMemory(BaseMemory):
|
||||
"""一个简单的、默认的内存记忆后端。"""
|
||||
|
||||
def __init__(self, max_messages: int = 50, **kwargs: Any):
|
||||
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
|
||||
self._max_messages = max_messages
|
||||
|
||||
def _trim_history(self, session_id: str) -> None:
|
||||
"""修剪历史记录,确保不超过最大长度,同时保留 System Prompt"""
|
||||
history = self._history[session_id]
|
||||
if len(history) <= self._max_messages:
|
||||
return
|
||||
|
||||
has_system = history and history[0].role == "system"
|
||||
|
||||
if has_system:
|
||||
keep_count = max(0, self._max_messages - 1)
|
||||
self._history[session_id] = [history[0], *history[-keep_count:]]
|
||||
else:
|
||||
self._history[session_id] = history[-self._max_messages :]
|
||||
|
||||
async def get_history(self, session_id: str) -> list[LLMMessage]:
|
||||
return self._history.get(session_id, []).copy()
|
||||
|
||||
async def add_message(self, session_id: str, message: LLMMessage) -> None:
|
||||
self._history[session_id].append(message)
|
||||
self._trim_history(session_id)
|
||||
|
||||
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
|
||||
self._history[session_id].extend(messages)
|
||||
self._trim_history(session_id)
|
||||
|
||||
async def clear_history(self, session_id: str) -> None:
|
||||
if session_id in self._history:
|
||||
del self._history[session_id]
|
||||
|
||||
|
||||
class MemoryProcessor(ABC):
|
||||
"""记忆处理器接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
_default_memory_factory: Callable[[], BaseMemory] | None = None
|
||||
|
||||
|
||||
def set_default_memory_backend(factory: Callable[[], BaseMemory]):
|
||||
"""
|
||||
设置全局默认记忆后端工厂,允许统一替换会话的记忆实现。
|
||||
"""
|
||||
global _default_memory_factory
|
||||
_default_memory_factory = factory
|
||||
|
||||
|
||||
def _get_default_memory() -> BaseMemory:
|
||||
if _default_memory_factory:
|
||||
return _default_memory_factory()
|
||||
return InMemoryMemory()
|
||||
|
||||
|
||||
DEFAULT_IVR_TEMPLATE = (
|
||||
"你的响应未能通过结构校验。\n"
|
||||
"错误详情: {error_msg}\n\n"
|
||||
@ -186,7 +89,8 @@ class AI:
|
||||
参数:
|
||||
session_id: 唯一的会话ID,用于隔离记忆。
|
||||
config: AI 配置.
|
||||
memory: 可选的自定义记忆后端。如果为None,则使用默认的InMemoryMemory。
|
||||
memory: 可选的自定义记忆后端。如果为None,则使用默认的 ChatMemory
|
||||
(InMemoryMessageStore)。
|
||||
default_generation_config: 此AI实例的默认生成配置。
|
||||
processors: 记忆处理器列表,在添加记忆后触发。
|
||||
"""
|
||||
@ -587,9 +491,7 @@ class AI:
|
||||
final_config = LLMGenerationConfig()
|
||||
|
||||
if max_validation_retries is None:
|
||||
max_validation_retries = (
|
||||
get_llm_config().client_settings.structured_retries
|
||||
)
|
||||
max_validation_retries = get_llm_config().client_settings.structured_retries
|
||||
|
||||
resolved_model_name = self._resolve_model_name(model or self.config.model)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user