mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-14 21:52:56 +08:00
Compare commits
13 Commits
c41875b401
...
82bc83b85d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82bc83b85d | ||
|
|
ff5b4e60c7 | ||
|
|
f94121080f | ||
|
|
761c8daac4 | ||
|
|
c667fc215e | ||
|
|
07be73c1b7 | ||
|
|
7e6896fa01 | ||
|
|
3cc882b116 | ||
|
|
ee699fb345 | ||
|
|
631e66d54f | ||
|
|
c7ef6fdb17 | ||
|
|
fb0a9813e1 | ||
|
|
6940c2f37b |
@ -7,7 +7,7 @@ ci:
|
||||
autoupdate_commit_msg: ":arrow_up: auto update by pre-commit hooks"
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
rev: v0.13.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
@ -84,13 +84,16 @@ async def _(
|
||||
):
|
||||
result = ""
|
||||
await MessageUtils.build_message("正在进行检查更新...").send(reply_to=True)
|
||||
|
||||
if not ver_type.available:
|
||||
result += await UpdateManager.check_version()
|
||||
logger.info("查看当前版本...", "检查更新", session=session)
|
||||
await MessageUtils.build_message(result).finish()
|
||||
return
|
||||
|
||||
ver_type_str = ver_type.result
|
||||
source_str = source.result
|
||||
if ver_type_str in {"main", "release"}:
|
||||
if not ver_type.available:
|
||||
result += await UpdateManager.check_version()
|
||||
logger.info("查看当前版本...", "检查更新", session=session)
|
||||
await MessageUtils.build_message(result).finish()
|
||||
try:
|
||||
result += await UpdateManager.update_zhenxun(
|
||||
bot,
|
||||
|
||||
@ -1,37 +1,135 @@
|
||||
import asyncio
|
||||
from typing import Literal
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
|
||||
from zhenxun.utils.manager.zhenxun_repo_manager import (
|
||||
ZhenxunRepoConfig,
|
||||
ZhenxunRepoManager,
|
||||
)
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
from zhenxun.utils.repo_utils import RepoFileManager
|
||||
|
||||
LOG_COMMAND = "AutoUpdate"
|
||||
|
||||
|
||||
class UpdateManager:
|
||||
@staticmethod
|
||||
async def _get_latest_commit_date(owner: str, repo: str, path: str) -> str:
|
||||
"""获取文件最新 commit 日期"""
|
||||
api_url = f"https://api.github.com/repos/{owner}/{repo}/commits"
|
||||
params = {"path": path, "page": 1, "per_page": 1}
|
||||
try:
|
||||
data = await AsyncHttpx.get_json(api_url, params=params)
|
||||
if data and isinstance(data, list) and data[0]:
|
||||
date_str = data[0]["commit"]["committer"]["date"]
|
||||
return date_str.split("T")[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 {owner}/{repo}/{path} 的 commit 日期失败", e=e)
|
||||
return "获取失败"
|
||||
|
||||
@classmethod
|
||||
async def check_version(cls) -> str:
|
||||
"""检查更新版本
|
||||
"""检查真寻和资源的版本"""
|
||||
bot_cur_version = cls.__get_version()
|
||||
|
||||
返回:
|
||||
str: 更新信息
|
||||
"""
|
||||
cur_version = cls.__get_version()
|
||||
release_data = await ZhenxunRepoManager.zhenxun_get_latest_releases_data()
|
||||
if not release_data:
|
||||
return "检查更新获取版本失败..."
|
||||
return (
|
||||
"检测到当前版本更新\n"
|
||||
f"当前版本:{cur_version}\n"
|
||||
f"最新版本:{release_data.get('name')}\n"
|
||||
f"创建日期:{release_data.get('created_at')}\n"
|
||||
f"更新内容:\n{release_data.get('body')}"
|
||||
release_task = ZhenxunRepoManager.zhenxun_get_latest_releases_data()
|
||||
dev_version_task = RepoFileManager.get_file_content(
|
||||
ZhenxunRepoConfig.ZHENXUN_BOT_GITHUB_URL, "__version__"
|
||||
)
|
||||
bot_commit_date_task = cls._get_latest_commit_date(
|
||||
"HibiKier", "zhenxun_bot", "__version__"
|
||||
)
|
||||
res_commit_date_task = cls._get_latest_commit_date(
|
||||
"zhenxun-org", "zhenxun-bot-resources", "__version__"
|
||||
)
|
||||
|
||||
(
|
||||
release_data,
|
||||
dev_version_text,
|
||||
bot_commit_date,
|
||||
res_commit_date,
|
||||
) = await asyncio.gather(
|
||||
release_task,
|
||||
dev_version_task,
|
||||
bot_commit_date_task,
|
||||
res_commit_date_task,
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if isinstance(release_data, dict):
|
||||
bot_release_version = release_data.get("name", "获取失败")
|
||||
bot_release_date = release_data.get("created_at", "").split("T")[0]
|
||||
else:
|
||||
bot_release_version = "获取失败"
|
||||
bot_release_date = "获取失败"
|
||||
logger.warning(f"获取 Bot release 信息失败: {release_data}")
|
||||
|
||||
if isinstance(dev_version_text, str):
|
||||
bot_dev_version = dev_version_text.split(":")[-1].strip()
|
||||
else:
|
||||
bot_dev_version = "获取失败"
|
||||
bot_commit_date = "获取失败"
|
||||
logger.warning(f"获取 Bot dev 版本信息失败: {dev_version_text}")
|
||||
|
||||
bot_update_hint = ""
|
||||
try:
|
||||
cur_base_v = bot_cur_version.split("-")[0].lstrip("v")
|
||||
dev_base_v = bot_dev_version.split("-")[0].lstrip("v")
|
||||
|
||||
if Version(cur_base_v) < Version(dev_base_v):
|
||||
bot_update_hint = "\n-> 发现新开发版本, 可用 `检查更新 main` 更新"
|
||||
elif (
|
||||
Version(cur_base_v) == Version(dev_base_v)
|
||||
and bot_cur_version != bot_dev_version
|
||||
):
|
||||
bot_update_hint = "\n-> 发现新开发版本, 可用 `检查更新 main` 更新"
|
||||
except (InvalidVersion, TypeError, IndexError):
|
||||
if bot_cur_version != bot_dev_version and bot_dev_version != "获取失败":
|
||||
bot_update_hint = "\n-> 发现新开发版本, 可用 `检查更新 main` 更新"
|
||||
|
||||
bot_update_info = (
|
||||
f"当前版本: {bot_cur_version}\n"
|
||||
f"最新开发版: {bot_dev_version} (更新于: {bot_commit_date})\n"
|
||||
f"最新正式版: {bot_release_version} (发布于: {bot_release_date})"
|
||||
f"{bot_update_hint}"
|
||||
)
|
||||
|
||||
res_version_file = ZhenxunRepoConfig.RESOURCE_PATH / "__version__"
|
||||
res_cur_version = "未找到"
|
||||
if res_version_file.exists():
|
||||
if text := res_version_file.open(encoding="utf8").readline():
|
||||
res_cur_version = text.split(":")[-1].strip()
|
||||
|
||||
res_latest_version = "获取失败"
|
||||
try:
|
||||
res_latest_version_text = await RepoFileManager.get_file_content(
|
||||
ZhenxunRepoConfig.RESOURCE_GITHUB_URL, "__version__"
|
||||
)
|
||||
res_latest_version = res_latest_version_text.split(":")[-1].strip()
|
||||
except Exception as e:
|
||||
res_commit_date = "获取失败"
|
||||
logger.warning(f"获取资源版本信息失败: {e}")
|
||||
|
||||
res_update_hint = ""
|
||||
try:
|
||||
if Version(res_cur_version) < Version(res_latest_version):
|
||||
res_update_hint = "\n-> 发现新资源版本, 可用 `检查更新 resource` 更新"
|
||||
except (InvalidVersion, TypeError):
|
||||
pass
|
||||
|
||||
res_update_info = (
|
||||
f"当前版本: {res_cur_version}\n"
|
||||
f"最新版本: {res_latest_version} (更新于: {res_commit_date})"
|
||||
f"{res_update_hint}"
|
||||
)
|
||||
|
||||
return f"『绪山真寻 Bot』\n{bot_update_info}\n\n『真寻资源』\n{res_update_info}"
|
||||
|
||||
@classmethod
|
||||
async def update_webui(
|
||||
@ -125,6 +223,7 @@ class UpdateManager:
|
||||
f"检测真寻已更新,当前版本:{cur_version}\n开始更新...",
|
||||
user_id,
|
||||
)
|
||||
result_message = ""
|
||||
if zip:
|
||||
new_version = await ZhenxunRepoManager.zhenxun_zip_update(version_type)
|
||||
await PlatformUtils.send_superuser(
|
||||
@ -133,7 +232,7 @@ class UpdateManager:
|
||||
await VirtualEnvPackageManager.install_requirement(
|
||||
ZhenxunRepoConfig.REQUIREMENTS_FILE
|
||||
)
|
||||
return (
|
||||
result_message = (
|
||||
f"版本更新完成!\n版本: {cur_version} -> {new_version}\n"
|
||||
"请重新启动真寻以完成更新!"
|
||||
)
|
||||
@ -155,13 +254,54 @@ class UpdateManager:
|
||||
await VirtualEnvPackageManager.install_requirement(
|
||||
ZhenxunRepoConfig.REQUIREMENTS_FILE
|
||||
)
|
||||
return (
|
||||
result_message = (
|
||||
f"版本更新完成!\n"
|
||||
f"版本: {cur_version} -> {result.new_version}\n"
|
||||
f"变更文件个数: {len(result.changed_files)}"
|
||||
f"{'' if source == 'git' else '(阿里云更新不支持查看变更文件)'}\n"
|
||||
"请重新启动真寻以完成更新!"
|
||||
)
|
||||
resource_warning = ""
|
||||
if version_type == "main":
|
||||
try:
|
||||
spec_content = await RepoFileManager.get_file_content(
|
||||
ZhenxunRepoConfig.ZHENXUN_BOT_GITHUB_URL, "resources.spec"
|
||||
)
|
||||
required_spec_str = None
|
||||
for line in spec_content.splitlines():
|
||||
if line.startswith("require_resources_version:"):
|
||||
required_spec_str = line.split(":", 1)[1].strip().strip("\"'")
|
||||
break
|
||||
if required_spec_str:
|
||||
res_version_file = ZhenxunRepoConfig.RESOURCE_PATH / "__version__"
|
||||
local_res_version_str = "0.0.0"
|
||||
if res_version_file.exists():
|
||||
if text := res_version_file.open(encoding="utf8").readline():
|
||||
local_res_version_str = text.split(":")[-1].strip()
|
||||
|
||||
spec = SpecifierSet(required_spec_str)
|
||||
local_ver = Version(local_res_version_str)
|
||||
if not spec.contains(local_ver):
|
||||
warning_header = (
|
||||
f"⚠️ **资源版本不兼容!**\n"
|
||||
f"当前代码需要资源版本: `{required_spec_str}`\n"
|
||||
f"您当前的资源版本是: `{local_res_version_str}`\n"
|
||||
"**将自动为您更新资源文件...**"
|
||||
)
|
||||
await PlatformUtils.send_superuser(bot, warning_header, user_id)
|
||||
resource_update_source = None if zip else source
|
||||
resource_update_result = await cls.update_resources(
|
||||
source=resource_update_source, force=force
|
||||
)
|
||||
resource_warning = (
|
||||
f"\n\n{warning_header}\n{resource_update_result}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"检查资源版本兼容性时出错: {e}", LOG_COMMAND, e=e)
|
||||
resource_warning = (
|
||||
"\n\n⚠️ 检查资源版本兼容性时出错,建议手动运行 `检查更新 resource`"
|
||||
)
|
||||
return result_message + resource_warning
|
||||
|
||||
@classmethod
|
||||
def __get_version(cls) -> str:
|
||||
|
||||
@ -19,12 +19,12 @@ from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.utils import Command, PluginExtraData, RegisterConfig
|
||||
from zhenxun.models.chat_history import ChatHistory
|
||||
from zhenxun.models.group_member_info import GroupInfoUser
|
||||
from zhenxun.services import avatar_service
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.ui.builders import TableBuilder
|
||||
from zhenxun.ui.models import ImageCell, TextCell
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
__plugin_meta__ = PluginMetadata(
|
||||
name="消息统计",
|
||||
@ -147,12 +147,14 @@ async def _(
|
||||
user_in_group.user_name if user_in_group else f"{uid_str}(已退群)"
|
||||
)
|
||||
|
||||
avatar_url = PlatformUtils.get_user_avatar_url(uid_str, platform)
|
||||
avatar_path = await avatar_service.get_avatar_path(platform, uid_str)
|
||||
|
||||
rows_data.append(
|
||||
[
|
||||
TextCell(content=str(len(rows_data) + 1)),
|
||||
ImageCell(src=avatar_url or "", shape="circle"),
|
||||
ImageCell(
|
||||
src=avatar_path.as_uri() if avatar_path else "", shape="circle"
|
||||
),
|
||||
TextCell(content=user_name),
|
||||
TextCell(content=str(num), bold=True),
|
||||
]
|
||||
|
||||
@ -26,7 +26,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
""".strip(),
|
||||
extra=PluginExtraData(
|
||||
author="HibiKier",
|
||||
version="0.1",
|
||||
version="0.2",
|
||||
plugin_type=PluginType.SUPERUSER,
|
||||
configs=[
|
||||
RegisterConfig(
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from pathlib import Path
|
||||
@ -18,7 +19,47 @@ BAIDU_URL = "https://www.baidu.com/"
|
||||
GOOGLE_URL = "https://www.google.com/"
|
||||
|
||||
VERSION_FILE = Path() / "__version__"
|
||||
ARM_KEY = "aarch64"
|
||||
|
||||
|
||||
def get_arm_cpu_freq_safe():
|
||||
"""获取ARM设备CPU频率"""
|
||||
# 方法1: 优先从系统频率文件读取
|
||||
freq_files = [
|
||||
"/sys/devices/system/cpu/cpu0/cpufreq/cpuinfo_max_freq",
|
||||
"/sys/devices/system/cpu/cpu0/cpufreq/scaling_max_freq",
|
||||
"/sys/devices/system/cpu/cpu0/cpufreq/cpuinfo_cur_freq",
|
||||
"/sys/devices/system/cpu/cpu0/cpufreq/scaling_cur_freq",
|
||||
]
|
||||
|
||||
for freq_file in freq_files:
|
||||
try:
|
||||
with open(freq_file) as f:
|
||||
frequency = int(f.read().strip())
|
||||
return round(frequency / 1000000, 2) # 转换为GHz
|
||||
except (OSError, ValueError):
|
||||
continue
|
||||
|
||||
# 方法2: 解析/proc/cpuinfo
|
||||
with contextlib.suppress(OSError, FileNotFoundError, ValueError, PermissionError):
|
||||
with open("/proc/cpuinfo") as f:
|
||||
for line in f:
|
||||
if "CPU MHz" in line:
|
||||
freq = float(line.split(":")[1].strip())
|
||||
return round(freq / 1000, 2) # 转换为GHz
|
||||
# 方法3: 使用lscpu命令
|
||||
with contextlib.suppress(OSError, subprocess.SubprocessError, ValueError):
|
||||
env = os.environ.copy()
|
||||
env["LC_ALL"] = "C"
|
||||
result = subprocess.run(
|
||||
["lscpu"], capture_output=True, text=True, env=env, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.split("\n"):
|
||||
if "CPU max MHz" in line or "CPU MHz" in line:
|
||||
freq = float(line.split(":")[1].strip())
|
||||
return round(freq / 1000, 2) # 转换为GHz
|
||||
return 0 # 如果所有方法都失败,返回0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -37,7 +78,7 @@ class CPUInfo:
|
||||
if _cpu_freq := psutil.cpu_freq():
|
||||
cpu_freq = round(_cpu_freq.current / 1000, 2)
|
||||
else:
|
||||
cpu_freq = 0
|
||||
cpu_freq = get_arm_cpu_freq_safe()
|
||||
return CPUInfo(core=cpu_core, usage=cpu_usage, freq=cpu_freq)
|
||||
|
||||
|
||||
@ -160,44 +201,13 @@ def __get_version() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def __get_arm_cpu():
|
||||
env = os.environ.copy()
|
||||
env["LC_ALL"] = "en_US.UTF-8"
|
||||
cpu_info = subprocess.check_output(["lscpu"], env=env).decode()
|
||||
model_name = ""
|
||||
cpu_freq = 0
|
||||
for line in cpu_info.splitlines():
|
||||
if "Model name" in line:
|
||||
model_name = line.split(":")[1].strip()
|
||||
if "CPU MHz" in line:
|
||||
cpu_freq = float(line.split(":")[1].strip())
|
||||
return model_name, cpu_freq
|
||||
|
||||
|
||||
def __get_arm_oracle_cpu_freq():
|
||||
cpu_freq = subprocess.check_output(
|
||||
["dmidecode", "-s", "processor-frequency"]
|
||||
).decode()
|
||||
return round(float(cpu_freq.split()[0]) / 1000, 2)
|
||||
|
||||
|
||||
async def get_status_info() -> dict:
|
||||
"""获取信息"""
|
||||
data = await __build_status()
|
||||
|
||||
system = platform.uname()
|
||||
if system.machine == ARM_KEY and not (
|
||||
cpuinfo.get_cpu_info().get("brand_raw") and data.cpu.freq
|
||||
):
|
||||
model_name, cpu_freq = __get_arm_cpu()
|
||||
if not data.cpu.freq:
|
||||
data.cpu.freq = cpu_freq or __get_arm_oracle_cpu_freq()
|
||||
data = data.get_system_info()
|
||||
data["brand_raw"] = model_name
|
||||
else:
|
||||
data = data.get_system_info()
|
||||
data["brand_raw"] = cpuinfo.get_cpu_info().get("brand_raw", "Unknown")
|
||||
|
||||
data = data.get_system_info()
|
||||
data["brand_raw"] = cpuinfo.get_cpu_info().get("brand_raw", "Unknown")
|
||||
baidu, google = await __get_network_info()
|
||||
data["baidu"] = "#8CC265" if baidu else "red"
|
||||
data["google"] = "#8CC265" if google else "red"
|
||||
|
||||
@ -13,6 +13,7 @@ from zhenxun.models.statistics import Statistics
|
||||
from zhenxun.services import (
|
||||
LLMException,
|
||||
LLMMessage,
|
||||
avatar_service,
|
||||
generate,
|
||||
)
|
||||
from zhenxun.services.log import logger
|
||||
@ -105,7 +106,8 @@ async def create_help_img(
|
||||
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
bot_id = BotConfig.get_qbot_uid(session.self_id) or session.self_id
|
||||
bot_avatar_url = PlatformUtils.get_user_avatar_url(bot_id, platform) or ""
|
||||
bot_avatar_path = await avatar_service.get_avatar_path(platform, bot_id)
|
||||
bot_avatar_url = bot_avatar_path.as_uri() if bot_avatar_path else ""
|
||||
|
||||
builder = PluginMenuBuilder(
|
||||
bot_name=BotConfig.self_nickname,
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Bot, Message
|
||||
from nonebot.adapters.onebot.v11 import MessageSegment
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.models.bot_message_store import BotMessageStore
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.enum import BotSentType
|
||||
from zhenxun.utils.log_sanitizer import sanitize_for_logging
|
||||
from zhenxun.utils.manager.message_manager import MessageManager
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
@ -41,35 +41,6 @@ def replace_message(message: Message) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def format_message_for_log(message: Message) -> str:
|
||||
"""
|
||||
将消息对象转换为适合日志记录的字符串,对base64等长内容进行摘要处理。
|
||||
"""
|
||||
if not isinstance(message, Message):
|
||||
return str(message)
|
||||
|
||||
log_parts = []
|
||||
for seg in message:
|
||||
seg: MessageSegment
|
||||
if seg.type == "text":
|
||||
log_parts.append(seg.data.get("text", ""))
|
||||
elif seg.type in ("image", "record", "video"):
|
||||
file_info = seg.data.get("file", "")
|
||||
if isinstance(file_info, str) and file_info.startswith("base64://"):
|
||||
b64_data = file_info[9:]
|
||||
data_size_bytes = (len(b64_data) * 3) / 4 - b64_data.count("=", -2)
|
||||
log_parts.append(
|
||||
f"[{seg.type}: base64, size={data_size_bytes / 1024:.2f}KB]"
|
||||
)
|
||||
else:
|
||||
log_parts.append(f"[{seg.type}]")
|
||||
elif seg.type == "at":
|
||||
log_parts.append(f"[@{seg.data.get('qq', 'unknown')}]")
|
||||
else:
|
||||
log_parts.append(f"[{seg.type}]")
|
||||
return "".join(log_parts)
|
||||
|
||||
|
||||
@Bot.on_called_api
|
||||
async def handle_api_result(
|
||||
bot: Bot, exception: Exception | None, api: str, data: dict[str, Any], result: Any
|
||||
@ -82,7 +53,6 @@ async def handle_api_result(
|
||||
message: Message = data.get("message", "")
|
||||
message_type = data.get("message_type")
|
||||
try:
|
||||
# 记录消息id
|
||||
if user_id and message_id:
|
||||
MessageManager.add(str(user_id), str(message_id))
|
||||
logger.debug(
|
||||
@ -108,7 +78,8 @@ async def handle_api_result(
|
||||
else replace_message(message),
|
||||
platform=PlatformUtils.get_platform(bot),
|
||||
)
|
||||
logger.debug(f"消息发送记录,message: {format_message_for_log(message)}")
|
||||
sanitized_message = sanitize_for_logging(message, context="nonebot_message")
|
||||
logger.debug(f"消息发送记录,message: {sanitized_message}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||
|
||||
@ -11,6 +11,7 @@ from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.statistics import Statistics
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services import avatar_service
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
RACE = [
|
||||
@ -139,9 +140,8 @@ async def get_user_info(
|
||||
bytes: 图片数据
|
||||
"""
|
||||
platform = PlatformUtils.get_platform(session) or "qq"
|
||||
avatar_url = (
|
||||
PlatformUtils.get_user_avatar_url(user_id, platform, session.self_id) or ""
|
||||
)
|
||||
avatar_path = await avatar_service.get_avatar_path(platform, user_id)
|
||||
avatar_url = avatar_path.as_uri() if avatar_path else ""
|
||||
|
||||
user = await UserConsole.get_user(user_id, platform)
|
||||
permission_level = await LevelUser.get_user_level(user_id, group_id)
|
||||
|
||||
@ -11,6 +11,7 @@ from zhenxun.models.mahiro_bank import MahiroBank
|
||||
from zhenxun.models.mahiro_bank_log import MahiroBankLog
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services import avatar_service
|
||||
from zhenxun.utils.enum import BankHandleType, GoldHandle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
@ -210,9 +211,8 @@ class BankManager:
|
||||
for deposit in user_today_deposit
|
||||
]
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
avatar_url = PlatformUtils.get_user_avatar_url(
|
||||
user_id, platform, session.self_id
|
||||
)
|
||||
avatar_path = await avatar_service.get_avatar_path(platform, user_id)
|
||||
avatar_url = avatar_path.as_uri() if avatar_path else ""
|
||||
return {
|
||||
"name": uname,
|
||||
"rank": rank + 1,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
import shutil
|
||||
@ -10,6 +11,7 @@ from zhenxun.configs.path_config import TEMP_PATH
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.services.plugin_init import PluginInitManager
|
||||
from zhenxun.utils.enum import PluginType
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
from zhenxun.utils.manager.virtual_env_package_manager import VirtualEnvPackageManager
|
||||
from zhenxun.utils.repo_utils import RepoFileManager
|
||||
@ -183,6 +185,8 @@ class StoreManager:
|
||||
StorePluginInfo: 插件信息
|
||||
bool: 是否是外部插件
|
||||
"""
|
||||
plugin_list: list[StorePluginInfo]
|
||||
extra_plugin_list: list[StorePluginInfo]
|
||||
plugin_list, extra_plugin_list = await cls.get_data()
|
||||
plugin_info = None
|
||||
is_external = False
|
||||
@ -206,6 +210,12 @@ class StoreManager:
|
||||
if is_remove:
|
||||
if plugin_info.module not in modules:
|
||||
raise PluginStoreException(f"插件 {plugin_info.name} 未安装,无法移除")
|
||||
if plugin_obj := await PluginInfo.get_plugin(
|
||||
module=plugin_info.module, plugin_type=PluginType.PARENT
|
||||
):
|
||||
plugin_info.module_path = plugin_obj.module_path
|
||||
elif plugin_obj := await PluginInfo.get_plugin(module=plugin_info.module):
|
||||
plugin_info.module_path = plugin_obj.module_path
|
||||
return plugin_info, is_external
|
||||
|
||||
if is_update:
|
||||
@ -237,9 +247,7 @@ class StoreManager:
|
||||
plugin_info.github_url = f"{github_url_split[0]}/tree/{version_split[1]}"
|
||||
logger.info(f"正在安装插件 {plugin_info.name}...", LOG_COMMAND)
|
||||
await cls.install_plugin_with_repo(
|
||||
plugin_info.github_url,
|
||||
plugin_info.module_path,
|
||||
plugin_info.is_dir,
|
||||
plugin_info,
|
||||
is_external,
|
||||
source,
|
||||
)
|
||||
@ -248,9 +256,7 @@ class StoreManager:
|
||||
@classmethod
|
||||
async def install_plugin_with_repo(
|
||||
cls,
|
||||
github_url: str,
|
||||
module_path: str,
|
||||
is_dir: bool,
|
||||
plugin_info: StorePluginInfo,
|
||||
is_external: bool = False,
|
||||
source: str | None = None,
|
||||
):
|
||||
@ -267,18 +273,26 @@ class StoreManager:
|
||||
repo_type = RepoType.ALIYUN
|
||||
elif source == "git":
|
||||
repo_type = RepoType.GITHUB
|
||||
replace_module_path = module_path.replace(".", "/")
|
||||
plugin_name = module_path.split(".")[-1]
|
||||
module_path = plugin_info.module_path
|
||||
is_dir = plugin_info.is_dir
|
||||
github_url = plugin_info.github_url
|
||||
assert github_url
|
||||
replace_module_path = module_path.replace(".", "/").lstrip("/")
|
||||
plugin_name = module_path.split(".")[-1] or plugin_info.module
|
||||
if is_dir:
|
||||
files = await RepoFileManager.list_directory_files(
|
||||
github_url, replace_module_path, repo_type=repo_type
|
||||
)
|
||||
else:
|
||||
files = [RepoFileInfo(path=f"{replace_module_path}.py", is_dir=False)]
|
||||
local_path = BASE_PATH / "plugins" if is_external else BASE_PATH
|
||||
target_dir = BASE_PATH / "plugins" / plugin_name
|
||||
if not is_external:
|
||||
target_dir = BASE_PATH
|
||||
elif is_dir and module_path == ".":
|
||||
target_dir = BASE_PATH / "plugins" / plugin_name
|
||||
else:
|
||||
target_dir = BASE_PATH / "plugins"
|
||||
files = [file for file in files if not file.is_dir]
|
||||
download_files = [(file.path, local_path / file.path) for file in files]
|
||||
download_files = [(file.path, target_dir / file.path) for file in files]
|
||||
result = await RepoFileManager.download_files(
|
||||
github_url,
|
||||
download_files,
|
||||
@ -298,7 +312,7 @@ class StoreManager:
|
||||
|
||||
is_install_req = False
|
||||
for requirement_path in requirement_paths:
|
||||
requirement_file = local_path / requirement_path.path
|
||||
requirement_file = target_dir / requirement_path.path
|
||||
if requirement_file.exists():
|
||||
is_install_req = True
|
||||
await VirtualEnvPackageManager.install_requirement(requirement_file)
|
||||
@ -341,13 +355,11 @@ class StoreManager:
|
||||
str: 返回消息
|
||||
"""
|
||||
plugin_info, _ = await cls.get_plugin_by_value(index_or_module, is_remove=True)
|
||||
path = BASE_PATH
|
||||
if plugin_info.github_url:
|
||||
path = BASE_PATH / "plugins"
|
||||
for p in plugin_info.module_path.split("."):
|
||||
path = path / p
|
||||
module_path = plugin_info.module_path
|
||||
module = module_path.split(".")[-1]
|
||||
path = BASE_PATH.parent / Path(module_path.replace(".", os.sep))
|
||||
if not plugin_info.is_dir:
|
||||
path = Path(f"{path}.py")
|
||||
path = path.parent / f"{module}.py"
|
||||
if not path.exists():
|
||||
return f"插件 {plugin_info.name} 不存在..."
|
||||
logger.debug(f"尝试移除插件 {plugin_info.name} 文件: {path}", LOG_COMMAND)
|
||||
@ -356,7 +368,7 @@ class StoreManager:
|
||||
shutil.rmtree(path, onerror=win_on_rm_error)
|
||||
else:
|
||||
path.unlink()
|
||||
await PluginInitManager.remove(f"zhenxun.{plugin_info.module_path}")
|
||||
await PluginInitManager.remove(module_path)
|
||||
return f"插件 {plugin_info.name} 移除成功! 重启后生效"
|
||||
|
||||
@classmethod
|
||||
@ -423,9 +435,7 @@ class StoreManager:
|
||||
if plugin_info.github_url is None:
|
||||
plugin_info.github_url = DEFAULT_GITHUB_URL
|
||||
await cls.install_plugin_with_repo(
|
||||
plugin_info.github_url,
|
||||
plugin_info.module_path,
|
||||
plugin_info.is_dir,
|
||||
plugin_info,
|
||||
is_external,
|
||||
)
|
||||
return f"插件 {plugin_info.name} 更新成功! 重启后生效"
|
||||
@ -473,9 +483,7 @@ class StoreManager:
|
||||
plugin_info.github_url = DEFAULT_GITHUB_URL
|
||||
is_external = False
|
||||
await cls.install_plugin_with_repo(
|
||||
plugin_info.github_url,
|
||||
plugin_info.module_path,
|
||||
plugin_info.is_dir,
|
||||
plugin_info,
|
||||
is_external,
|
||||
)
|
||||
update_success_list.append(plugin_info.name)
|
||||
|
||||
@ -153,7 +153,7 @@ async def _(session: Uninfo, arparma: Arparma, nickname: str = UserName()):
|
||||
nickname,
|
||||
PlatformUtils.get_platform(session),
|
||||
):
|
||||
await MessageUtils.build_message(image.pic2bytes()).finish(reply_to=True) # type: ignore
|
||||
await MessageUtils.build_message(image).finish(reply_to=True) # type: ignore
|
||||
return await MessageUtils.build_message("你的道具为空捏...").send(reply_to=True)
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from zhenxun.models.group_member_info import GroupInfoUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.models.user_gold_log import UserGoldLog
|
||||
from zhenxun.models.user_props_log import UserPropsLog
|
||||
from zhenxun.services import avatar_service
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.ui.models import ImageCell, TextCell
|
||||
from zhenxun.utils.enum import GoldHandle, PropHandle
|
||||
@ -123,12 +124,14 @@ async def gold_rank(session: Uninfo, group_id: str | None, num: int) -> bytes |
|
||||
data_list = []
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
for i, user in enumerate(user_list):
|
||||
ava_url = PlatformUtils.get_user_avatar_url(user[0], platform, session.self_id)
|
||||
avatar_path = await avatar_service.get_avatar_path(platform, user[0])
|
||||
data_list.append(
|
||||
[
|
||||
TextCell(content=f"{i + 1}"),
|
||||
ImageCell(src=ava_url or "", shape="circle")
|
||||
if platform == "qq"
|
||||
ImageCell(
|
||||
src=avatar_path.as_uri() if avatar_path else "", shape="circle"
|
||||
)
|
||||
if avatar_path
|
||||
else TextCell(content=""),
|
||||
TextCell(content=uid2name.get(user[0]) or user[0]),
|
||||
TextCell(content=str(user[1]), bold=True),
|
||||
@ -529,10 +532,10 @@ class ShopManage:
|
||||
if not prop:
|
||||
continue
|
||||
|
||||
icon = ""
|
||||
icon = None
|
||||
if prop.icon:
|
||||
icon_path = ICON_PATH / prop.icon
|
||||
icon = (icon_path, 33, 33) if icon_path.exists() else ""
|
||||
icon = icon_path if icon_path.exists() else None
|
||||
|
||||
table_rows.append(
|
||||
[
|
||||
|
||||
@ -13,6 +13,7 @@ from zhenxun.models.group_member_info import GroupInfoUser
|
||||
from zhenxun.models.sign_log import SignLog
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.avatar_service import avatar_service
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.ui.models import ImageCell, TextCell
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
@ -79,14 +80,16 @@ class SignManage:
|
||||
data_list = []
|
||||
platform = PlatformUtils.get_platform(session)
|
||||
for i, user in enumerate(user_list):
|
||||
ava_url = PlatformUtils.get_user_avatar_url(
|
||||
user[0], platform, session.self_id
|
||||
avatar_path = await avatar_service.get_avatar_path(
|
||||
platform=user[3] or "qq", identifier=user[0]
|
||||
)
|
||||
data_list.append(
|
||||
[
|
||||
TextCell(content=f"{i + 1}"),
|
||||
ImageCell(src=ava_url or "", shape="circle")
|
||||
if user[3] == "qq"
|
||||
ImageCell(
|
||||
src=avatar_path.as_uri() if avatar_path else "", shape="circle"
|
||||
)
|
||||
if avatar_path
|
||||
else TextCell(content=""),
|
||||
TextCell(content=uid2name.get(user[0]) or user[0]),
|
||||
TextCell(content=str(user[1]), bold=True),
|
||||
|
||||
@ -11,6 +11,7 @@ from nonebot_plugin_uninfo import Uninfo
|
||||
from zhenxun import ui
|
||||
from zhenxun.configs.config import BotConfig, Config
|
||||
from zhenxun.models.sign_user import SignUser
|
||||
from zhenxun.services import avatar_service
|
||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
@ -21,9 +22,9 @@ from .config import (
|
||||
lik2relation,
|
||||
)
|
||||
|
||||
assert (
|
||||
len(level2attitude) == len(lik2level) == len(lik2relation)
|
||||
), "好感度态度、等级、关系长度不匹配!"
|
||||
assert len(level2attitude) == len(lik2level) == len(lik2relation), (
|
||||
"好感度态度、等级、关系长度不匹配!"
|
||||
)
|
||||
|
||||
AVA_URL = "http://q1.qlogo.cn/g?b=qq&nk={}&s=160"
|
||||
|
||||
@ -212,13 +213,13 @@ async def _generate_html_card(
|
||||
if len(nickname) > 6:
|
||||
font_size = 27
|
||||
|
||||
avatar_path = await avatar_service.get_avatar_path(
|
||||
PlatformUtils.get_platform(session), user.user_id
|
||||
)
|
||||
user_info = {
|
||||
"nickname": nickname,
|
||||
"uid_str": uid_formatted,
|
||||
"avatar_url": PlatformUtils.get_user_avatar_url(
|
||||
user.user_id, PlatformUtils.get_platform(session), session.self_id
|
||||
)
|
||||
or "",
|
||||
"avatar_url": avatar_path.as_uri() if avatar_path else "",
|
||||
"sign_count": user.sign_count,
|
||||
"font_size": font_size,
|
||||
}
|
||||
|
||||
@ -344,7 +344,9 @@ class ConfigsManager:
|
||||
返回:
|
||||
ConfigGroup: ConfigGroup
|
||||
"""
|
||||
return self._data.get(key) or ConfigGroup(module="")
|
||||
if key not in self._data:
|
||||
self._data[key] = ConfigGroup(module=key)
|
||||
return self._data[key]
|
||||
|
||||
def save(self, path: str | Path | None = None, save_simple_data: bool = False):
|
||||
"""保存数据
|
||||
|
||||
@ -77,7 +77,7 @@ class PluginInfo(Model):
|
||||
返回:
|
||||
Self | None: 插件
|
||||
"""
|
||||
if filter_parent:
|
||||
if not kwargs.get("plugin_type") and filter_parent:
|
||||
return await cls.get_or_none(
|
||||
load_status=load_status, plugin_type__not=PluginType.PARENT, **kwargs
|
||||
)
|
||||
@ -96,7 +96,7 @@ class PluginInfo(Model):
|
||||
返回:
|
||||
list[Self]: 插件列表
|
||||
"""
|
||||
if filter_parent:
|
||||
if not kwargs.get("plugin_type") and filter_parent:
|
||||
return await cls.filter(
|
||||
load_status=load_status, plugin_type__not=PluginType.PARENT, **kwargs
|
||||
).all()
|
||||
|
||||
@ -18,6 +18,7 @@ require("nonebot_plugin_htmlrender")
|
||||
require("nonebot_plugin_uninfo")
|
||||
require("nonebot_plugin_waiter")
|
||||
|
||||
from .avatar_service import avatar_service
|
||||
from .db_context import Model, disconnect, with_db_timeout
|
||||
from .llm import (
|
||||
AI,
|
||||
@ -57,6 +58,7 @@ __all__ = [
|
||||
"Model",
|
||||
"PluginInit",
|
||||
"PluginInitManager",
|
||||
"avatar_service",
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
|
||||
141
zhenxun/services/avatar_service.py
Normal file
141
zhenxun/services/avatar_service.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""
|
||||
头像缓存服务
|
||||
|
||||
提供一个统一的、带缓存的头像获取服务,支持多平台和可配置的过期策略。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import DATA_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.http_utils import AsyncHttpx
|
||||
from zhenxun.utils.platform import PlatformUtils
|
||||
|
||||
Config.add_plugin_config(
|
||||
"avatar_cache",
|
||||
"ENABLED",
|
||||
True,
|
||||
help="是否启用头像缓存功能",
|
||||
default_value=True,
|
||||
type=bool,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
"avatar_cache",
|
||||
"TTL_DAYS",
|
||||
7,
|
||||
help="头像缓存的有效期(天)",
|
||||
default_value=7,
|
||||
type=int,
|
||||
)
|
||||
Config.add_plugin_config(
|
||||
"avatar_cache",
|
||||
"CLEANUP_INTERVAL_HOURS",
|
||||
24,
|
||||
help="后台清理过期缓存的间隔时间(小时)",
|
||||
default_value=24,
|
||||
type=int,
|
||||
)
|
||||
|
||||
|
||||
class AvatarService:
|
||||
"""
|
||||
一个集中式的头像缓存服务,提供L1(内存)和L2(文件)两级缓存。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache_path = (DATA_PATH / "cache" / "avatars").resolve()
|
||||
self.cache_path.mkdir(parents=True, exist_ok=True)
|
||||
self._memory_cache: dict[str, Path] = {}
|
||||
|
||||
def _get_cache_path(self, platform: str, identifier: str) -> Path:
|
||||
"""
|
||||
根据平台和ID生成存储的文件路径。
|
||||
例如: data/cache/avatars/qq/123456789.png
|
||||
"""
|
||||
identifier = str(identifier)
|
||||
return self.cache_path / platform / f"{identifier}.png"
|
||||
|
||||
async def get_avatar_path(
|
||||
self, platform: str, identifier: str, force_refresh: bool = False
|
||||
) -> Path | None:
|
||||
"""
|
||||
获取用户或群组的头像本地路径。
|
||||
|
||||
参数:
|
||||
platform: 平台名称 (e.g., 'qq')
|
||||
identifier: 用户ID或群组ID
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
返回:
|
||||
Path | None: 头像的本地文件路径,如果获取失败则返回None。
|
||||
"""
|
||||
if not Config.get_config("avatar_cache", "ENABLED"):
|
||||
return None
|
||||
|
||||
cache_key = f"{platform}-{identifier}"
|
||||
if not force_refresh and cache_key in self._memory_cache:
|
||||
if self._memory_cache[cache_key].exists():
|
||||
return self._memory_cache[cache_key]
|
||||
|
||||
local_path = self._get_cache_path(platform, identifier)
|
||||
ttl_seconds = Config.get_config("avatar_cache", "TTL_DAYS", 7) * 86400
|
||||
|
||||
if not force_refresh and local_path.exists():
|
||||
try:
|
||||
file_mtime = os.path.getmtime(local_path)
|
||||
if time.time() - file_mtime < ttl_seconds:
|
||||
self._memory_cache[cache_key] = local_path
|
||||
return local_path
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
avatar_url = PlatformUtils.get_user_avatar_url(identifier, platform)
|
||||
if not avatar_url:
|
||||
return None
|
||||
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if await AsyncHttpx.download_file(avatar_url, local_path):
|
||||
self._memory_cache[cache_key] = local_path
|
||||
return local_path
|
||||
else:
|
||||
logger.warning(f"下载头像失败: {avatar_url}", "AvatarService")
|
||||
return None
|
||||
|
||||
async def _cleanup_cache(self):
|
||||
"""后台定时清理过期的缓存文件"""
|
||||
if not Config.get_config("avatar_cache", "ENABLED"):
|
||||
return
|
||||
|
||||
logger.info("开始执行头像缓存清理任务...", "AvatarService")
|
||||
ttl_seconds = Config.get_config("avatar_cache", "TTL_DAYS", 7) * 86400
|
||||
now = time.time()
|
||||
deleted_count = 0
|
||||
for root, _, files in os.walk(self.cache_path):
|
||||
for name in files:
|
||||
file_path = Path(root) / name
|
||||
try:
|
||||
if now - os.path.getmtime(file_path) > ttl_seconds:
|
||||
file_path.unlink()
|
||||
deleted_count += 1
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"头像缓存清理完成,共删除 {deleted_count} 个过期文件。", "AvatarService"
|
||||
)
|
||||
|
||||
|
||||
avatar_service = AvatarService()
|
||||
|
||||
|
||||
@scheduler.scheduled_job(
|
||||
"interval", hours=Config.get_config("avatar_cache", "CLEANUP_INTERVAL_HOURS", 24)
|
||||
)
|
||||
async def _run_avatar_cache_cleanup():
|
||||
await avatar_service._cleanup_cache()
|
||||
@ -7,6 +7,7 @@ LLM 服务模块 - 公共 API 入口
|
||||
from .api import (
|
||||
chat,
|
||||
code,
|
||||
create_image,
|
||||
embed,
|
||||
generate,
|
||||
generate_structured,
|
||||
@ -74,6 +75,7 @@ __all__ = [
|
||||
"chat",
|
||||
"clear_model_cache",
|
||||
"code",
|
||||
"create_image",
|
||||
"create_multimodal_message",
|
||||
"embed",
|
||||
"function_tool",
|
||||
|
||||
@ -3,6 +3,9 @@ LLM 适配器基类和通用数据结构
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -32,6 +35,7 @@ class ResponseData(BaseModel):
|
||||
"""响应数据封装 - 支持所有高级功能"""
|
||||
|
||||
text: str
|
||||
image_bytes: bytes | None = None
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[LLMToolCall] | None = None
|
||||
@ -242,6 +246,38 @@ class BaseAdapter(ABC):
|
||||
if content:
|
||||
content = content.strip()
|
||||
|
||||
image_bytes: bytes | None = None
|
||||
if content and content.startswith("{") and content.endswith("}"):
|
||||
try:
|
||||
content_json = json.loads(content)
|
||||
if "b64_json" in content_json:
|
||||
image_bytes = base64.b64decode(content_json["b64_json"])
|
||||
content = "[图片已生成]"
|
||||
elif "data" in content_json and isinstance(
|
||||
content_json["data"], str
|
||||
):
|
||||
image_bytes = base64.b64decode(content_json["data"])
|
||||
content = "[图片已生成]"
|
||||
|
||||
except (json.JSONDecodeError, KeyError, binascii.Error):
|
||||
pass
|
||||
elif (
|
||||
"images" in message
|
||||
and isinstance(message["images"], list)
|
||||
and message["images"]
|
||||
):
|
||||
image_info = message["images"][0]
|
||||
if image_info.get("type") == "image_url":
|
||||
image_url_obj = image_info.get("image_url", {})
|
||||
url_str = image_url_obj.get("url", "")
|
||||
if url_str.startswith("data:image/png;base64,"):
|
||||
try:
|
||||
b64_data = url_str.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
content = content if content else "[图片已生成]"
|
||||
except (IndexError, binascii.Error) as e:
|
||||
logger.warning(f"解析OpenRouter Base64图片数据失败: {e}")
|
||||
|
||||
parsed_tool_calls: list[LLMToolCall] | None = None
|
||||
if message_tool_calls := message.get("tool_calls"):
|
||||
from ..types.models import LLMToolFunction
|
||||
@ -280,6 +316,7 @@ class BaseAdapter(ABC):
|
||||
text=final_text,
|
||||
tool_calls=parsed_tool_calls,
|
||||
usage_info=usage_info,
|
||||
image_bytes=image_bytes,
|
||||
raw_response=response_json,
|
||||
)
|
||||
|
||||
@ -450,6 +487,13 @@ class OpenAICompatAdapter(BaseAdapter):
|
||||
"""准备高级请求 - OpenAI兼容格式"""
|
||||
url = self.get_api_url(model, self.get_chat_endpoint(model))
|
||||
headers = self.get_base_headers(api_key)
|
||||
if model.api_type == "openrouter":
|
||||
headers.update(
|
||||
{
|
||||
"HTTP-Referer": "https://github.com/zhenxun-org/zhenxun_bot",
|
||||
"X-Title": "Zhenxun Bot",
|
||||
}
|
||||
)
|
||||
openai_messages = self.convert_messages_to_openai_format(messages)
|
||||
|
||||
body = {
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
Gemini API 适配器
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
@ -373,7 +374,16 @@ class GeminiAdapter(BaseAdapter):
|
||||
self.validate_response(response_json)
|
||||
|
||||
try:
|
||||
candidates = response_json.get("candidates", [])
|
||||
if "image_generation" in response_json and isinstance(
|
||||
response_json["image_generation"], dict
|
||||
):
|
||||
candidates_source = response_json["image_generation"]
|
||||
else:
|
||||
candidates_source = response_json
|
||||
|
||||
candidates = candidates_source.get("candidates", [])
|
||||
usage_info = response_json.get("usageMetadata")
|
||||
|
||||
if not candidates:
|
||||
logger.debug("Gemini响应中没有candidates。")
|
||||
return ResponseData(text="", raw_response=response_json)
|
||||
@ -398,6 +408,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
parts = content_data.get("parts", [])
|
||||
|
||||
text_content = ""
|
||||
image_bytes: bytes | None = None
|
||||
parsed_tool_calls: list["LLMToolCall"] | None = None
|
||||
thought_summary_parts = []
|
||||
answer_parts = []
|
||||
@ -409,6 +420,14 @@ class GeminiAdapter(BaseAdapter):
|
||||
thought_summary_parts.append(part["thought"])
|
||||
elif "thoughtSummary" in part:
|
||||
thought_summary_parts.append(part["thoughtSummary"])
|
||||
elif "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
image_bytes = base64.b64decode(inline_data["data"])
|
||||
answer_parts.append(
|
||||
f"[图片已生成: {inline_data.get('mimeType', 'image')}]"
|
||||
)
|
||||
|
||||
elif "functionCall" in part:
|
||||
if parsed_tool_calls is None:
|
||||
parsed_tool_calls = []
|
||||
@ -475,6 +494,7 @@ class GeminiAdapter(BaseAdapter):
|
||||
return ResponseData(
|
||||
text=text_content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
image_bytes=image_bytes,
|
||||
usage_info=usage_info,
|
||||
raw_response=response_json,
|
||||
grounding_metadata=grounding_metadata_obj,
|
||||
|
||||
@ -21,7 +21,14 @@ class OpenAIAdapter(OpenAICompatAdapter):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> list[str]:
|
||||
return ["openai", "deepseek", "zhipu", "general_openai_compat", "ark"]
|
||||
return [
|
||||
"openai",
|
||||
"deepseek",
|
||||
"zhipu",
|
||||
"general_openai_compat",
|
||||
"ark",
|
||||
"openrouter",
|
||||
]
|
||||
|
||||
def get_chat_endpoint(self, model: "LLMModel") -> str:
|
||||
"""返回聊天完成端点"""
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
LLM 服务的高级 API 接口 - 便捷函数入口 (无状态)
|
||||
"""
|
||||
|
||||
from typing import Any, TypeVar
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from nonebot_plugin_alconna.uniseg import UniMessage
|
||||
from pydantic import BaseModel
|
||||
@ -10,7 +11,7 @@ from pydantic import BaseModel
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
from .config import CommonOverrides
|
||||
from .config.generation import create_generation_config_from_kwargs
|
||||
from .config.generation import LLMGenerationConfig, create_generation_config_from_kwargs
|
||||
from .manager import get_model_instance
|
||||
from .session import AI
|
||||
from .tools.manager import tool_provider_manager
|
||||
@ -23,6 +24,7 @@ from .types import (
|
||||
LLMResponse,
|
||||
ModelName,
|
||||
)
|
||||
from .utils import create_multimodal_message
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@ -303,3 +305,99 @@ async def run_with_tools(
|
||||
raise LLMException(
|
||||
"带工具的执行循环未能产生有效的助手回复。", code=LLMErrorCode.GENERATION_FAILED
|
||||
)
|
||||
|
||||
|
||||
async def _generate_image_from_message(
|
||||
message: UniMessage,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
[内部] 从 UniMessage 生成图片的核心辅助函数。
|
||||
"""
|
||||
from .utils import normalize_to_llm_messages
|
||||
|
||||
config = (
|
||||
create_generation_config_from_kwargs(**kwargs)
|
||||
if kwargs
|
||||
else LLMGenerationConfig()
|
||||
)
|
||||
|
||||
config.validation_policy = {"require_image": True}
|
||||
config.response_modalities = ["IMAGE", "TEXT"]
|
||||
|
||||
try:
|
||||
messages = await normalize_to_llm_messages(message)
|
||||
|
||||
async with await get_model_instance(model) as model_instance:
|
||||
if not model_instance.can_generate_images():
|
||||
raise LLMException(
|
||||
f"模型 '{model_instance.provider_name}/{model_instance.model_name}'"
|
||||
f"不支持图片生成",
|
||||
code=LLMErrorCode.CONFIGURATION_ERROR,
|
||||
)
|
||||
|
||||
response = await model_instance.generate_response(messages, config=config)
|
||||
|
||||
if not response.image_bytes:
|
||||
error_text = response.text or "模型未返回图片数据。"
|
||||
logger.warning(f"图片生成调用未返回图片,返回文本内容: {error_text}")
|
||||
|
||||
return response
|
||||
except LLMException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"执行图片生成时发生未知错误: {e}", e=e)
|
||||
raise LLMException(f"图片生成失败: {e}", cause=e)
|
||||
|
||||
|
||||
@overload
|
||||
async def create_image(
|
||||
prompt: str | UniMessage,
|
||||
*,
|
||||
images: None = None,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""根据文本提示生成一张新图片。"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def create_image(
|
||||
prompt: str | UniMessage,
|
||||
*,
|
||||
images: list[Path | bytes | str] | Path | bytes | str,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""在给定图片的基础上,根据文本提示进行编辑或重新生成。"""
|
||||
...
|
||||
|
||||
|
||||
async def create_image(
|
||||
prompt: str | UniMessage,
|
||||
*,
|
||||
images: list[Path | bytes | str] | Path | bytes | str | None = None,
|
||||
model: ModelName = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
智能图片生成/编辑函数。
|
||||
- 如果 `images` 为 None,执行文生图。
|
||||
- 如果提供了 `images`,执行图+文生图,支持多张图片输入。
|
||||
"""
|
||||
text_prompt = (
|
||||
prompt.extract_plain_text() if isinstance(prompt, UniMessage) else str(prompt)
|
||||
)
|
||||
|
||||
image_list = []
|
||||
if images:
|
||||
if isinstance(images, list):
|
||||
image_list.extend(images)
|
||||
else:
|
||||
image_list.append(images)
|
||||
|
||||
message = create_multimodal_message(text=text_prompt, images=image_list)
|
||||
|
||||
return await _generate_image_from_message(message, model=model, **kwargs)
|
||||
|
||||
@ -2,13 +2,15 @@
|
||||
LLM 生成配置相关类和函数
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.pydantic_compat import model_dump
|
||||
|
||||
from ..types import LLMResponse
|
||||
from ..types.enums import ResponseFormat
|
||||
from ..types.exceptions import LLMErrorCode, LLMException
|
||||
|
||||
@ -64,6 +66,15 @@ class ModelConfigOverride(BaseModel):
|
||||
|
||||
custom_params: dict[str, Any] | None = Field(default=None, description="自定义参数")
|
||||
|
||||
validation_policy: dict[str, Any] | None = Field(
|
||||
default=None, description="声明式的响应验证策略 (例如: {'require_image': True})"
|
||||
)
|
||||
response_validator: Callable[[LLMResponse], None] | None = Field(
|
||||
default=None, description="一个高级回调函数,用于验证响应,验证失败时应抛出异常"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典,排除None值"""
|
||||
|
||||
|
||||
@ -50,8 +50,8 @@ class LLMHttpClient:
|
||||
async with self._lock:
|
||||
if self._client is None or self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Initializing new httpx.AsyncClient "
|
||||
f"with config: {self.config}"
|
||||
f"LLMHttpClient: 正在初始化新的 httpx.AsyncClient "
|
||||
f"配置: {self.config}"
|
||||
)
|
||||
headers = get_user_agent()
|
||||
limits = httpx.Limits(
|
||||
@ -92,7 +92,7 @@ class LLMHttpClient:
|
||||
)
|
||||
if self._client is None:
|
||||
raise LLMException(
|
||||
"HTTP client failed to initialize.", LLMErrorCode.CONFIGURATION_ERROR
|
||||
"HTTP 客户端初始化失败。", LLMErrorCode.CONFIGURATION_ERROR
|
||||
)
|
||||
return self._client
|
||||
|
||||
@ -110,17 +110,17 @@ class LLMHttpClient:
|
||||
async with self._lock:
|
||||
if self._client and not self._client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClient: Closing with config: {self.config}. "
|
||||
f"Active requests: {self._active_requests}"
|
||||
f"LLMHttpClient: 正在关闭,配置: {self.config}. "
|
||||
f"活跃请求数: {self._active_requests}"
|
||||
)
|
||||
if self._active_requests > 0:
|
||||
logger.warning(
|
||||
f"LLMHttpClient: Closing while {self._active_requests} "
|
||||
f"requests are still active."
|
||||
f"LLMHttpClient: 关闭时仍有 {self._active_requests} "
|
||||
f"个请求处于活跃状态。"
|
||||
)
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
logger.debug(f"LLMHttpClient for config {self.config} definitively closed.")
|
||||
logger.debug(f"配置为 {self.config} 的 LLMHttpClient 已完全关闭。")
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
@ -145,20 +145,17 @@ class LLMHttpClientManager:
|
||||
client = self._clients.get(key)
|
||||
if client and not client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Reusing existing LLMHttpClient "
|
||||
f"for key: {key}"
|
||||
f"LLMHttpClientManager: 复用现有的 LLMHttpClient 密钥: {key}"
|
||||
)
|
||||
return client
|
||||
|
||||
if client and client.is_closed:
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Found a closed client for key {key}. "
|
||||
f"Creating a new one."
|
||||
f"LLMHttpClientManager: 发现密钥 {key} 对应的客户端已关闭。"
|
||||
f"正在创建新的客户端。"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLMHttpClientManager: Creating new LLMHttpClient for key: {key}"
|
||||
)
|
||||
logger.debug(f"LLMHttpClientManager: 为密钥 {key} 创建新的 LLMHttpClient")
|
||||
http_client_config = HttpClientConfig(
|
||||
timeout=provider_config.timeout, proxy=provider_config.proxy
|
||||
)
|
||||
@ -169,8 +166,7 @@ class LLMHttpClientManager:
|
||||
async def shutdown(self):
|
||||
async with self._lock:
|
||||
logger.info(
|
||||
f"LLMHttpClientManager: Shutting down. "
|
||||
f"Closing {len(self._clients)} client(s)."
|
||||
f"LLMHttpClientManager: 正在关闭。关闭 {len(self._clients)} 个客户端。"
|
||||
)
|
||||
close_tasks = [
|
||||
client.close()
|
||||
@ -180,7 +176,7 @@ class LLMHttpClientManager:
|
||||
if close_tasks:
|
||||
await asyncio.gather(*close_tasks, return_exceptions=True)
|
||||
self._clients.clear()
|
||||
logger.info("LLMHttpClientManager: Shutdown complete.")
|
||||
logger.info("LLMHttpClientManager: 关闭完成。")
|
||||
|
||||
|
||||
http_client_manager = LLMHttpClientManager()
|
||||
|
||||
@ -118,6 +118,7 @@ def get_default_api_base_for_type(api_type: str) -> str | None:
|
||||
"deepseek": "https://api.deepseek.com",
|
||||
"zhipu": "https://open.bigmodel.cn",
|
||||
"gemini": "https://generativelanguage.googleapis.com",
|
||||
"openrouter": "https://openrouter.ai/api",
|
||||
"general_openai_compat": None,
|
||||
}
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from typing import Any, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.log_sanitizer import sanitize_for_logging
|
||||
|
||||
from .adapters.base import RequestData
|
||||
from .config import LLMGenerationConfig
|
||||
@ -34,7 +35,6 @@ from .types import (
|
||||
ToolExecutable,
|
||||
)
|
||||
from .types.capabilities import ModelCapabilities, ModelModality
|
||||
from .utils import _sanitize_request_body_for_logging
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
@ -187,7 +187,13 @@ class LLMModel(LLMModelBase):
|
||||
logger.debug(f"🔑 API密钥: {masked_key}")
|
||||
logger.debug(f"📋 请求头: {dict(request_data.headers)}")
|
||||
|
||||
sanitized_body = _sanitize_request_body_for_logging(request_data.body)
|
||||
sanitizer_req_context_map = {"gemini": "gemini_request"}
|
||||
sanitizer_req_context = sanitizer_req_context_map.get(
|
||||
self.api_type, "openai_request"
|
||||
)
|
||||
sanitized_body = sanitize_for_logging(
|
||||
request_data.body, context=sanitizer_req_context
|
||||
)
|
||||
request_body_str = json.dumps(sanitized_body, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"📦 请求体: {request_body_str}")
|
||||
|
||||
@ -200,8 +206,11 @@ class LLMModel(LLMModelBase):
|
||||
logger.debug(f"📥 响应状态码: {http_response.status_code}")
|
||||
logger.debug(f"📄 响应头: {dict(http_response.headers)}")
|
||||
|
||||
response_bytes = await http_response.aread()
|
||||
logger.debug(f"📦 响应体已完整读取 ({len(response_bytes)} bytes)")
|
||||
|
||||
if http_response.status_code != 200:
|
||||
error_text = http_response.text
|
||||
error_text = response_bytes.decode("utf-8", errors="ignore")
|
||||
logger.error(
|
||||
f"❌ HTTP请求失败: {http_response.status_code} - {error_text} "
|
||||
f"[{log_context}]"
|
||||
@ -232,13 +241,22 @@ class LLMModel(LLMModelBase):
|
||||
)
|
||||
|
||||
try:
|
||||
response_json = http_response.json()
|
||||
response_json = json.loads(response_bytes)
|
||||
|
||||
sanitizer_context_map = {"gemini": "gemini_response"}
|
||||
sanitizer_context = sanitizer_context_map.get(
|
||||
self.api_type, "openai_response"
|
||||
)
|
||||
|
||||
sanitized_for_log = sanitize_for_logging(
|
||||
response_json, context=sanitizer_context
|
||||
)
|
||||
|
||||
response_json_str = json.dumps(
|
||||
response_json, ensure_ascii=False, indent=2
|
||||
sanitized_for_log, ensure_ascii=False, indent=2
|
||||
)
|
||||
logger.debug(f"📋 响应JSON: {response_json_str}")
|
||||
parsed_data = parse_response_func(response_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 {log_context} 响应失败: {e}", e=e)
|
||||
await self.key_store.record_failure(api_key, None, str(e))
|
||||
@ -290,7 +308,7 @@ class LLMModel(LLMModelBase):
|
||||
adapter.validate_embedding_response(response_json)
|
||||
return adapter.parse_embedding_response(response_json)
|
||||
|
||||
parsed_data, api_key_used = await self._perform_api_call(
|
||||
parsed_data, _api_key_used = await self._perform_api_call(
|
||||
prepare_request_func=prepare_request,
|
||||
parse_response_func=parse_response,
|
||||
http_client=http_client,
|
||||
@ -376,6 +394,7 @@ class LLMModel(LLMModelBase):
|
||||
return LLMResponse(
|
||||
text=response_data.text,
|
||||
usage_info=response_data.usage_info,
|
||||
image_bytes=response_data.image_bytes,
|
||||
raw_response=response_data.raw_response,
|
||||
tool_calls=response_tool_calls if response_tool_calls else None,
|
||||
code_executions=response_data.code_executions,
|
||||
@ -390,6 +409,56 @@ class LLMModel(LLMModelBase):
|
||||
failed_keys=failed_keys,
|
||||
log_context="Generation",
|
||||
)
|
||||
|
||||
if config:
|
||||
if config.response_validator:
|
||||
try:
|
||||
config.response_validator(parsed_data)
|
||||
except Exception as e:
|
||||
raise LLMException(
|
||||
f"响应内容未通过自定义验证器: {e}",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={"validator_error": str(e)},
|
||||
cause=e,
|
||||
) from e
|
||||
|
||||
policy = config.validation_policy
|
||||
if policy:
|
||||
if policy.get("require_image") and not parsed_data.image_bytes:
|
||||
if self.api_type == "gemini" and parsed_data.raw_response:
|
||||
usage_metadata = parsed_data.raw_response.get(
|
||||
"usageMetadata", {}
|
||||
)
|
||||
prompt_token_details = usage_metadata.get(
|
||||
"promptTokensDetails", []
|
||||
)
|
||||
prompt_had_image = any(
|
||||
detail.get("modality") == "IMAGE"
|
||||
for detail in prompt_token_details
|
||||
)
|
||||
|
||||
if prompt_had_image:
|
||||
raise LLMException(
|
||||
"响应验证失败:模型接收了图片输入但未生成图片。",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={
|
||||
"policy": policy,
|
||||
"text_response": parsed_data.text,
|
||||
"raw_response": parsed_data.raw_response,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.debug("Gemini提示词中未包含图片,跳过图片要求重试。")
|
||||
else:
|
||||
raise LLMException(
|
||||
"响应验证失败:要求返回图片但未找到图片数据。",
|
||||
code=LLMErrorCode.API_RESPONSE_INVALID,
|
||||
details={
|
||||
"policy": policy,
|
||||
"text_response": parsed_data.text,
|
||||
},
|
||||
)
|
||||
|
||||
return parsed_data, api_key_used
|
||||
|
||||
async def close(self):
|
||||
|
||||
@ -44,6 +44,13 @@ GEMINI_CAPABILITIES = ModelCapabilities(
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
GEMINI_IMAGE_GEN_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
output_modalities={ModelModality.TEXT, ModelModality.IMAGE},
|
||||
supports_tool_calling=True,
|
||||
)
|
||||
|
||||
|
||||
DOUBAO_ADVANCED_MULTIMODAL_CAPABILITIES = ModelCapabilities(
|
||||
input_modalities={ModelModality.TEXT, ModelModality.IMAGE, ModelModality.VIDEO},
|
||||
output_modalities={ModelModality.TEXT},
|
||||
@ -83,6 +90,7 @@ MODEL_CAPABILITIES_REGISTRY: dict[str, ModelCapabilities] = {
|
||||
output_modalities={ModelModality.EMBEDDING},
|
||||
is_embedding_model=True,
|
||||
),
|
||||
"*gemini-*-image-preview*": GEMINI_IMAGE_GEN_CAPABILITIES,
|
||||
"gemini-2.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-1.5-pro*": GEMINI_CAPABILITIES,
|
||||
"gemini-2.5-flash*": GEMINI_CAPABILITIES,
|
||||
|
||||
@ -425,6 +425,7 @@ class LLMResponse(BaseModel):
|
||||
"""LLM 响应"""
|
||||
|
||||
text: str
|
||||
image_bytes: bytes | None = None
|
||||
usage_info: dict[str, Any] | None = None
|
||||
raw_response: dict[str, Any] | None = None
|
||||
tool_calls: list[Any] | None = None
|
||||
|
||||
@ -273,54 +273,6 @@ def message_to_unimessage(message: PlatformMessage) -> UniMessage:
|
||||
return UniMessage(uni_segments)
|
||||
|
||||
|
||||
def _sanitize_request_body_for_logging(body: dict) -> dict:
|
||||
"""
|
||||
净化请求体用于日志记录,移除大数据字段并添加摘要信息
|
||||
|
||||
参数:
|
||||
body: 原始请求体字典。
|
||||
|
||||
返回:
|
||||
dict: 净化后的请求体字典。
|
||||
"""
|
||||
try:
|
||||
sanitized_body = copy.deepcopy(body)
|
||||
|
||||
if "contents" in sanitized_body and isinstance(
|
||||
sanitized_body["contents"], list
|
||||
):
|
||||
for content_item in sanitized_body["contents"]:
|
||||
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||
media_summary = []
|
||||
new_parts = []
|
||||
for part in content_item["parts"]:
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data")
|
||||
if isinstance(data, str):
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "unknown"
|
||||
)
|
||||
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||
continue
|
||||
new_parts.append(part)
|
||||
|
||||
if media_summary:
|
||||
summary_text = (
|
||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||
f"{', '.join(media_summary)}]"
|
||||
)
|
||||
new_parts.insert(0, {"text": summary_text})
|
||||
|
||||
content_item["parts"] = new_parts
|
||||
|
||||
return sanitized_body
|
||||
except Exception as e:
|
||||
logger.warning(f"日志净化失败: {e},将记录原始请求体。")
|
||||
return body
|
||||
|
||||
|
||||
def sanitize_schema_for_llm(schema: Any, api_type: str) -> Any:
|
||||
"""
|
||||
递归地净化 JSON Schema,移除特定 LLM API 不支持的关键字。
|
||||
|
||||
@ -87,7 +87,7 @@ class PluginInitManager:
|
||||
|
||||
@classmethod
|
||||
async def remove(cls, module_path: str):
|
||||
"""运行指定插件安装方法"""
|
||||
"""运行指定插件移除方法"""
|
||||
if model := cls.plugins.get(module_path):
|
||||
if model.remove:
|
||||
class_ = model.class_()
|
||||
|
||||
@ -22,6 +22,7 @@ from zhenxun.configs.config import Config
|
||||
from zhenxun.configs.path_config import THEMES_PATH, UI_CACHE_PATH
|
||||
from zhenxun.services.log import logger
|
||||
from zhenxun.utils.exception import RenderingError
|
||||
from zhenxun.utils.log_sanitizer import sanitize_for_logging
|
||||
from zhenxun.utils.pydantic_compat import _dump_pydantic_obj
|
||||
|
||||
from .config import RESERVED_TEMPLATE_KEYS
|
||||
@ -470,10 +471,7 @@ class RendererService:
|
||||
) from e
|
||||
|
||||
async def render(
|
||||
self,
|
||||
component: Renderable,
|
||||
use_cache: bool = False,
|
||||
**render_options,
|
||||
self, component: Renderable, use_cache: bool = False, **render_options
|
||||
) -> bytes:
|
||||
"""
|
||||
统一的、多态的渲染入口,直接返回图片字节。
|
||||
@ -504,9 +502,12 @@ class RendererService:
|
||||
)
|
||||
result = await self._render_component(context)
|
||||
if Config.get_config("UI", "DEBUG_MODE") and result.html_content:
|
||||
sanitized_html = sanitize_for_logging(
|
||||
result.html_content, context="ui_html"
|
||||
)
|
||||
logger.info(
|
||||
f"--- [UI DEBUG] HTML for {component.__class__.__name__} ---\n"
|
||||
f"{result.html_content}\n"
|
||||
f"{sanitized_html}\n"
|
||||
f"--- [UI DEBUG] End of HTML ---"
|
||||
)
|
||||
if result.image_bytes is None:
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from typing import Literal
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from ...models.core.table import TableCell, TableData
|
||||
from ...models.core.table import (
|
||||
BaseCell,
|
||||
ImageCell,
|
||||
TableCell,
|
||||
TableData,
|
||||
TextCell,
|
||||
)
|
||||
from ..base import BaseBuilder
|
||||
|
||||
__all__ = ["TableBuilder"]
|
||||
@ -13,6 +20,28 @@ class TableBuilder(BaseBuilder[TableData]):
|
||||
data_model = TableData(title=title, tip=tip, headers=[], rows=[])
|
||||
super().__init__(data_model, template_name="components/core/table")
|
||||
|
||||
def _normalize_cell(self, cell_data: Any) -> TableCell:
|
||||
"""内部辅助方法,将各种原生数据类型转换为TableCell模型。"""
|
||||
if isinstance(cell_data, BaseCell):
|
||||
return cell_data # type: ignore
|
||||
if isinstance(cell_data, str | int | float):
|
||||
return TextCell(content=str(cell_data))
|
||||
if isinstance(cell_data, Path):
|
||||
return ImageCell(src=cell_data.resolve().as_uri())
|
||||
if isinstance(cell_data, tuple) and len(cell_data) == 3:
|
||||
if (
|
||||
isinstance(cell_data[0], Path)
|
||||
and isinstance(cell_data[1], int)
|
||||
and isinstance(cell_data[2], int)
|
||||
):
|
||||
return ImageCell(
|
||||
src=cell_data[0].resolve().as_uri(),
|
||||
width=cell_data[1],
|
||||
height=cell_data[2],
|
||||
)
|
||||
|
||||
return TextCell(content="")
|
||||
|
||||
def set_headers(self, headers: list[str]) -> "TableBuilder":
|
||||
"""
|
||||
设置表格的表头。
|
||||
@ -57,12 +86,13 @@ class TableBuilder(BaseBuilder[TableData]):
|
||||
返回:
|
||||
TableBuilder: 当前构建器实例,以支持链式调用。
|
||||
"""
|
||||
self._data.rows.append(row)
|
||||
normalized_row = [self._normalize_cell(cell) for cell in row]
|
||||
self._data.rows.append(normalized_row)
|
||||
return self
|
||||
|
||||
def add_rows(self, rows: list[list[TableCell]]) -> "TableBuilder":
|
||||
"""
|
||||
向表格中批量添加多行数据。
|
||||
向表格中批量添加多行数据, 并自动转换原生类型。
|
||||
|
||||
参数:
|
||||
rows: 一个包含多行数据的列表。
|
||||
@ -70,5 +100,6 @@ class TableBuilder(BaseBuilder[TableData]):
|
||||
返回:
|
||||
TableBuilder: 当前构建器实例,以支持链式调用。
|
||||
"""
|
||||
self._data.rows.extend(rows)
|
||||
for row in rows:
|
||||
self.add_row(row)
|
||||
return self
|
||||
|
||||
@ -12,6 +12,7 @@ from .components import (
|
||||
from .core import (
|
||||
BaseCell,
|
||||
CodeElement,
|
||||
ComponentCell,
|
||||
HeadingElement,
|
||||
ImageCell,
|
||||
ImageElement,
|
||||
@ -49,6 +50,7 @@ __all__ = [
|
||||
"BaseCell",
|
||||
"BaseChartData",
|
||||
"CodeElement",
|
||||
"ComponentCell",
|
||||
"Divider",
|
||||
"EChartsData",
|
||||
"HeadingElement",
|
||||
|
||||
@ -11,44 +11,68 @@ from .core.base import RenderableComponent
|
||||
|
||||
class EChartsTitle(BaseModel):
|
||||
text: str
|
||||
"""图表主标题"""
|
||||
left: Literal["left", "center", "right"] = "center"
|
||||
"""标题水平对齐方式"""
|
||||
|
||||
|
||||
class EChartsAxis(BaseModel):
|
||||
type: Literal["category", "value", "time", "log"]
|
||||
"""坐标轴类型"""
|
||||
data: list[Any] | None = None
|
||||
"""类目数据"""
|
||||
show: bool = True
|
||||
"""是否显示坐标轴"""
|
||||
|
||||
|
||||
class EChartsSeries(BaseModel):
|
||||
type: str
|
||||
"""系列类型 (e.g., 'bar', 'line', 'pie')"""
|
||||
data: list[Any]
|
||||
"""系列数据"""
|
||||
name: str | None = None
|
||||
"""系列名称,用于 tooltip 的显示"""
|
||||
label: dict[str, Any] | None = None
|
||||
"""图形上的文本标签"""
|
||||
itemStyle: dict[str, Any] | None = None
|
||||
"""图形样式"""
|
||||
barMaxWidth: int | None = None
|
||||
"""柱条的最大宽度"""
|
||||
smooth: bool | None = None
|
||||
"""是否平滑显示折线"""
|
||||
|
||||
|
||||
class EChartsTooltip(BaseModel):
|
||||
trigger: Literal["item", "axis", "none"] = "item"
|
||||
trigger: Literal["item", "axis", "none"] = Field("item", description="触发类型")
|
||||
"""触发类型"""
|
||||
|
||||
|
||||
class EChartsGrid(BaseModel):
|
||||
left: str | None = None
|
||||
"""grid 组件离容器左侧的距离"""
|
||||
right: str | None = None
|
||||
"""grid 组件离容器右侧的距离"""
|
||||
top: str | None = None
|
||||
"""grid 组件离容器上侧的距离"""
|
||||
bottom: str | None = None
|
||||
"""grid 组件离容器下侧的距离"""
|
||||
containLabel: bool = True
|
||||
"""grid 区域是否包含坐标轴的刻度标签"""
|
||||
|
||||
|
||||
class BaseChartData(RenderableComponent, ABC):
|
||||
"""所有图表数据模型的基类"""
|
||||
|
||||
style_name: str | None = None
|
||||
chart_id: str = Field(default_factory=lambda: f"chart-{uuid.uuid4().hex}")
|
||||
"""组件的样式名称"""
|
||||
chart_id: str = Field(
|
||||
default_factory=lambda: f"chart-{uuid.uuid4().hex}",
|
||||
description="图表的唯一ID,用于前端渲染",
|
||||
)
|
||||
"""图表的唯一ID,用于前端渲染"""
|
||||
|
||||
echarts_options: dict[str, Any] | None = None
|
||||
"""原始ECharts选项,用于高级自定义"""
|
||||
|
||||
@abstractmethod
|
||||
def build_option(self) -> dict[str, Any]:
|
||||
@ -70,21 +94,37 @@ class BaseChartData(RenderableComponent, ABC):
|
||||
class EChartsData(BaseChartData):
|
||||
"""统一的 ECharts 图表数据模型"""
|
||||
|
||||
template_path: str = Field(..., exclude=True)
|
||||
title_model: EChartsTitle | None = Field(None, alias="title")
|
||||
grid_model: EChartsGrid | None = Field(None, alias="grid")
|
||||
tooltip_model: EChartsTooltip | None = Field(None, alias="tooltip")
|
||||
x_axis_model: EChartsAxis | None = Field(None, alias="xAxis")
|
||||
y_axis_model: EChartsAxis | None = Field(None, alias="yAxis")
|
||||
series_models: list[EChartsSeries] = Field(default_factory=list, alias="series")
|
||||
legend_model: dict[str, Any] | None = Field(default_factory=dict, alias="legend")
|
||||
template_path: str = Field(..., exclude=True, description="图表组件的模板路径")
|
||||
"""图表组件的模板路径"""
|
||||
title_model: EChartsTitle | None = Field(
|
||||
None, alias="title", description="标题组件"
|
||||
)
|
||||
"""标题组件"""
|
||||
grid_model: EChartsGrid | None = Field(None, alias="grid", description="网格组件")
|
||||
"""网格组件"""
|
||||
tooltip_model: EChartsTooltip | None = Field(
|
||||
None, alias="tooltip", description="提示框组件"
|
||||
)
|
||||
"""提示框组件"""
|
||||
x_axis_model: EChartsAxis | None = Field(None, alias="xAxis", description="X轴配置")
|
||||
"""X轴配置"""
|
||||
y_axis_model: EChartsAxis | None = Field(None, alias="yAxis", description="Y轴配置")
|
||||
"""Y轴配置"""
|
||||
series_models: list[EChartsSeries] = Field(
|
||||
default_factory=list, alias="series", description="系列列表"
|
||||
)
|
||||
"""系列列表"""
|
||||
legend_model: dict[str, Any] | None = Field(
|
||||
default_factory=dict, alias="legend", description="图例组件"
|
||||
)
|
||||
"""图例组件"""
|
||||
raw_options: dict[str, Any] = Field(
|
||||
default_factory=dict, description="用于 set_option 的原始覆盖选项"
|
||||
)
|
||||
"""用于 set_option 的原始覆盖选项"""
|
||||
|
||||
background_image: str | None = Field(
|
||||
None, description="【兼容】用于横向柱状图的背景图片"
|
||||
)
|
||||
background_image: str | None = Field(None, description="用于横向柱状图的背景图片")
|
||||
"""用于横向柱状图的背景图片"""
|
||||
|
||||
def build_option(self) -> dict[str, Any]:
|
||||
"""将 Pydantic 模型序列化为 ECharts 的 option 字典。"""
|
||||
|
||||
@ -14,9 +14,13 @@ class Alert(RenderableComponent):
|
||||
type: Literal["info", "success", "warning", "error"] = Field(
|
||||
default="info", description="提示框的类型,决定了颜色和图标"
|
||||
)
|
||||
"""提示框的类型,决定了颜色和图标"""
|
||||
title: str = Field(..., description="提示框的标题")
|
||||
"""提示框的标题"""
|
||||
content: str = Field(..., description="提示框的主要内容")
|
||||
"""提示框的主要内容"""
|
||||
show_icon: bool = Field(default=True, description="是否显示与类型匹配的图标")
|
||||
"""是否显示与类型匹配的图标"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,8 +12,11 @@ class Avatar(RenderableComponent):
|
||||
|
||||
component_type: Literal["avatar"] = "avatar"
|
||||
src: str = Field(..., description="头像的URL或Base64数据URI")
|
||||
"""头像的URL或Base64数据URI"""
|
||||
shape: Literal["circle", "square"] = Field("circle", description="头像形状")
|
||||
"""头像形状"""
|
||||
size: int = Field(50, description="头像尺寸(像素)")
|
||||
"""头像尺寸(像素)"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
@ -25,10 +28,13 @@ class AvatarGroup(RenderableComponent):
|
||||
|
||||
component_type: Literal["avatar_group"] = "avatar_group"
|
||||
avatars: list[Avatar] = Field(default_factory=list, description="头像列表")
|
||||
"""头像列表"""
|
||||
spacing: int = Field(-15, description="头像间的间距(负数表示重叠)")
|
||||
"""头像间的间距(负数表示重叠)"""
|
||||
max_count: int | None = Field(
|
||||
None, description="最多显示的头像数量,超出部分会显示为'+N'"
|
||||
)
|
||||
"""最多显示的头像数量,超出部分会显示为'+N'"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,10 +12,12 @@ class Badge(RenderableComponent):
|
||||
|
||||
component_type: Literal["badge"] = "badge"
|
||||
text: str = Field(..., description="徽章上显示的文本")
|
||||
"""徽章上显示的文本"""
|
||||
color_scheme: Literal["primary", "success", "warning", "error", "info"] = Field(
|
||||
default="info",
|
||||
description="预设的颜色方案",
|
||||
)
|
||||
"""预设的颜色方案"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,9 +12,13 @@ class Divider(RenderableComponent):
|
||||
|
||||
component_type: Literal["divider"] = "divider"
|
||||
margin: str = Field("2em 0", description="CSS margin属性,控制分割线上下的间距")
|
||||
"""CSS margin属性,控制分割线上下的间距"""
|
||||
color: str = Field("#f7889c", description="分割线颜色")
|
||||
"""分割线颜色"""
|
||||
style: Literal["solid", "dashed", "dotted"] = Field("solid", description="线条样式")
|
||||
"""线条样式"""
|
||||
thickness: str = Field("1px", description="线条粗细")
|
||||
"""线条粗细"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
@ -26,9 +30,13 @@ class Rectangle(RenderableComponent):
|
||||
|
||||
component_type: Literal["rectangle"] = "rectangle"
|
||||
height: str = Field("50px", description="矩形的高度 (CSS value)")
|
||||
"""矩形的高度 (CSS value)"""
|
||||
background_color: str = Field("#fdf1f5", description="背景颜色")
|
||||
"""背景颜色"""
|
||||
border: str = Field("1px solid #fce4ec", description="CSS border属性")
|
||||
"""CSS border属性"""
|
||||
border_radius: str = Field("8px", description="CSS border-radius属性")
|
||||
"""CSS border-radius属性"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,17 +12,23 @@ class KpiCard(RenderableComponent):
|
||||
|
||||
component_type: Literal["kpi_card"] = "kpi_card"
|
||||
label: str = Field(..., description="指标的标签或名称")
|
||||
"""指标的标签或名称"""
|
||||
value: Any = Field(..., description="指标的主要数值")
|
||||
"""指标的主要数值"""
|
||||
unit: str | None = Field(default=None, description="数值的单位,可选")
|
||||
"""数值的单位,可选"""
|
||||
change: str | None = Field(
|
||||
default=None, description="与上一周期的变化,例如 '+15%' 或 '-100'"
|
||||
)
|
||||
"""与上一周期的变化,例如 '+15%' 或 '-100'"""
|
||||
change_type: Literal["positive", "negative", "neutral"] = Field(
|
||||
default="neutral", description="变化的类型,用于决定颜色"
|
||||
)
|
||||
"""变化的类型,用于决定颜色"""
|
||||
icon_svg: str | None = Field(
|
||||
default=None, description="卡片中显示的可选图标 (SVG path data)"
|
||||
)
|
||||
"""卡片中显示的可选图标 (SVG path data)"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,12 +12,16 @@ class ProgressBar(RenderableComponent):
|
||||
|
||||
component_type: Literal["progress_bar"] = "progress_bar"
|
||||
progress: float = Field(..., ge=0, le=100, description="进度百分比 (0-100)")
|
||||
"""进度百分比 (0-100)"""
|
||||
label: str | None = Field(default=None, description="显示在进度条上的可选文本")
|
||||
"""显示在进度条上的可选文本"""
|
||||
color_scheme: Literal["primary", "success", "warning", "error", "info"] = Field(
|
||||
default="primary",
|
||||
description="预设的颜色方案",
|
||||
)
|
||||
"""预设的颜色方案"""
|
||||
animated: bool = Field(default=False, description="是否显示动画效果")
|
||||
"""是否显示动画效果"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -11,10 +11,15 @@ class TimelineItem(BaseModel):
|
||||
"""时间轴中的单个事件点。"""
|
||||
|
||||
timestamp: str = Field(..., description="显示在时间点旁边的时间或标签")
|
||||
"""显示在时间点旁边的时间或标签"""
|
||||
title: str = Field(..., description="事件的标题")
|
||||
"""事件的标题"""
|
||||
content: str = Field(..., description="事件的详细描述")
|
||||
"""事件的详细描述"""
|
||||
icon: str | None = Field(default=None, description="可选的自定义图标SVG路径")
|
||||
"""可选的自定义图标SVG路径"""
|
||||
color: str | None = Field(default=None, description="可选的自定义颜色,覆盖默认")
|
||||
"""可选的自定义颜色,覆盖默认"""
|
||||
|
||||
|
||||
class Timeline(RenderableComponent):
|
||||
@ -24,6 +29,7 @@ class Timeline(RenderableComponent):
|
||||
items: list[TimelineItem] = Field(
|
||||
default_factory=list, description="时间轴项目列表"
|
||||
)
|
||||
"""时间轴项目列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,11 +12,15 @@ class UserInfoBlock(RenderableComponent):
|
||||
|
||||
component_type: Literal["user_info_block"] = "user_info_block"
|
||||
avatar_url: str = Field(..., description="用户头像的URL")
|
||||
"""用户头像的URL"""
|
||||
name: str = Field(..., description="用户的名称")
|
||||
"""用户的名称"""
|
||||
subtitle: str | None = Field(
|
||||
default=None, description="显示在名称下方的副标题 (如UID或角色)"
|
||||
)
|
||||
"""显示在名称下方的副标题 (如UID或角色)"""
|
||||
tags: list[str] = Field(default_factory=list, description="附加的标签列表")
|
||||
"""附加的标签列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -24,6 +24,7 @@ from .markdown import (
|
||||
from .notebook import NotebookData, NotebookElement
|
||||
from .table import (
|
||||
BaseCell,
|
||||
ComponentCell,
|
||||
ImageCell,
|
||||
RichTextCell,
|
||||
StatusBadgeCell,
|
||||
@ -38,6 +39,7 @@ __all__ = [
|
||||
"BaseCell",
|
||||
"CardData",
|
||||
"CodeElement",
|
||||
"ComponentCell",
|
||||
"DetailsData",
|
||||
"DetailsItem",
|
||||
"HeadingElement",
|
||||
|
||||
@ -20,10 +20,15 @@ class RenderableComponent(BaseModel, Renderable):
|
||||
"""
|
||||
|
||||
_is_standalone_template: bool = False
|
||||
"""标记此组件是否为独立模板"""
|
||||
inline_style: dict[str, str] | None = None
|
||||
"""应用于组件根元素的内联CSS样式"""
|
||||
component_css: str | None = None
|
||||
"""注入到页面的额外CSS字符串"""
|
||||
extra_classes: list[str] | None = None
|
||||
"""应用于组件根元素的额外CSS类名列表"""
|
||||
variant: str | None = None
|
||||
"""组件的变体/皮肤名称"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -7,8 +7,11 @@ class CardData(ContainerComponent):
|
||||
"""通用卡片的数据模型,可以包含头部、内容和尾部"""
|
||||
|
||||
header: RenderableComponent | None = None
|
||||
"""卡片的头部内容组件"""
|
||||
content: RenderableComponent
|
||||
"""卡片的主要内容组件"""
|
||||
footer: RenderableComponent | None = None
|
||||
"""卡片的尾部内容组件"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -9,14 +9,18 @@ class DetailsItem(BaseModel):
|
||||
"""描述列表中的单个项目"""
|
||||
|
||||
label: str = Field(..., description="项目的标签/键")
|
||||
"""项目的标签/键"""
|
||||
value: Any = Field(..., description="项目的值")
|
||||
"""项目的值"""
|
||||
|
||||
|
||||
class DetailsData(RenderableComponent):
|
||||
"""描述列表(键值对)的数据模型"""
|
||||
|
||||
title: str | None = Field(None, description="列表的可选标题")
|
||||
"""列表的可选标题"""
|
||||
items: list[DetailsItem] = Field(default_factory=list, description="键值对项目列表")
|
||||
"""键值对项目列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,20 +12,26 @@ class LayoutItem(BaseModel):
|
||||
"""布局中的单个项目,现在持有可渲染组件的数据模型"""
|
||||
|
||||
component: RenderableComponent = Field(..., description="要渲染的组件的数据模型")
|
||||
"""要渲染的组件的数据模型"""
|
||||
metadata: dict[str, Any] | None = Field(None, description="传递给模板的额外元数据")
|
||||
"""传递给模板的额外元数据"""
|
||||
|
||||
|
||||
class LayoutData(ContainerComponent):
|
||||
"""布局构建器的数据模型"""
|
||||
|
||||
style_name: str | None = None
|
||||
"""应用于布局容器的样式名称"""
|
||||
layout_type: str = "column"
|
||||
"""布局类型 (如 'column', 'row', 'grid')"""
|
||||
children: list[LayoutItem] = Field(
|
||||
default_factory=list, description="要布局的项目列表"
|
||||
)
|
||||
"""要布局的项目列表"""
|
||||
options: dict[str, Any] = Field(
|
||||
default_factory=dict, description="传递给模板的选项"
|
||||
)
|
||||
"""传递给模板的选项"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -12,6 +12,7 @@ class ListItem(BaseModel):
|
||||
"""列表中的单个项目,其内容可以是任何可渲染组件。"""
|
||||
|
||||
component: RenderableComponent = Field(..., description="要渲染的组件的数据模型")
|
||||
"""要渲染的组件的数据模型"""
|
||||
|
||||
|
||||
class ListData(ContainerComponent):
|
||||
@ -19,7 +20,9 @@ class ListData(ContainerComponent):
|
||||
|
||||
component_type: Literal["list"] = "list"
|
||||
items: list[ListItem] = Field(default_factory=list, description="列表项目")
|
||||
"""列表项目"""
|
||||
ordered: bool = Field(default=False, description="是否为有序列表")
|
||||
"""是否为有序列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -44,7 +44,9 @@ class TextElement(MarkdownElement):
|
||||
class HeadingElement(MarkdownElement):
|
||||
type: Literal["heading"] = "heading"
|
||||
text: str
|
||||
level: int = Field(..., ge=1, le=6)
|
||||
"""标题文本"""
|
||||
level: int = Field(..., ge=1, le=6, description="标题级别 (1-6)")
|
||||
"""标题级别 (1-6)"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
return f"{'#' * self.level} {self.text}"
|
||||
@ -53,7 +55,9 @@ class HeadingElement(MarkdownElement):
|
||||
class ImageElement(MarkdownElement):
|
||||
type: Literal["image"] = "image"
|
||||
src: str
|
||||
"""图片来源 (URL或data URI)"""
|
||||
alt: str = "image"
|
||||
"""图片的替代文本"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
return f""
|
||||
@ -62,7 +66,9 @@ class ImageElement(MarkdownElement):
|
||||
class CodeElement(MarkdownElement):
|
||||
type: Literal["code"] = "code"
|
||||
code: str
|
||||
"""代码字符串"""
|
||||
language: str = ""
|
||||
"""代码语言,用于语法高亮"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
return f"```{self.language}\n{self.code}\n```"
|
||||
@ -71,6 +77,7 @@ class CodeElement(MarkdownElement):
|
||||
class RawHtmlElement(MarkdownElement):
|
||||
type: Literal["raw_html"] = "raw_html"
|
||||
html: str
|
||||
"""原始HTML字符串"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
return self.html
|
||||
@ -79,8 +86,11 @@ class RawHtmlElement(MarkdownElement):
|
||||
class TableElement(MarkdownElement):
|
||||
type: Literal["table"] = "table"
|
||||
headers: list[str]
|
||||
"""表格的表头列表"""
|
||||
rows: list[list[str]]
|
||||
"""表格的数据行列表"""
|
||||
alignments: list[Literal["left", "center", "right"]] | None = None
|
||||
"""每列的对齐方式"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
header_row = "| " + " | ".join(self.headers) + " |"
|
||||
@ -102,7 +112,10 @@ class TableElement(MarkdownElement):
|
||||
|
||||
|
||||
class ContainerElement(MarkdownElement):
|
||||
content: list[MarkdownElement] = Field(default_factory=list)
|
||||
content: list[MarkdownElement] = Field(
|
||||
default_factory=list, description="容器内包含的Markdown元素列表"
|
||||
)
|
||||
"""容器内包含的Markdown元素列表"""
|
||||
|
||||
|
||||
class QuoteElement(ContainerElement):
|
||||
@ -121,6 +134,7 @@ class ListItemElement(ContainerElement):
|
||||
class ListElement(ContainerElement):
|
||||
type: Literal["list"] = "list"
|
||||
ordered: bool = False
|
||||
"""是否为有序列表 (例如 1., 2.)"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
lines = []
|
||||
@ -137,6 +151,7 @@ class ComponentElement(MarkdownElement):
|
||||
|
||||
type: Literal["component"] = "component"
|
||||
component: RenderableComponent
|
||||
"""嵌入在Markdown中的可渲染组件"""
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
return ""
|
||||
@ -146,9 +161,15 @@ class MarkdownData(ContainerComponent):
|
||||
"""Markdown转图片的数据模型"""
|
||||
|
||||
style_name: str | None = None
|
||||
elements: list[MarkdownElement] = Field(default_factory=list)
|
||||
"""Markdown内容的样式名称"""
|
||||
elements: list[MarkdownElement] = Field(
|
||||
default_factory=list, description="构成Markdown文档的元素列表"
|
||||
)
|
||||
"""构成Markdown文档的元素列表"""
|
||||
width: int = 800
|
||||
"""最终渲染图片的宽度"""
|
||||
css_path: str | None = None
|
||||
"""自定义CSS文件的绝对路径"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
@ -180,7 +201,6 @@ class MarkdownData(ContainerComponent):
|
||||
logger.warning(f"Markdown自定义CSS文件不存在: {self.css_path}")
|
||||
else:
|
||||
style_name = self.style_name or "light"
|
||||
# 使用上下文对象来解析路径
|
||||
css_path = await context.theme_manager.resolve_markdown_style_path(
|
||||
style_name, context
|
||||
)
|
||||
|
||||
@ -22,21 +22,32 @@ class NotebookElement(BaseModel):
|
||||
"component",
|
||||
]
|
||||
text: str | None = None
|
||||
"""元素的文本内容 (用于标题、段落、引用)"""
|
||||
level: int | None = None
|
||||
"""标题的级别 (1-4)"""
|
||||
src: str | None = None
|
||||
"""图片的来源 (URL或data URI)"""
|
||||
caption: str | None = None
|
||||
"""图片的说明文字"""
|
||||
code: str | None = None
|
||||
"""代码块的内容"""
|
||||
language: str | None = None
|
||||
"""代码块的语言"""
|
||||
data: list[str] | None = None
|
||||
"""列表项的内容列表"""
|
||||
ordered: bool | None = None
|
||||
"""是否为有序列表"""
|
||||
component: RenderableComponent | None = None
|
||||
"""嵌入的自定义可渲染组件"""
|
||||
|
||||
|
||||
class NotebookData(ContainerComponent):
|
||||
"""Notebook转图片的数据模型"""
|
||||
|
||||
style_name: str | None = None
|
||||
"""Notebook的样式名称"""
|
||||
elements: list[NotebookElement]
|
||||
"""构成Notebook页面的元素列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -8,6 +8,7 @@ from .text import TextSpan
|
||||
|
||||
__all__ = [
|
||||
"BaseCell",
|
||||
"ComponentCell",
|
||||
"ImageCell",
|
||||
"ProgressBarCell",
|
||||
"RichTextCell",
|
||||
@ -63,8 +64,18 @@ class RichTextCell(BaseCell):
|
||||
|
||||
type: Literal["rich_text"] = "rich_text" # type: ignore
|
||||
spans: list[TextSpan] = Field(default_factory=list, description="文本片段列表")
|
||||
"""文本片段列表"""
|
||||
direction: Literal["column", "row"] = Field("column", description="片段排列方向")
|
||||
"""片段排列方向"""
|
||||
gap: str = Field("4px", description="片段之间的间距")
|
||||
"""片段之间的间距"""
|
||||
|
||||
|
||||
class ComponentCell(BaseCell):
|
||||
"""一个通用的单元格,可以容纳任何可渲染的组件。"""
|
||||
|
||||
type: str = "component"
|
||||
component: RenderableComponent
|
||||
|
||||
|
||||
TableCell = (
|
||||
@ -73,6 +84,7 @@ TableCell = (
|
||||
| StatusBadgeCell
|
||||
| ProgressBarCell
|
||||
| RichTextCell
|
||||
| ComponentCell
|
||||
| str
|
||||
| int
|
||||
| float
|
||||
@ -84,16 +96,23 @@ class TableData(RenderableComponent):
|
||||
"""通用表格的数据模型"""
|
||||
|
||||
style_name: str | None = None
|
||||
"""应用于表格容器的样式名称"""
|
||||
title: str = Field(..., description="表格主标题")
|
||||
"""表格主标题"""
|
||||
tip: str | None = Field(None, description="表格下方的提示信息")
|
||||
"""表格下方的提示信息"""
|
||||
headers: list[str] = Field(default_factory=list, description="表头列表")
|
||||
"""表头列表"""
|
||||
rows: list[list[TableCell]] = Field(default_factory=list, description="数据行列表")
|
||||
"""数据行列表"""
|
||||
column_alignments: list[Literal["left", "center", "right"]] | None = Field(
|
||||
default=None, description="每列的对齐方式"
|
||||
)
|
||||
"""每列的对齐方式"""
|
||||
column_widths: list[str | int] | None = Field(
|
||||
default=None, description="每列的宽度 (e.g., ['50px', 'auto', 100])"
|
||||
)
|
||||
"""每列的宽度 (e.g., ['50px', 'auto', 100])"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import RenderableComponent
|
||||
|
||||
__all__ = ["TemplateComponent"]
|
||||
@ -10,8 +12,11 @@ class TemplateComponent(RenderableComponent):
|
||||
"""基于独立模板文件的UI组件"""
|
||||
|
||||
_is_standalone_template: bool = True
|
||||
template_path: str | Path
|
||||
data: dict[str, Any]
|
||||
"""标记此组件为独立模板"""
|
||||
template_path: str | Path = Field(..., description="指向HTML模板文件的路径")
|
||||
"""指向HTML模板文件的路径"""
|
||||
data: dict[str, Any] = Field(..., description="传递给模板的上下文数据字典")
|
||||
"""传递给模板的上下文数据字典"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -23,9 +23,11 @@ class TextData(RenderableComponent):
|
||||
"""轻量级富文本组件的数据模型"""
|
||||
|
||||
spans: list[TextSpan] = Field(default_factory=list, description="文本片段列表")
|
||||
"""文本片段列表"""
|
||||
align: Literal["left", "right", "center"] = Field(
|
||||
"left", description="整体文本对齐方式"
|
||||
)
|
||||
"""整体文本对齐方式"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -13,25 +13,35 @@ class HelpItem(BaseModel):
|
||||
"""帮助菜单中的单个功能项"""
|
||||
|
||||
name: str
|
||||
"""功能名称"""
|
||||
description: str
|
||||
"""功能描述"""
|
||||
usage: str
|
||||
"""功能用法说明"""
|
||||
|
||||
|
||||
class HelpCategory(BaseModel):
|
||||
"""帮助菜单中的一个功能类别"""
|
||||
|
||||
title: str
|
||||
"""分类标题"""
|
||||
icon_svg_path: str
|
||||
"""分类图标的SVG路径数据"""
|
||||
items: list[HelpItem]
|
||||
"""该分类下的功能项列表"""
|
||||
|
||||
|
||||
class PluginHelpPageData(RenderableComponent):
|
||||
"""通用插件帮助页面的数据模型"""
|
||||
|
||||
style_name: str | None = None
|
||||
"""页面样式名称"""
|
||||
bot_nickname: str
|
||||
"""机器人昵称"""
|
||||
page_title: str
|
||||
"""页面主标题"""
|
||||
categories: list[HelpCategory]
|
||||
"""帮助分类列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -13,29 +13,43 @@ class PluginMenuItem(BaseModel):
|
||||
"""插件菜单中的单个插件项"""
|
||||
|
||||
id: str
|
||||
"""插件的唯一ID"""
|
||||
name: str
|
||||
"""插件名称"""
|
||||
status: bool
|
||||
"""插件在当前群组的开关状态"""
|
||||
has_superuser_help: bool
|
||||
commands: list[str] = Field(default_factory=list)
|
||||
"""插件是否有超级用户专属帮助"""
|
||||
commands: list[str] = Field(default_factory=list, description="插件的主要命令列表")
|
||||
"""插件的主要命令列表"""
|
||||
|
||||
|
||||
class PluginMenuCategory(BaseModel):
|
||||
"""插件菜单中的一个分类"""
|
||||
|
||||
name: str
|
||||
items: list[PluginMenuItem]
|
||||
"""插件分类名称"""
|
||||
items: list[PluginMenuItem] = Field(..., description="该分类下的插件项列表")
|
||||
"""该分类下的插件项列表"""
|
||||
|
||||
|
||||
class PluginMenuData(RenderableComponent):
|
||||
"""通用插件帮助菜单的数据模型"""
|
||||
|
||||
style_name: str | None = None
|
||||
"""页面样式名称"""
|
||||
bot_name: str
|
||||
"""机器人名称"""
|
||||
bot_avatar_url: str
|
||||
"""机器人头像URL"""
|
||||
is_detail: bool
|
||||
"""是否为详细菜单模式"""
|
||||
plugin_count: int
|
||||
"""总插件数量"""
|
||||
active_count: int
|
||||
"""已启用插件数量"""
|
||||
categories: list[PluginMenuCategory]
|
||||
"""插件分类列表"""
|
||||
|
||||
@property
|
||||
def template_name(self) -> str:
|
||||
|
||||
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from nonebot_plugin_htmlrender import get_browser
|
||||
from nonebot_plugin_htmlrender.browser import get_browser
|
||||
from playwright.async_api import Page
|
||||
|
||||
from zhenxun.utils.message import MessageUtils
|
||||
|
||||
202
zhenxun/utils/log_sanitizer.py
Normal file
202
zhenxun/utils/log_sanitizer.py
Normal file
@ -0,0 +1,202 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from nonebot.adapters import Message, MessageSegment
|
||||
|
||||
|
||||
def _truncate_base64_string(value: str, threshold: int = 256) -> str:
|
||||
"""如果字符串是超长的base64或data URI,则截断它。"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
prefixes = ("base64://", "data:image", "data:video", "data:audio")
|
||||
if value.startswith(prefixes) and len(value) > threshold:
|
||||
prefix = next((p for p in prefixes if value.startswith(p)), "base64")
|
||||
return f"[{prefix}_data_omitted_len={len(value)}]"
|
||||
return value
|
||||
|
||||
|
||||
def _sanitize_ui_html(html_string: str) -> str:
|
||||
"""
|
||||
专门用于净化UI渲染调试HTML的函数。
|
||||
它会查找所有内联的base64数据(如字体、图片)并将其截断。
|
||||
"""
|
||||
if not isinstance(html_string, str):
|
||||
return html_string
|
||||
|
||||
pattern = re.compile(r"(data:[^;]+;base64,)[A-Za-z0-9+/=\s]{100,}")
|
||||
|
||||
def replacer(match):
|
||||
prefix = match.group(1)
|
||||
original_len = len(match.group(0)) - len(prefix)
|
||||
return f"{prefix}[...base64_omitted_len={original_len}...]"
|
||||
|
||||
return pattern.sub(replacer, html_string)
|
||||
|
||||
|
||||
def _sanitize_nonebot_message(message: Message) -> Message:
|
||||
"""净化nonebot.adapter.Message对象,用于日志记录。"""
|
||||
sanitized_message = copy.deepcopy(message)
|
||||
for seg in sanitized_message:
|
||||
seg: MessageSegment
|
||||
if seg.type in ("image", "record", "video"):
|
||||
file_info = seg.data.get("file", "")
|
||||
if isinstance(file_info, str):
|
||||
seg.data["file"] = _truncate_base64_string(file_info)
|
||||
return sanitized_message
|
||||
|
||||
|
||||
def _sanitize_openai_response(response_json: dict) -> dict:
|
||||
"""净化OpenAI兼容API的响应体。"""
|
||||
try:
|
||||
sanitized_json = copy.deepcopy(response_json)
|
||||
if "choices" in sanitized_json and isinstance(sanitized_json["choices"], list):
|
||||
for choice in sanitized_json["choices"]:
|
||||
if "message" in choice and isinstance(choice["message"], dict):
|
||||
message = choice["message"]
|
||||
if "images" in message and isinstance(message["images"], list):
|
||||
for i, image_info in enumerate(message["images"]):
|
||||
if "image_url" in image_info and isinstance(
|
||||
image_info["image_url"], dict
|
||||
):
|
||||
url = image_info["image_url"].get("url", "")
|
||||
message["images"][i]["image_url"]["url"] = (
|
||||
_truncate_base64_string(url)
|
||||
)
|
||||
return sanitized_json
|
||||
except Exception:
|
||||
return response_json
|
||||
|
||||
|
||||
def _sanitize_openai_request(body: dict) -> dict:
|
||||
"""净化OpenAI兼容API的请求体,主要截断图片base64。"""
|
||||
try:
|
||||
sanitized_json = copy.deepcopy(body)
|
||||
if "messages" in sanitized_json and isinstance(
|
||||
sanitized_json["messages"], list
|
||||
):
|
||||
for message in sanitized_json["messages"]:
|
||||
if "content" in message and isinstance(message["content"], list):
|
||||
for i, part in enumerate(message["content"]):
|
||||
if part.get("type") == "image_url":
|
||||
if "image_url" in part and isinstance(
|
||||
part["image_url"], dict
|
||||
):
|
||||
url = part["image_url"].get("url", "")
|
||||
message["content"][i]["image_url"]["url"] = (
|
||||
_truncate_base64_string(url)
|
||||
)
|
||||
return sanitized_json
|
||||
except Exception:
|
||||
return body
|
||||
|
||||
|
||||
def _sanitize_gemini_response(response_json: dict) -> dict:
|
||||
"""净化Gemini API的响应体,处理文本和图片生成两种格式。"""
|
||||
try:
|
||||
sanitized_json = copy.deepcopy(response_json)
|
||||
|
||||
def _process_candidates(candidates_list: list):
|
||||
"""辅助函数,用于处理任何 candidates 列表。"""
|
||||
if not isinstance(candidates_list, list):
|
||||
return
|
||||
for candidate in candidates_list:
|
||||
if "content" in candidate and isinstance(candidate["content"], dict):
|
||||
content = candidate["content"]
|
||||
if "parts" in content and isinstance(content["parts"], list):
|
||||
for i, part in enumerate(content["parts"]):
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data", "")
|
||||
if isinstance(data, str) and len(data) > 256:
|
||||
content["parts"][i]["inlineData"]["data"] = (
|
||||
f"[base64_data_omitted_len={len(data)}]"
|
||||
)
|
||||
|
||||
if "candidates" in sanitized_json:
|
||||
_process_candidates(sanitized_json["candidates"])
|
||||
|
||||
if "image_generation" in sanitized_json and isinstance(
|
||||
sanitized_json["image_generation"], dict
|
||||
):
|
||||
if "candidates" in sanitized_json["image_generation"]:
|
||||
_process_candidates(sanitized_json["image_generation"]["candidates"])
|
||||
|
||||
return sanitized_json
|
||||
except Exception:
|
||||
return response_json
|
||||
|
||||
|
||||
def _sanitize_gemini_request(body: dict) -> dict:
|
||||
"""净化Gemini API的请求体,进行结构转换和总结。"""
|
||||
try:
|
||||
sanitized_body = copy.deepcopy(body)
|
||||
if "contents" in sanitized_body and isinstance(
|
||||
sanitized_body["contents"], list
|
||||
):
|
||||
for content_item in sanitized_body["contents"]:
|
||||
if "parts" in content_item and isinstance(content_item["parts"], list):
|
||||
media_summary = []
|
||||
new_parts = []
|
||||
for part in content_item["parts"]:
|
||||
if "inlineData" in part and isinstance(
|
||||
part["inlineData"], dict
|
||||
):
|
||||
data = part["inlineData"].get("data")
|
||||
if isinstance(data, str):
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "unknown"
|
||||
)
|
||||
media_summary.append(f"{mime_type} ({len(data)} chars)")
|
||||
continue
|
||||
new_parts.append(part)
|
||||
|
||||
if media_summary:
|
||||
summary_text = (
|
||||
f"[多模态内容: {len(media_summary)}个文件 - "
|
||||
f"{', '.join(media_summary)}]"
|
||||
)
|
||||
new_parts.insert(0, {"text": summary_text})
|
||||
|
||||
content_item["parts"] = new_parts
|
||||
return sanitized_body
|
||||
except Exception:
|
||||
return body
|
||||
|
||||
|
||||
def sanitize_for_logging(data: Any, context: str | None = None) -> Any:
|
||||
"""
|
||||
统一的日志净化入口。
|
||||
|
||||
Args:
|
||||
data: 需要净化的数据 (dict, Message, etc.).
|
||||
context: 净化场景的上下文标识,例如 'gemini_request', 'openai_response'.
|
||||
|
||||
Returns:
|
||||
净化后的数据。
|
||||
"""
|
||||
if context == "nonebot_message":
|
||||
if isinstance(data, Message):
|
||||
return _sanitize_nonebot_message(data)
|
||||
elif context == "openai_response":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_openai_response(data)
|
||||
elif context == "gemini_response":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_gemini_response(data)
|
||||
elif context == "gemini_request":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_gemini_request(data)
|
||||
elif context == "openai_request":
|
||||
if isinstance(data, dict):
|
||||
return _sanitize_openai_request(data)
|
||||
elif context == "ui_html":
|
||||
if isinstance(data, str):
|
||||
return _sanitize_ui_html(data)
|
||||
else:
|
||||
if isinstance(data, str):
|
||||
return _truncate_base64_string(data)
|
||||
|
||||
return data
|
||||
@ -247,7 +247,7 @@ class PlatformUtils:
|
||||
if platform != "qq":
|
||||
return None
|
||||
if user_id.isdigit():
|
||||
return f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=160"
|
||||
return f"http://q1.qlogo.cn/g?b=qq&nk={user_id}&s=640"
|
||||
else:
|
||||
return f"https://q.qlogo.cn/qqapp/{appid}/{user_id}/640"
|
||||
|
||||
|
||||
@ -326,7 +326,7 @@ class RepoFileManager:
|
||||
|
||||
# 获取仓库树信息
|
||||
strategy = GitHubStrategy()
|
||||
strategy.body = await GitHubStrategy.parse_repo_info(repo_info)
|
||||
strategy.body = await strategy.parse_repo_info(repo_info)
|
||||
|
||||
# 处理目录路径,确保格式正确
|
||||
if directory_path and not directory_path.endswith("/") and recursive:
|
||||
@ -480,7 +480,7 @@ class RepoFileManager:
|
||||
target_dir: Path | None = None,
|
||||
) -> FileDownloadResult:
|
||||
"""
|
||||
下载单个文件
|
||||
下载多个文件
|
||||
|
||||
参数:
|
||||
repo_url: 仓库URL
|
||||
|
||||
@ -7,6 +7,7 @@ import base64
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from zhenxun.services.log import logger
|
||||
|
||||
@ -145,80 +146,85 @@ async def sparse_checkout_clone(
|
||||
target_dir: Path,
|
||||
) -> None:
|
||||
"""
|
||||
使用 git 稀疏检出克隆指定路径到目标目录(完全独立于主项目 git)。
|
||||
使用 git 稀疏检出克隆指定路径到目标目录(在临时目录中操作)。
|
||||
|
||||
关键保障:
|
||||
- 在 target_dir 下检测/初始化 .git,所有 git 操作均以 cwd=target_dir 执行
|
||||
- 强制拉取与工作区覆盖: fetch --force、checkout -B、reset --hard、clean -xdf
|
||||
- 反复设置 sparse-checkout 路径,确保路径更新生效
|
||||
- 在临时目录中执行所有 git 操作,避免影响 target_dir 中的现有内容
|
||||
- 只操作 target_dir/sparse_path 路径,不影响 target_dir 其他内容
|
||||
"""
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not await check_git():
|
||||
raise GitUnavailableError()
|
||||
|
||||
git_dir = target_dir / ".git"
|
||||
if not git_dir.exists():
|
||||
success, out, err = await run_git_command("init", target_dir)
|
||||
# 在临时目录中进行 git 操作
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# 初始化临时目录为 git 仓库
|
||||
success, out, err = await run_git_command("init", temp_path)
|
||||
if not success:
|
||||
raise RuntimeError(f"git init 失败: {err or out}")
|
||||
success, out, err = await run_git_command(
|
||||
f"remote add origin {repo_url}", target_dir
|
||||
f"remote add origin {repo_url}", temp_path
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f"添加远程失败: {err or out}")
|
||||
else:
|
||||
|
||||
# 启用稀疏检出(使用 --no-cone 模式以获得更精确的控制)
|
||||
await run_git_command("config core.sparseCheckout true", temp_path)
|
||||
await run_git_command("sparse-checkout init --no-cone", temp_path)
|
||||
|
||||
# 设置需要检出的路径(每次都覆盖配置)
|
||||
if not sparse_path:
|
||||
raise RuntimeError("sparse-checkout 路径不能为空")
|
||||
|
||||
# 使用 --no-cone 模式,直接指定要检出的具体路径
|
||||
success, out, err = await run_git_command(
|
||||
f"remote set-url origin {repo_url}", target_dir
|
||||
f"sparse-checkout set {sparse_path}/", temp_path
|
||||
)
|
||||
if not success:
|
||||
# 兜底尝试添加
|
||||
await run_git_command(f"remote add origin {repo_url}", target_dir)
|
||||
raise RuntimeError(f"配置稀疏路径失败: {err or out}")
|
||||
|
||||
# 启用稀疏检出(使用 --no-cone 模式以获得更精确的控制)
|
||||
await run_git_command("config core.sparseCheckout true", target_dir)
|
||||
await run_git_command("sparse-checkout init --no-cone", target_dir)
|
||||
# 强制拉取并同步到远端
|
||||
success, out, err = await run_git_command(
|
||||
f"fetch --force --depth 1 origin {branch}", temp_path
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f"fetch 失败: {err or out}")
|
||||
|
||||
# 设置需要检出的路径(每次都覆盖配置)
|
||||
if not sparse_path:
|
||||
raise RuntimeError("sparse-checkout 路径不能为空")
|
||||
# 使用远端强制更新本地分支并覆盖工作区
|
||||
success, out, err = await run_git_command(
|
||||
f"checkout -B {branch} origin/{branch}", temp_path
|
||||
)
|
||||
if not success:
|
||||
# 回退方案
|
||||
success2, out2, err2 = await run_git_command(
|
||||
f"checkout {branch}", temp_path
|
||||
)
|
||||
if not success2:
|
||||
raise RuntimeError(f"checkout 失败: {(err or out) or (err2 or out2)}")
|
||||
|
||||
# 使用 --no-cone 模式,直接指定要检出的具体路径
|
||||
# 例如:sparse_path="plugins/mahiro" -> 只检出 plugins/mahiro/ 下的内容
|
||||
success, out, err = await run_git_command(
|
||||
f"sparse-checkout set {sparse_path}/", target_dir
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f"配置稀疏路径失败: {err or out}")
|
||||
# 强制对齐工作区
|
||||
await run_git_command(f"reset --hard origin/{branch}", temp_path)
|
||||
await run_git_command("clean -xdf", temp_path)
|
||||
|
||||
# 强制拉取并同步到远端
|
||||
success, out, err = await run_git_command(
|
||||
f"fetch --force --depth 1 origin {branch}", target_dir
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(f"fetch 失败: {err or out}")
|
||||
# 将检出的文件移动到目标位置
|
||||
source_path = temp_path / sparse_path
|
||||
if source_path.exists():
|
||||
# 确保目标路径存在
|
||||
target_path = target_dir / sparse_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用远端强制更新本地分支并覆盖工作区
|
||||
success, out, err = await run_git_command(
|
||||
f"checkout -B {branch} origin/{branch}", target_dir
|
||||
)
|
||||
if not success:
|
||||
# 回退方案
|
||||
success2, out2, err2 = await run_git_command(f"checkout {branch}", target_dir)
|
||||
if not success2:
|
||||
raise RuntimeError(f"checkout 失败: {(err or out) or (err2 or out2)}")
|
||||
# 如果目标路径已存在,先清理
|
||||
if target_path.exists():
|
||||
if target_path.is_dir():
|
||||
shutil.rmtree(target_path)
|
||||
else:
|
||||
target_path.unlink()
|
||||
|
||||
# 强制对齐工作区
|
||||
await run_git_command(f"reset --hard origin/{branch}", target_dir)
|
||||
await run_git_command("clean -xdf", target_dir)
|
||||
|
||||
dir_path = target_dir / Path(sparse_path)
|
||||
for f in dir_path.iterdir():
|
||||
shutil.move(f, target_dir / f.name)
|
||||
dir_name = sparse_path.split("/")[0]
|
||||
rm_path = target_dir / dir_name
|
||||
if rm_path.exists():
|
||||
shutil.rmtree(rm_path)
|
||||
# 移动整个目录结构到目标位置
|
||||
shutil.move(str(source_path), str(target_path))
|
||||
|
||||
|
||||
def prepare_aliyun_url(repo_url: str) -> str:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user