Compare commits

...

3 Commits

Author SHA1 Message Date
Rumio
bac6040336
Merge 7b36f864ef into e5b2a872d3 2025-12-08 14:58:17 +00:00
pre-commit-ci[bot]
7b36f864ef 🚨 auto fix by pre-commit hooks 2025-12-08 14:58:13 +00:00
webjoin111
a4decd8524 ♻️ refactor(llm.session): 重构记忆系统以分离存储和策略 2025-12-08 22:57:36 +08:00
5 changed files with 262 additions and 112 deletions

View File

@ -34,7 +34,13 @@ from .manager import (
list_model_identifiers, list_model_identifiers,
set_global_default_model_name, set_global_default_model_name,
) )
from .session import AI, AIConfig, MemoryProcessor, set_default_memory_backend from .memory import (
AIConfig,
BaseMemory,
MemoryProcessor,
set_default_memory_backend,
)
from .session import AI
from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager from .tools import RunContext, ToolInvoker, function_tool, tool_provider_manager
from .types import ( from .types import (
EmbeddingTaskType, EmbeddingTaskType,
@ -62,6 +68,7 @@ from .utils import create_multimodal_message, message_to_unimessage, unimsg_to_l
__all__ = [ __all__ = [
"AI", "AI",
"AIConfig", "AIConfig",
"BaseMemory",
"CommonOverrides", "CommonOverrides",
"EmbeddingTaskType", "EmbeddingTaskType",
"GeminiCodeExecution", "GeminiCodeExecution",

View File

@ -119,8 +119,7 @@ class GeminiAdapter(BaseAdapter):
system_instruction_parts = [{"text": msg.content}] system_instruction_parts = [{"text": msg.content}]
elif isinstance(msg.content, list): elif isinstance(msg.content, list):
system_instruction_parts = [ system_instruction_parts = [
await converter.convert_part(part) await converter.convert_part(part) for part in msg.content
for part in msg.content
] ]
continue continue

View File

@ -302,7 +302,6 @@ async def _generate_image_from_message(
messages = await normalize_to_llm_messages(message) messages = await normalize_to_llm_messages(message)
async with await get_model_instance(model) as model_instance: async with await get_model_instance(model) as model_instance:
response = await model_instance.generate_response(messages, config=config) response = await model_instance.generate_response(messages, config=config)
if not response.images: if not response.images:

View File

@ -0,0 +1,243 @@
"""
LLM 服务 - 会话记忆模块
定义了LLM会话记忆的存储策略和处理接口
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, Field
from zhenxun.services.llm.types import LLMMessage
from zhenxun.services.log import logger
class AIConfig(BaseModel):
"""AI配置类 (为保持独立性而在此处保留一个副本,实际使用中可能来自更高层)"""
model: Any = None
default_embedding_model: Any = None
default_preserve_media_in_history: bool = False
tool_providers: list[Any] = Field(default_factory=list)
def __post_init__(self):
"""初始化后从配置中读取默认值"""
pass
class BaseMessageStore(ABC):
"""
底层存储接口 (DAO - Data Access Object)
这是一个抽象基类定义了消息数据最底层的 **持久化与检索 (CRUD)** 接口
它只关心数据的存取不涉及任何业务逻辑如历史记录修剪
开发者如果希望将对话历史存储到 Redis数据库或其他持久化后端
应当实现这个接口
"""
@abstractmethod
async def get_messages(self, session_id: str) -> list[LLMMessage]:
"""
根据会话ID获取完整的消息列表
"""
raise NotImplementedError
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""追加消息"""
raise NotImplementedError
@abstractmethod
async def set_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""
完全覆盖指定会话ID的消息列表
主要用于历史记录修剪等场景
"""
raise NotImplementedError
@abstractmethod
async def clear(self, session_id: str) -> None:
"""清空指定会话ID的所有消息数据。"""
raise NotImplementedError
class InMemoryMessageStore(BaseMessageStore):
"""
一个基于内存的 `BaseMessageStore` 实现
它使用一个Python字典来存储所有会话的消息提供了最简单最快速的存储方案
这是框架的默认存储方式实现了开箱即用
注意此实现是 **非持久化** 当应用程序重启时所有对话历史都会丢失
适用于测试简单应用或不需要长期记忆的场景
"""
def __init__(self):
self._data: dict[str, list[LLMMessage]] = defaultdict(list)
async def get_messages(self, session_id: str) -> list[LLMMessage]:
"""从内存字典中获取消息列表的副本。"""
return self._data.get(session_id, []).copy()
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""向内存中的消息列表追加消息。"""
self._data[session_id].extend(messages)
async def set_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""在内存中直接替换指定会话的消息列表。"""
self._data[session_id] = messages
async def clear(self, session_id: str) -> None:
"""从内存字典中删除指定会话的条目。"""
if session_id in self._data:
del self._data[session_id]
class BaseMemory(ABC):
"""
记忆系统上层逻辑基类 (Strategy Layer)
此抽象基类定义了记忆系统的 **策略层** 接口它负责对外提供统一的记忆操作
接口并封装了具体的记忆管理策略如历史记录的修剪摘要生成等
`AI` 会话客户端直接与此接口交互而不关心底层的存储实现
开发者可以通过实现此接口来创建自定义的记忆管理策略例如
- `SummarizationMemory`: 在历史记录过长时自动调用LLM生成摘要来压缩历史
- `VectorStoreMemory`: 将对话历史向量化并存入向量数据库实现长期记忆检索
"""
@abstractmethod
async def get_history(self, session_id: str) -> list[LLMMessage]:
"""获取用于构建模型输入的完整历史消息列表。"""
raise NotImplementedError
async def add_message(self, session_id: str, message: LLMMessage) -> None:
"""向记忆中添加单条消息。默认实现是调用 `add_messages`。"""
await self.add_messages(session_id, [message])
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""向记忆中添加多条消息,并可能触发内部的记忆管理策略(如修剪)。"""
raise NotImplementedError
@abstractmethod
async def clear_history(self, session_id: str) -> None:
"""清空指定会话的全部记忆。"""
raise NotImplementedError
class ChatMemory(BaseMemory):
"""
标准聊天记忆实现组合 Store + 滑动窗口策略
这是 `BaseMemory` 的默认实现它通过组合一个 `BaseMessageStore` 实例来
完成实际的数据存储并在此之上实现了一个简单的滑动窗口记忆修剪策略
"""
def __init__(self, store: BaseMessageStore, max_messages: int = 50):
self.store = store
self._max_messages = max_messages
async def _trim_history(self, session_id: str) -> None:
"""
记忆修剪策略确保历史记录不超过 `_max_messages`
如果存在系统消息 (System Prompt)它将被永久保留在列表的第一位
"""
history = await self.store.get_messages(session_id)
if len(history) <= self._max_messages:
return
has_system = history and history[0].role == "system"
new_history: list[LLMMessage] = []
if has_system:
keep_count = max(0, self._max_messages - 1)
new_history = [history[0], *history[-keep_count:]]
else:
new_history = history[-self._max_messages :]
await self.store.set_messages(session_id, new_history)
async def get_history(self, session_id: str) -> list[LLMMessage]:
"""直接从底层存储获取历史记录。"""
return await self.store.get_messages(session_id)
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
"""添加消息到历史记录,并立即执行修剪策略。"""
await self.store.add_messages(session_id, messages)
await self._trim_history(session_id)
async def clear_history(self, session_id: str) -> None:
"""清空底层存储中的历史记录。"""
await self.store.clear(session_id)
class MemoryProcessor(ABC):
"""
记忆处理器接口 (Hook/Observer)
这是一个扩展接口允许开发者创建自定义的记忆处理器以在记忆被修改后
执行额外的操作钩子
`AI` 实例的记忆更新时它会依次调用所有注册的 `MemoryProcessor`
使用场景示例
- `LoggingMemoryProcessor`: 将每一轮对话异步记录到外部日志系统
- `SummarizationProcessor`: 在后台任务中检查对话长度并在需要时生成摘要
- `EntityExtractionProcessor`: 从对话中提取关键实体如人名地名并存储
"""
@abstractmethod
async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None:
"""处理新添加到记忆中的消息。"""
pass
_default_memory_factory: Callable[[], BaseMemory] | None = None
def set_default_memory_backend(factory: Callable[[], BaseMemory]):
"""
设置全局默认记忆后端工厂允许统一替换会话的记忆实现
这是一个高级依赖注入函数允许插件或项目在启动时用自定义的 `BaseMemory`
实现替换掉默认的 `ChatMemory(InMemoryMessageStore())`
Args:
factory: 一个无参数的返回 `BaseMemory` 实例的函数或类
"""
global _default_memory_factory
_default_memory_factory = factory
def _get_default_memory() -> BaseMemory:
"""
[内部函数] 获取一个默认的记忆后端实例
它会首先检查是否有通过 `set_default_memory_backend` 设置的全局工厂
如果有则使用该工厂创建实例否则返回一个标准的内存记忆实例
"""
if _default_memory_factory:
logger.debug("使用自定义的默认记忆后端工厂构建实例。")
return _default_memory_factory()
logger.debug("未配置自定义记忆后端,使用默认的 ChatMemory。")
return ChatMemory(store=InMemoryMessageStore())
__all__ = [
"AIConfig",
"BaseMemory",
"BaseMessageStore",
"ChatMemory",
"InMemoryMessageStore",
"MemoryProcessor",
"_get_default_memory",
"set_default_memory_backend",
]

View File

@ -4,11 +4,8 @@ LLM 服务 - 会话客户端
提供一个有状态的面向会话的 LLM 客户端用于进行多轮对话和复杂交互 提供一个有状态的面向会话的 LLM 客户端用于进行多轮对话和复杂交互
""" """
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
import copy import copy
from dataclasses import dataclass, field
import json import json
from typing import Any, TypeVar, cast from typing import Any, TypeVar, cast
import uuid import uuid
@ -28,8 +25,14 @@ from .config import (
LLMGenerationConfig, LLMGenerationConfig,
) )
from .config.generation import OutputConfig from .config.generation import OutputConfig
from .config.providers import get_ai_config, get_llm_config from .config.providers import get_llm_config
from .manager import get_global_default_model_name, get_model_instance from .manager import get_global_default_model_name, get_model_instance
from .memory import (
AIConfig,
BaseMemory,
MemoryProcessor,
_get_default_memory,
)
from .tools import tool_provider_manager from .tools import tool_provider_manager
from .types import ( from .types import (
LLMContentPart, LLMContentPart,
@ -42,7 +45,6 @@ from .types import (
StructuredOutputStrategy, StructuredOutputStrategy,
ToolChoice, ToolChoice,
ToolExecutable, ToolExecutable,
ToolProvider,
) )
from .types.models import ( from .types.models import (
GeminiCodeExecution, GeminiCodeExecution,
@ -57,105 +59,6 @@ from .utils import (
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
@dataclass
class AIConfig:
"""AI配置类"""
model: ModelName = None
default_embedding_model: ModelName = None
default_preserve_media_in_history: bool = False
tool_providers: list[ToolProvider] = field(default_factory=list)
def __post_init__(self):
"""初始化后从配置中读取默认值"""
ai_config = get_ai_config()
if self.model is None:
self.model = ai_config.get("default_model_name")
class BaseMemory(ABC):
"""记忆系统的抽象基类。"""
@abstractmethod
async def get_history(self, session_id: str) -> list[LLMMessage]:
raise NotImplementedError
@abstractmethod
async def add_message(self, session_id: str, message: LLMMessage) -> None:
raise NotImplementedError
@abstractmethod
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
raise NotImplementedError
@abstractmethod
async def clear_history(self, session_id: str) -> None:
raise NotImplementedError
class InMemoryMemory(BaseMemory):
"""一个简单的、默认的内存记忆后端。"""
def __init__(self, max_messages: int = 50, **kwargs: Any):
self._history: dict[str, list[LLMMessage]] = defaultdict(list)
self._max_messages = max_messages
def _trim_history(self, session_id: str) -> None:
"""修剪历史记录,确保不超过最大长度,同时保留 System Prompt"""
history = self._history[session_id]
if len(history) <= self._max_messages:
return
has_system = history and history[0].role == "system"
if has_system:
keep_count = max(0, self._max_messages - 1)
self._history[session_id] = [history[0], *history[-keep_count:]]
else:
self._history[session_id] = history[-self._max_messages :]
async def get_history(self, session_id: str) -> list[LLMMessage]:
return self._history.get(session_id, []).copy()
async def add_message(self, session_id: str, message: LLMMessage) -> None:
self._history[session_id].append(message)
self._trim_history(session_id)
async def add_messages(self, session_id: str, messages: list[LLMMessage]) -> None:
self._history[session_id].extend(messages)
self._trim_history(session_id)
async def clear_history(self, session_id: str) -> None:
if session_id in self._history:
del self._history[session_id]
class MemoryProcessor(ABC):
"""记忆处理器接口"""
@abstractmethod
async def process(self, session_id: str, new_messages: list[LLMMessage]) -> None:
pass
_default_memory_factory: Callable[[], BaseMemory] | None = None
def set_default_memory_backend(factory: Callable[[], BaseMemory]):
"""
设置全局默认记忆后端工厂允许统一替换会话的记忆实现
"""
global _default_memory_factory
_default_memory_factory = factory
def _get_default_memory() -> BaseMemory:
if _default_memory_factory:
return _default_memory_factory()
return InMemoryMemory()
DEFAULT_IVR_TEMPLATE = ( DEFAULT_IVR_TEMPLATE = (
"你的响应未能通过结构校验。\n" "你的响应未能通过结构校验。\n"
"错误详情: {error_msg}\n\n" "错误详情: {error_msg}\n\n"
@ -186,7 +89,8 @@ class AI:
参数: 参数:
session_id: 唯一的会话ID用于隔离记忆 session_id: 唯一的会话ID用于隔离记忆
config: AI 配置. config: AI 配置.
memory: 可选的自定义记忆后端如果为None则使用默认的InMemoryMemory memory: 可选的自定义记忆后端如果为None则使用默认的 ChatMemory
(InMemoryMessageStore)
default_generation_config: 此AI实例的默认生成配置 default_generation_config: 此AI实例的默认生成配置
processors: 记忆处理器列表在添加记忆后触发 processors: 记忆处理器列表在添加记忆后触发
""" """
@ -587,9 +491,7 @@ class AI:
final_config = LLMGenerationConfig() final_config = LLMGenerationConfig()
if max_validation_retries is None: if max_validation_retries is None:
max_validation_retries = ( max_validation_retries = get_llm_config().client_settings.structured_retries
get_llm_config().client_settings.structured_retries
)
resolved_model_name = self._resolve_model_name(model or self.config.model) resolved_model_name = self._resolve_model_name(model or self.config.model)