diff --git a/zhenxun/services/llm/__init__.py b/zhenxun/services/llm/__init__.py index eaa6f8a3..a870cae7 100644 --- a/zhenxun/services/llm/__init__.py +++ b/zhenxun/services/llm/__init__.py @@ -34,7 +34,13 @@ from .manager import ( list_model_identifiers, set_global_default_model_name, ) -from .session import AI, AIConfig, MemoryProcessor, set_default_memory_backend +from .memory import ( + AIConfig, + BaseMemory, + MemoryProcessor, + set_default_memory_backend, +) +from .session import AI from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager from .types import ( EmbeddingTaskType, @@ -62,6 +68,7 @@ from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_l __all__ = [ "AI", "AIConfig", + "BaseMemory", "CommonOverrides", "EmbeddingTaskType", "GeminiCodeExecution", diff --git a/zhenxun/services/llm/memory.py b/zhenxun/services/llm/memory.py new file mode 100644 index 00000000..c0b8b80f --- /dev/null +++ b/zhenxun/services/llm/memory.py @@ -0,0 +1,243 @@ +""" +LLM 服务 - 会话记忆模块 + +定义了LLM会话记忆的存储、策略和处理接口。 +""" + +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +from pydantic import BaseModel, Field + +from zhenxun.services.llm.types import LLMMessage +from zhenxun.services.log import logger + + +class AIConfig(BaseModel): + """AI配置类 (为保持独立性而在此处保留一个副本,实际使用中可能来自更高层)""" + + model: Any = None + default_embedding_model: Any = None + default_preserve_media_in_history: bool = False + tool_providers: list[Any] = Field(default_factory=list) + + def __post_init__(self): + """初始化后从配置中读取默认值""" + pass + + +class BaseMessageStore(ABC): + """ + 底层存储接口 (DAO - Data Access Object)。 + + 这是一个抽象基类,定义了消息数据最底层的 **持久化与检索 (CRUD)** 接口。 + 它只关心数据的存取,不涉及任何业务逻辑(如历史记录修剪)。 + + 开发者如果希望将对话历史存储到 Redis、数据库或其他持久化后端, + 应当实现这个接口。 + """ + + @abstractmethod + async def get_messages(self, session_id: str) -> list[LLMMessage]: + """ + 根据会话ID获取完整的消息列表。 + """ + raise NotImplementedError + + @abstractmethod + async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + """追加消息""" + raise NotImplementedError + + @abstractmethod + async def set_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + """ + 完全覆盖指定会话ID的消息列表。 + 主要用于历史记录修剪等场景。 + """ + raise NotImplementedError + + @abstractmethod + async def clear(self, session_id: str) -> None: + """清空指定会话ID的所有消息数据。""" + raise NotImplementedError + + +class InMemoryMessageStore(BaseMessageStore): + """ + 一个基于内存的 `BaseMessageStore` 实现。 + + 它使用一个Python字典来存储所有会话的消息,提供了最简单、最快速的存储方案。 + 这是框架的默认存储方式,实现了开箱即用。 + + 注意:此实现是 **非持久化** 的,当应用程序重启时,所有对话历史都会丢失。 + 适用于测试、简单应用或不需要长期记忆的场景。 + """ + + def __init__(self): + self._data: dict[str, list[LLMMessage]] = defaultdict(list) + + async def get_messages(self, session_id: str) -> list[LLMMessage]: + """从内存字典中获取消息列表的副本。""" + return self._data.get(session_id, []).copy() + + async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + """向内存中的消息列表追加消息。""" + self._data[session_id].extend(messages) + + async def set_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + """在内存中直接替换指定会话的消息列表。""" + self._data[session_id] = messages + + async def clear(self, session_id: str) -> None: + """从内存字典中删除指定会话的条目。""" + if session_id in self._data: + del self._data[session_id] + + +class BaseMemory(ABC): + """ + 记忆系统上层逻辑基类 (Strategy Layer)。 + + 此抽象基类定义了记忆系统的 **策略层** 接口。它负责对外提供统一的记忆操作 + 接口,并封装了具体的记忆管理策略,如历史记录的修剪、摘要生成等。 + + `AI` 会话客户端直接与此接口交互,而不关心底层的存储实现。 + + 开发者可以通过实现此接口来创建自定义的记忆管理策略,例如: + - `SummarizationMemory`: 在历史记录过长时,自动调用LLM生成摘要来压缩历史。 + - `VectorStoreMemory`: 将对话历史向量化并存入向量数据库,实现长期记忆检索。 + """ + + @abstractmethod + async def get_history(self, session_id: str) -> list[LLMMessage]: + """获取用于构建模型输入的完整历史消息列表。""" + raise NotImplementedError + + async def add_message(self, session_id: str, message: LLMMessage) -> None: + """向记忆中添加单条消息。默认实现是调用 `add_messages`。""" + await self.add_messages(session_id, [message]) + + @abstractmethod + async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + """向记忆中添加多条消息,并可能触发内部的记忆管理策略(如修剪)。""" + raise NotImplementedError + + @abstractmethod + async def clear_history(self, session_id: str) -> None: + """清空指定会话的全部记忆。""" + raise NotImplementedError + + +class ChatMemory(BaseMemory): + """ + 标准聊天记忆实现:组合 Store + 滑动窗口策略。 + + 这是 `BaseMemory` 的默认实现,它通过组合一个 `BaseMessageStore` 实例来 + 完成实际的数据存储,并在此之上实现了一个简单的“滑动窗口”记忆修剪策略。 + """ + + def __init__(self, store: BaseMessageStore, max_messages: int = 50): + self.store = store + self._max_messages = max_messages + + async def _trim_history(self, session_id: str) -> None: + """ + 记忆修剪策略:确保历史记录不超过 `_max_messages` 条。 + + 如果存在系统消息 (System Prompt),它将被永久保留在列表的第一位。 + """ + history = await self.store.get_messages(session_id) + if len(history) <= self._max_messages: + return + + has_system = history and history[0].role == "system" + new_history: list[LLMMessage] = [] + + if has_system: + keep_count = max(0, self._max_messages - 1) + new_history = [history[0], *history[-keep_count:]] + else: + new_history = history[-self._max_messages :] + + await self.store.set_messages(session_id, new_history) + + async def get_history(self, session_id: str) -> list[LLMMessage]: + """直接从底层存储获取历史记录。""" + return await self.store.get_messages(session_id) + + async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: + """添加消息到历史记录,并立即执行修剪策略。""" + await self.store.add_messages(session_id, messages) + await self._trim_history(session_id) + + async def clear_history(self, session_id: str) -> None: + """清空底层存储中的历史记录。""" + await self.store.clear(session_id) + + +class MemoryProcessor(ABC): + """ + 记忆处理器接口 (Hook/Observer)。 + + 这是一个扩展接口,允许开发者创建自定义的“记忆处理器”,以在记忆被修改后 + 执行额外的操作(“钩子”)。 + + 当 `AI` 实例的记忆更新时,它会依次调用所有注册的 `MemoryProcessor`。 + + 使用场景示例: + - `LoggingMemoryProcessor`: 将每一轮对话异步记录到外部日志系统。 + - `SummarizationProcessor`: 在后台任务中检查对话长度,并在需要时生成摘要。 + - `EntityExtractionProcessor`: 从对话中提取关键实体(如人名、地名)并存储。 + """ + + @abstractmethod + async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None: + """处理新添加到记忆中的消息。""" + pass + + +_default_memory_factory: Callable[[], BaseMemory] | None = None + + +def set_default_memory_backend(factory: Callable[[], BaseMemory]): + """ + 设置全局默认记忆后端工厂,允许统一替换会话的记忆实现。 + + 这是一个高级依赖注入函数,允许插件或项目在启动时用自定义的 `BaseMemory` + 实现替换掉默认的 `ChatMemory(InMemoryMessageStore())`。 + + Args: + factory: 一个无参数的、返回 `BaseMemory` 实例的函数或类。 + """ + global _default_memory_factory + _default_memory_factory = factory + + +def _get_default_memory() -> BaseMemory: + """ + [内部函数] 获取一个默认的记忆后端实例。 + + 它会首先检查是否有通过 `set_default_memory_backend` 设置的全局工厂, + 如果有,则使用该工厂创建实例;否则,返回一个标准的内存记忆实例。 + """ + if _default_memory_factory: + logger.debug("使用自定义的默认记忆后端工厂构建实例。") + return _default_memory_factory() + + logger.debug("未配置自定义记忆后端,使用默认的 ChatMemory。") + return ChatMemory(store=InMemoryMessageStore()) + + +__all__ = [ + "AIConfig", + "BaseMemory", + "BaseMessageStore", + "ChatMemory", + "InMemoryMessageStore", + "MemoryProcessor", + "_get_default_memory", + "set_default_memory_backend", +] diff --git a/zhenxun/services/llm/session.py b/zhenxun/services/llm/session.py index c7df1476..a07a7987 100644 --- a/zhenxun/services/llm/session.py +++ b/zhenxun/services/llm/session.py @@ -4,11 +4,8 @@ LLM 服务 - 会话客户端 提供一个有状态的、面向会话的 LLM 客户端,用于进行多轮对话和复杂交互。 """ -from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import Awaitable, Callable import copy -from dataclasses import dataclass, field import json from typing import Any, TypeVar, cast import uuid @@ -28,8 +25,14 @@ from .config import ( LLMGenerationConfig, ) from .config.generation import OutputConfig -from .config.providers import get_ai_config, get_llm_config +from .config.providers import get_llm_config from .manager import get_global_default_model_name, get_model_instance +from .memory import ( + AIConfig, + BaseMemory, + MemoryProcessor, + _get_default_memory, +) from .tools import tool_provider_manager from .types import ( LLMContentPart, @@ -42,7 +45,6 @@ from .types import ( StructuredOutputStrategy, ToolChoice, ToolExecutable, - ToolProvider, ) from .types.models import ( GeminiCodeExecution, @@ -57,105 +59,6 @@ from .utils import ( T = TypeVar("T", bound=BaseModel) - -@dataclass -class AIConfig: - """AI配置类""" - - model: ModelName = None - default_embedding_model: ModelName = None - default_preserve_media_in_history: bool = False - tool_providers: list[ToolProvider] = field(default_factory=list) - - def __post_init__(self): - """初始化后从配置中读取默认值""" - ai_config = get_ai_config() - if self.model is None: - self.model = ai_config.get("default_model_name") - - -class BaseMemory(ABC): - """记忆系统的抽象基类。""" - - @abstractmethod - async def get_history(self, session_id: str) -> list[LLMMessage]: - raise NotImplementedError - - @abstractmethod - async def add_message(self, session_id: str, message: LLMMessage) -> None: - raise NotImplementedError - - @abstractmethod - async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: - raise NotImplementedError - - @abstractmethod - async def clear_history(self, session_id: str) -> None: - raise NotImplementedError - - -class InMemoryMemory(BaseMemory): - """一个简单的、默认的内存记忆后端。""" - - def __init__(self, max_messages: int = 50, **kwargs: Any): - self._history: dict[str, list[LLMMessage]] = defaultdict(list) - self._max_messages = max_messages - - def _trim_history(self, session_id: str) -> None: - """修剪历史记录,确保不超过最大长度,同时保留 System Prompt""" - history = self._history[session_id] - if len(history) <= self._max_messages: - return - - has_system = history and history[0].role == "system" - - if has_system: - keep_count = max(0, self._max_messages - 1) - self._history[session_id] = [history[0], *history[-keep_count:]] - else: - self._history[session_id] = history[-self._max_messages :] - - async def get_history(self, session_id: str) -> list[LLMMessage]: - return self._history.get(session_id, []).copy() - - async def add_message(self, session_id: str, message: LLMMessage) -> None: - self._history[session_id].append(message) - self._trim_history(session_id) - - async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None: - self._history[session_id].extend(messages) - self._trim_history(session_id) - - async def clear_history(self, session_id: str) -> None: - if session_id in self._history: - del self._history[session_id] - - -class MemoryProcessor(ABC): - """记忆处理器接口""" - - @abstractmethod - async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None: - pass - - -_default_memory_factory: Callable[[], BaseMemory] | None = None - - -def set_default_memory_backend(factory: Callable[[], BaseMemory]): - """ - 设置全局默认记忆后端工厂,允许统一替换会话的记忆实现。 - """ - global _default_memory_factory - _default_memory_factory = factory - - -def _get_default_memory() -> BaseMemory: - if _default_memory_factory: - return _default_memory_factory() - return InMemoryMemory() - - DEFAULT_IVR_TEMPLATE = ( "你的响应未能通过结构校验。\n" "错误详情: {error_msg}\n\n" @@ -186,7 +89,8 @@ class AI: 参数: session_id: 唯一的会话ID,用于隔离记忆。 config: AI 配置. - memory: 可选的自定义记忆后端。如果为None,则使用默认的InMemoryMemory。 + memory: 可选的自定义记忆后端。如果为None,则使用默认的 ChatMemory + (InMemoryMessageStore)。 default_generation_config: 此AI实例的默认生成配置。 processors: 记忆处理器列表,在添加记忆后触发。 """ @@ -587,9 +491,7 @@ class AI: final_config = LLMGenerationConfig() if max_validation_retries is None: - max_validation_retries = ( - get_llm_config().client_settings.structured_retries - ) + max_validation_retries = get_llm_config().client_settings.structured_retries resolved_model_name = self._resolve_model_name(model or self.config.model)