import asyncio from collections.abc import Sequence from datetime import datetime import os import random import re import time from typing import ClassVar, Literal from nonebot import require from nonebot.adapters import Bot from nonebot.compat import model_dump from nonebot_plugin_alconna import Text, UniMessage, UniMsg from nonebot_plugin_uninfo import Uninfo from zhenxun.configs.config import BotConfig, Config from zhenxun.configs.path_config import IMAGE_PATH from zhenxun.configs.utils import AICallableTag from zhenxun.models.sign_user import SignUser from zhenxun.services.log import logger from zhenxun.utils.decorator.retry import Retry from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.message import MessageUtils from .call_tool import AiCallTool from .exception import CallApiParamException, NotResultException from .models.bym_chat import BymChat from .models.bym_gift_log import GiftLog require("sign_in") from zhenxun.builtin_plugins.sign_in.utils import ( get_level_and_next_impression, level2attitude, ) from .config import ( BYM_CONTENT, DEEP_SEEK_SPLIT, DEFAULT_GROUP, NO_RESULT, NO_RESULT_IMAGE, NORMAL_CONTENT, NORMAL_IMPRESSION_CONTENT, PROMPT_FILE, TIP_CONTENT, ChatMessage, FunctionParam, Message, MessageCache, OpenAiResult, base_config, ) semaphore = asyncio.Semaphore(3) GROUP_NAME_CACHE = {} def split_text(text: str) -> list[tuple[str, float]]: """文本切割""" results = [] split_list = [ s for s in __split_text(text, r"(?" ] for r in split_list: next_char_index = text.find(r) + len(r) if next_char_index < len(text) and text[next_char_index] == "?": r += "?" results.append((r, min(len(r) * 0.2, 3.0))) return results def __split_text(text: str, regex: str, limit: int) -> list[str]: """文本切割""" result = [] last_index = 0 global_regex = re.compile(regex) for match in global_regex.finditer(text): if len(result) >= limit - 1: break result.append(text[last_index : match.start()]) last_index = match.end() result.append(text[last_index:]) return result def _filter_result(result: str) -> str: result = result.replace("", "").strip() return re.sub(r"(.)\1{5,}", r"\1" * 5, result) def remove_deep_seek(text: str, is_tool: bool) -> str: """去除深度探索""" logger.debug(f"去除深度思考前原文:{text}", "BYM_AI") if "```" in text.strip() and not text.strip().endswith("```"): text += "```" match_text = None if match := re.findall(r"([\s\S]*?)", text, re.DOTALL): match_text = match[-1] elif match := re.findall(r"```([\s\S]*?)```", text, re.DOTALL): match_text = match[-1] elif match := re.findall(r"```xml([\s\S]*?)```", text, re.DOTALL): match_text = match[-1] elif match := re.findall(r"```content([\s\S]*?)```", text, re.DOTALL): match_text = match[-1] elif match := re.search(r"instruction[:,:](.*)<\/code>", text, re.DOTALL): match_text = match[2] elif match := re.findall(r"\n(.*?)\n", text, re.DOTALL): match_text = match[1] elif len(re.split(r"最终(回复|结果)[:,:]", text, re.DOTALL)) > 1: match_text = re.split(r"最终(回复|结果)[:,:]", text, re.DOTALL)[-1] elif match := re.search(r"Response[:,:]\*?\*?(.*)", text, re.DOTALL): match_text = match[2] elif "回复用户" in text: match_text = re.split("回复用户.{0,1}", text)[-1] elif "最终回复" in text: match_text = re.split("最终回复.{0,1}", text)[-1] elif "Response text:" in text: match_text = re.split("Response text[:,:]", text)[-1] if match_text: match_text = re.sub(r"```tool_code([\s\S]*?)```", "", match_text).strip() match_text = re.sub(r"```json([\s\S]*?)```", "", match_text).strip() match_text = re.sub( r"([\s\S]*?)", "", match_text ).strip() match_text = re.sub( r"\[\/?instruction\]([\s\S]*?)\[\/?instruction\]", "", match_text ).strip() match_text = re.sub(r"([\s\S]*?)", "", match_text).strip() return re.sub(r"<\/?content>", "", match_text) else: text = re.sub(r"```tool_code([\s\S]*?)```", "", text).strip() text = re.sub(r"```json([\s\S]*?)```", "", text).strip() text = re.sub(r"([\s\S]*?)", "", text).strip() text = re.sub(r"([\s\S]*?)", "", text).strip() if is_tool: if DEEP_SEEK_SPLIT in text: return text.split(DEEP_SEEK_SPLIT, 1)[-1].strip() if match := re.search(r"```text\n([\s\S]*?)\n```", text, re.DOTALL): text = match[1] if text.endswith("```"): text = text[:-3].strip() if match := re.search(r"\n([\s\S]*?)\n", text, re.DOTALL): text = match[1] elif match := re.search(r"\n([\s\S]*?)\n", text, re.DOTALL): text = match[1] elif "think" in text: if text.count("think") == 2: text = re.split("<.{0,1}think.*>", text)[1] else: text = re.split("<.{0,1}think.*>", text)[-1] else: arr = text.split("\n") index = next((i for i, a in enumerate(arr) if not a.strip()), 0) if index != 0: text = "\n".join(arr[index + 1 :]) text = re.sub(r"^[\s\S]*?结果[:,:]\n", "", text) return ( re.sub(r"深度思考:[\s\S]*?\n\s*\n", "", text) .replace("深度思考结束。", "") .strip() ) else: text = text.strip().split("\n")[-1] text = re.sub(r"^[\s\S]*?结果[:,:]\n", "", text) return re.sub(r"<\/?content>", "", text).replace("深度思考结束。", "").strip() class TokenCounter: def __init__(self): if tokens := base_config.get("BYM_AI_CHAT_TOKEN"): if isinstance(tokens, str): tokens = [tokens] self.tokens = dict.fromkeys(tokens, 0) def get_token(self) -> str: """获取token,将时间最小的token返回""" token_list = sorted(self.tokens.keys(), key=lambda x: self.tokens[x]) result_token = token_list[0] self.tokens[result_token] = int(time.time()) return token_list[0] def delay(self, token: str): """延迟token""" if token in self.tokens: """等待15分钟""" self.tokens[token] = int(time.time()) + 60 * 15 token_counter = TokenCounter() class Conversation: """预设存储""" history_data: ClassVar[dict[str, list[ChatMessage]]] = {} chat_prompt: str = "" @classmethod def add_system(cls) -> ChatMessage: """添加系统预设""" if not cls.chat_prompt: cls.chat_prompt = PROMPT_FILE.open(encoding="utf8").read() return ChatMessage(role="system", content=cls.chat_prompt) @classmethod async def get_db_data( cls, user_id: str | None, group_id: str | None = None ) -> list[ChatMessage]: """从数据库获取记录 参数: user_id: 用户id group_id: 群组id,获取群组内记录时使用 返回: list[ChatMessage]: 记录列表 """ conversation = [] enable_group_chat = base_config.get("ENABLE_GROUP_CHAT") if enable_group_chat and group_id: db_filter = BymChat.filter(group_id=group_id) elif enable_group_chat: db_filter = BymChat.filter(user_id=user_id, group_id=None) else: db_filter = BymChat.filter(user_id=user_id) db_data_list = ( await db_filter.order_by("-id") .limit(int(base_config.get("CACHE_SIZE") / 2)) .all() ) for db_data in db_data_list: if db_data.is_reset: break conversation.extend( ( ChatMessage(role="assistant", content=db_data.result), ChatMessage(role="user", content=db_data.plain_text), ) ) conversation.reverse() return conversation @classmethod async def get_conversation( cls, user_id: str | None, group_id: str | None ) -> list[ChatMessage]: """获取预设 参数: user_id: 用户id 返回: list[ChatMessage]: 预设数据 """ conversation = [] if ( base_config.get("ENABLE_GROUP_CHAT") and group_id and group_id in cls.history_data ): conversation = cls.history_data[group_id] elif user_id and user_id in cls.history_data: conversation = cls.history_data[user_id] # 尝试从数据库中获取历史对话 if not conversation: conversation = await cls.get_db_data(user_id, group_id) # 必须带有人设 conversation = [c for c in conversation if c.role != "system"] conversation.insert(0, cls.add_system()) return conversation @classmethod def set_history( cls, user_id: str, group_id: str | None, conversation: list[ChatMessage] ): """设置历史预设 参数: user_id: 用户id conversation: 消息记录 """ cache_size = base_config.get("CACHE_SIZE") group_cache_size = base_config.get("GROUP_CACHE_SIZE") size = group_cache_size if group_id else cache_size if len(conversation) > size: conversation = conversation[-size:] if base_config.get("ENABLE_GROUP_CHAT") and group_id: cls.history_data[group_id] = conversation else: cls.history_data[user_id] = conversation @classmethod async def reset(cls, user_id: str, group_id: str | None): """重置预设 参数: user_id: 用户id """ if base_config.get("ENABLE_GROUP_CHAT") and group_id: # 群组内重置 if ( db_data := await BymChat.filter(group_id=group_id) .order_by("-id") .first() ): db_data.is_reset = True await db_data.save(update_fields=["is_reset"]) if group_id in cls.history_data: del cls.history_data[group_id] elif user_id: # 个人重置 if ( db_data := await BymChat.filter(user_id=user_id, group_id=None) .order_by("-id") .first() ): db_data.is_reset = True await db_data.save(update_fields=["is_reset"]) if user_id in cls.history_data: del cls.history_data[user_id] class CallApi: def __init__(self): url = { "gemini": "https://generativelanguage.googleapis.com/v1beta/chat/completions", "DeepSeek": "https://api.deepseek.com", "硅基流动": "https://api.siliconflow.cn/v1", "阿里云百炼": "https://dashscope.aliyuncs.com/compatible-mode/v1", "百度智能云": "https://qianfan.baidubce.com/v2", "字节火山引擎": "https://ark.cn-beijing.volces.com/api/v3", } # 对话 chat_url = base_config.get("BYM_AI_CHAT_URL") self.chat_url = url.get(chat_url, chat_url) self.chat_model = base_config.get("BYM_AI_CHAT_MODEL") self.tool_model = base_config.get("BYM_AI_TOOL_MODEL") self.chat_token = token_counter.get_token() # tts语音 self.tts_url = Config.get_config("bym_ai", "BYM_AI_TTS_URL") self.tts_token = Config.get_config("bym_ai", "BYM_AI_TTS_TOKEN") self.tts_voice = Config.get_config("bym_ai", "BYM_AI_TTS_VOICE") @Retry.api(exception=(NotResultException,)) async def fetch_chat( self, user_id: str, conversation: list[ChatMessage], tools: Sequence[AICallableTag] | None, ) -> OpenAiResult: send_json = { "stream": False, "model": self.tool_model if tools else self.chat_model, "temperature": 0.7, } if tools: send_json["tools"] = [ {"type": "function", "function": tool.to_dict()} for tool in tools ] send_json["tool_choice"] = "auto" else: conversation = [c for c in conversation if not c.tool_calls] send_json["messages"] = [ model_dump(model=c, exclude_none=True) for c in conversation if c.content ] response = await AsyncHttpx.post( self.chat_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.chat_token}", }, json=send_json, verify=False, ) if response.status_code == 429: logger.debug( f"fetch_chat 请求失败: 限速, token: {self.chat_token} 延迟 15 分钟", "BYM_AI", session=user_id, ) token_counter.delay(self.chat_token) if response.status_code == 400: logger.warning("请求接口错误 code: 400", "BYM_AI") raise CallApiParamException() response.raise_for_status() result = OpenAiResult(**response.json()) if not result.choices: logger.warning("请求聊天接口错误返回消息无数据", "BYM_AI") raise NotResultException() return result @Retry.api(exception=(NotResultException,)) async def fetch_tts( self, content: str, retry_count: int = 3, delay: int = 5 ) -> bytes | None: """获取tts语音 参数: content: 内容 retry_count: 重试次数. delay: 重试延迟. 返回: bytes | None: 语音数据 """ if not self.tts_url or not self.tts_token or not self.tts_voice: return None headers = {"Authorization": f"Bearer {self.tts_token}"} payload = {"model": "hailuo", "input": content, "voice": self.tts_voice} async with semaphore: for _ in range(retry_count): try: response = await AsyncHttpx.post( self.tts_url, headers=headers, json=payload ) response.raise_for_status() if "audio/mpeg" in response.headers.get("Content-Type", ""): return response.content logger.warning(f"fetch_tts 请求失败: {response.content}", "BYM_AI") await asyncio.sleep(delay) except Exception as e: logger.error("fetch_tts 请求失败", "BYM_AI", e=e) return None class ChatManager: group_cache: ClassVar[dict[str, list[MessageCache]]] = {} user_impression: ClassVar[dict[str, float]] = {} @classmethod def format( cls, type: Literal["system", "user", "text"], data: str ) -> dict[str, str]: """格式化数据 参数: data: 文本 返回: dict[str, str]: 格式化字典文本 """ return { "type": type, "text": data, } @classmethod def __build_content(cls, message: UniMsg) -> list[dict[str, str]]: """获取消息文本内容 参数: message: 消息内容 返回: list[dict[str, str]]: 文本列表 """ return [ cls.format("text", seg.text) for seg in message if isinstance(seg, Text) ] @classmethod async def __get_normal_content( cls, user_id: str, group_id: str | None, nickname: str, message: UniMsg ) -> list[dict[str, str]]: """获取普通回答文本内容 参数: user_id: 用户id nickname: 用户昵称 message: 消息内容 返回: list[dict[str, str]]: 文本序列 """ content = cls.__build_content(message) if user_id not in cls.user_impression: sign_user = await SignUser.get_user(user_id) cls.user_impression[user_id] = float(sign_user.impression) gift_count = await GiftLog.filter( user_id=user_id, create_time__gte=datetime.now().date() ).count() level, _, _ = get_level_and_next_impression(cls.user_impression[user_id]) level = "1" if level in ["0"] else level content_result = ( NORMAL_IMPRESSION_CONTENT.format( time=datetime.now(), nickname=nickname, user_id=user_id, impression=cls.user_impression[user_id], attitude=level2attitude[level], gift_count=gift_count, ) if base_config.get("ENABLE_IMPRESSION") else NORMAL_CONTENT.format( nickname=nickname, user_id=user_id, ) ) # if group_id and base_config.get("ENABLE_GROUP_CHAT"): # if group_id not in GROUP_NAME_CACHE: # if group := await GroupConsole.get_group(group_id): # GROUP_NAME_CACHE[group_id] = group.group_name # content_result = ( # GROUP_CONTENT.format( # group_id=group_id, group_name=GROUP_NAME_CACHE.get(group_id, "") # ) # + content_result # ) content.insert( 0, cls.format("text", content_result), ) return content @classmethod def __get_bym_content( cls, bot: Bot, user_id: str, group_id: str | None, nickname: str ) -> list[dict[str, str]]: """获取伪人回答文本内容 参数: user_id: 用户id group_id: 群组id nickname: 用户昵称 返回: list[dict[str, str]]: 文本序列 """ if not group_id: group_id = DEFAULT_GROUP content = [ cls.format( "text", BYM_CONTENT.format( user_id=user_id, group_id=group_id, nickname=nickname, self_id=bot.self_id, ), ) ] if group_message := cls.group_cache.get(group_id): for message in group_message: content.append( cls.format( "text", f"用户昵称:{message.nickname} 用户ID:{message.user_id}", ) ) content.extend(cls.__build_content(message.message)) content.append(cls.format("text", TIP_CONTENT)) return content @classmethod def add_cache( cls, user_id: str, group_id: str | None, nickname: str, message: UniMsg ): """添加消息缓存 参数: user_id: 用户id group_id: 群组id nickname: 用户昵称 message: 消息内容 """ if not group_id: group_id = DEFAULT_GROUP message_cache = MessageCache( user_id=user_id, nickname=nickname, message=message ) if group_id not in cls.group_cache: cls.group_cache[group_id] = [message_cache] else: cls.group_cache[group_id].append(message_cache) if len(cls.group_cache[group_id]) >= 30: cls.group_cache[group_id].pop(0) @classmethod def check_is_call_tool(cls, result: OpenAiResult) -> bool: if not base_config.get("BYM_AI_TOOL_MODEL"): return False if result.choices and (msg := result.choices[0].message): return bool(msg.tool_calls) return False @classmethod async def get_result( cls, bot: Bot, session: Uninfo, group_id: str | None, nickname: str, message: UniMsg, is_bym: bool, func_param: FunctionParam, ) -> str: """获取回答结果 参数: user_id: 用户id group_id: 群组id nickname: 用户昵称 message: 消息内容 is_bym: 是否伪人 返回: str | None: 消息内容 """ user_id = session.user.id cls.add_cache(user_id, group_id, nickname, message) if is_bym: content = cls.__get_bym_content(bot, user_id, group_id, nickname) conversation = await Conversation.get_conversation(None, group_id) else: content = await cls.__get_normal_content( user_id, group_id, nickname, message ) conversation = await Conversation.get_conversation(user_id, group_id) conversation.append(ChatMessage(role="user", content=content)) tools = list(AiCallTool.tools.values()) # 首次调用,查看是否是调用工具 if ( base_config.get("BYM_AI_CHAT_SMART") and base_config.get("BYM_AI_TOOL_MODEL") and tools ): try: result = await CallApi().fetch_chat(user_id, conversation, tools) if cls.check_is_call_tool(result): result = await cls._tool_handle( bot, session, conversation, result, tools, func_param ) or await cls._chat_handle(session, conversation) else: result = await cls._chat_handle(session, conversation) except CallApiParamException: logger.warning("尝试调用工具函数失败 code: 400", "BYM_AI") result = await cls._chat_handle(session, conversation) else: result = await cls._chat_handle(session, conversation) if res := _filter_result(result): cls.add_cache( bot.self_id, group_id, BotConfig.self_nickname, MessageUtils.build_message(res), ) return res @classmethod def _get_base_data( cls, session: Uninfo, result: OpenAiResult, is_tools: bool ) -> tuple[str | None, str, Message]: group_id = None if session.group: group_id = ( session.group.parent.id if session.group.parent else session.group.id ) assistant_reply = "" message = None if result.choices and (message := result.choices[0].message): if message.content: assistant_reply = message.content.strip() if not message: raise ValueError("API响应结果不合法") return group_id, remove_deep_seek(assistant_reply, is_tools), message @classmethod async def _chat_handle( cls, session: Uninfo, conversation: list[ChatMessage], ) -> str: """响应api 参数: session: Uninfo conversation: 消息记录 result: API返回结果 返回: str: 最终结果 """ result = await CallApi().fetch_chat(session.user.id, conversation, []) group_id, assistant_reply, _ = cls._get_base_data(session, result, False) conversation.append(ChatMessage(role="assistant", content=assistant_reply)) Conversation.set_history(session.user.id, group_id, conversation) return assistant_reply @classmethod async def _tool_handle( cls, bot: Bot, session: Uninfo, conversation: list[ChatMessage], result: OpenAiResult, tools: Sequence[AICallableTag], func_param: FunctionParam, ) -> str: """处理API响应并处理工具回调 参数: user_id: 用户id conversation: 当前对话 result: API响应结果 tools: 可用的工具列表 func_param: 函数参数 返回: str: 处理后的消息内容 """ group_id, assistant_reply, message = cls._get_base_data(session, result, True) if assistant_reply: conversation.append( ChatMessage( role="assistant", content=assistant_reply, tool_calls=message.tool_calls, ) ) # 处理工具回调 if message.tool_calls: # temp_conversation = conversation.copy() call_result = await AiCallTool.build_conversation( message.tool_calls, func_param ) if call_result: conversation.append(ChatMessage(role="assistant", content=call_result)) # temp_conversation.extend( # await AiCallTool.build_conversation(message.tool_calls, func_param) # ) result = await CallApi().fetch_chat(session.user.id, conversation, []) group_id, assistant_reply, message = cls._get_base_data( session, result, True ) conversation.append( ChatMessage(role="assistant", content=assistant_reply) ) # _, assistant_reply, _ = cls._get_base_data(session, result, True) # if res := await cls._tool_handle( # bot, session, conversation, result, tools, func_param # ): # if _filter_result(res): # assistant_reply = res Conversation.set_history(session.user.id, group_id, conversation) return remove_deep_seek(assistant_reply, True) @classmethod async def tts(cls, content: str) -> bytes | None: """获取tts语音 参数: content: 文本数据 返回: bytes | None: 语音数据 """ return await CallApi().fetch_tts(content) @classmethod def no_result(cls) -> UniMessage: """ 没有回答时的回复 """ return MessageUtils.build_message( [ random.choice(NO_RESULT), IMAGE_PATH / "noresult" / random.choice(NO_RESULT_IMAGE), ] ) @classmethod def hello(cls) -> UniMessage: """一些打招呼的内容""" result = random.choice( ( "哦豁?!", "你好!Ov<", f"库库库,呼唤{BotConfig.self_nickname}做什么呢", "我在呢!", "呼呼,叫俺干嘛", ) ) img = random.choice(os.listdir(IMAGE_PATH / "zai")) return MessageUtils.build_message([IMAGE_PATH / "zai" / img, result])