:增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。

This commit is contained in:
HibiKier 2025-07-07 11:05:21 +08:00
parent 61a9b74558
commit 251299c56c
13 changed files with 182 additions and 73 deletions

View File

@ -114,7 +114,7 @@ class BanManage:
if not is_superuser and user_id and session.id1: if not is_superuser and user_id and session.id1:
user_level = await LevelUser.get_user_level(session.id1, group_id) user_level = await LevelUser.get_user_level(session.id1, group_id)
if idx: if idx:
ban_data = await BanConsole.get_or_none(id=idx) ban_data = await BanConsole.get_ban(id=idx)
if not ban_data: if not ban_data:
return False, "该用户/群组不在黑名单中捏..." return False, "该用户/群组不在黑名单中捏..."
if ban_data.ban_level > user_level: if ban_data.ban_level > user_level:

View File

@ -1,10 +1,12 @@
import os import os
from typing import cast
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo from zhenxun.models.task_info import TaskInfo
from zhenxun.services.cache import Cache from zhenxun.services.cache import Cache
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType, PluginType from zhenxun.utils.enum import BlockType, CacheType, PluginType
from zhenxun.utils.exception import GroupInfoNotFound from zhenxun.utils.exception import GroupInfoNotFound
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
@ -117,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"] column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
group = None group = None
if group_id: if group_id:
group = await GroupConsole.get_or_none( group = await GroupConsole.get_group(group_id=group_id)
group_id=group_id, channel_id__isnull=True
)
if not group: if not group:
raise GroupInfoNotFound() raise GroupInfoNotFound()
else: else:
@ -201,20 +201,19 @@ class PluginManage:
) )
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}" return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
if group_id: if group_id:
if group := await GroupConsole.get_or_none( if group := await GroupConsole.get_group(group_id=group_id):
group_id=group_id, channel_id__isnull=True module_list = cast(
): list[str],
module_list = await PluginInfo.filter( await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
plugin_type=PluginType.NORMAL "module", flat=True
).values_list("module", flat=True) ),
)
if status: if status:
for module in module_list: # 开启所有功能 - 清空禁用列表
group.block_plugin = group.block_plugin.replace( group.block_plugin = ""
f"<{module},", ""
)
else: else:
module_list = [f"<{module}" for module in module_list] # 关闭所有功能 - 将模块列表转换为禁用格式
group.block_plugin = ",".join(module_list) + "," # type: ignore group.block_plugin = CommonUtils.convert_module_format(module_list)
await group.save(update_fields=["block_plugin"]) await group.save(update_fields=["block_plugin"])
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}" return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
return "获取群组失败..." return "获取群组失败..."
@ -235,9 +234,7 @@ class PluginManage:
返回: 返回:
bool: 是否醒来 bool: 是否醒来
""" """
if c := await GroupConsole.get_or_none( if c := await GroupConsole.get_group(group_id=group_id):
group_id=group_id, channel_id__isnull=True
):
return c.status return c.status
return False return False
@ -428,17 +425,18 @@ class PluginManage:
""" """
status_str = "关闭" if status else "开启" status_str = "关闭" if status else "开启"
if is_all: if is_all:
modules = await TaskInfo.annotate().values_list("module", flat=True) module_list = cast(
if modules: list[str], await TaskInfo.annotate().values_list("module", flat=True)
)
if module_list:
group, _ = await GroupConsole.get_or_create( group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True group_id=group_id, channel_id__isnull=True
) )
modules = [f"<{module}" for module in modules]
if status: if status:
group.block_task = ",".join(modules) + "," # type: ignore group.block_task = CommonUtils.convert_module_format(module_list)
else: else:
for module in modules: # 开启所有模块 - 清空禁用列表
group.block_task = group.block_task.replace(f"{module},", "") group.block_task = ""
await group.save(update_fields=["block_task"]) await group.save(update_fields=["block_task"])
return f"已成功{status_str}全部被动技能!" return f"已成功{status_str}全部被动技能!"
elif task := await TaskInfo.get_or_none(name=task_name): elif task := await TaskInfo.get_or_none(name=task_name):

View File

@ -45,7 +45,7 @@ async def classify_plugin(
""" """
sort_data = await sort_type() sort_data = await sort_type()
classify: dict[str, list] = {} classify: dict[str, list] = {}
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None group = await GroupConsole.get_group(group_id=group_id) if group_id else None
bot = await BotConsole.get_or_none(bot_id=session.self_id) bot = await BotConsole.get_or_none(bot_id=session.self_id)
for menu, value in sort_data.items(): for menu, value in sort_data.items():
for plugin in value: for plugin in value:

View File

@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
color="black" if idx % 2 else "white", color="black" if idx % 2 else "white",
) )
curr_h = 10 curr_h = 10
group = await GroupConsole.get_or_none(group_id=group_id) group = await GroupConsole.get_group(group_id=group_id) if group_id else None
for _, plugin in enumerate(plugin_list): for _, plugin in enumerate(plugin_list):
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0) text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
if group and f"{plugin.module}," in group.block_plugin: if group and f"{plugin.module}," in group.block_plugin:
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
width, height = 10, 10 width, height = 10, 10
for s in [ for s in [
"目前支持的功能列表:", "目前支持的功能列表:",
"可以通过 ‘帮助 [功能名称或功能Id] 来获取对应功能的使用方法", "可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
]: ]:
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24) text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
await result.paste(text, (width, height)) await result.paste(text, (width, height))

View File

@ -45,9 +45,7 @@ class StatisticsManage:
title = f"{user.user_name if user else user_id} {day_type}功能调用统计" title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
elif group_id: elif group_id:
"""查群组""" """查群组"""
group = await GroupConsole.get_or_none( group = await GroupConsole.get_group(group_id=group_id)
group_id=group_id, channel_id__isnull=True
)
title = f"{group.group_name if group else group_id} {day_type}功能调用统计" title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
else: else:
title = "功能调用统计" title = "功能调用统计"

View File

@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
@_matcher.assign("super-handle", parameterless=[CheckGroupId()]) @_matcher.assign("super-handle", parameterless=[CheckGroupId()])
async def _(session: EventSession, arparma: Arparma, state: T_State): async def _(session: EventSession, arparma: Arparma, state: T_State):
gid = state["group_id"] gid = state["group_id"]
group = await GroupConsole.get_or_none(group_id=gid) group = await GroupConsole.get_group(group_id=gid)
if not group: if not group:
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish() await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
s = "删除" if arparma.find("delete") else "添加" s = "删除" if arparma.find("delete") else "添加"

View File

@ -250,7 +250,7 @@ class ApiDataSource:
返回: 返回:
GroupDetail | None: 群组详情数据 GroupDetail | None: 群组详情数据
""" """
group = await GroupConsole.get_or_none(group_id=group_id) group = await GroupConsole.get_group(group_id=group_id)
if not group: if not group:
return None return None
like_plugin = await cls.__get_group_detail_like_plugin(group_id) like_plugin = await cls.__get_group_detail_like_plugin(group_id)

View File

@ -45,10 +45,10 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
mtime=file_path.stat().st_mtime, mtime=file_path.stat().st_mtime,
) )
) )
sorted(data_list, key=lambda f: f.name) data_list.sort(key=lambda f: f.name)
return Result.ok(data_list) return Result.ok(data_list)
except Exception as e: except Exception as e:
return Result.fail(f"获取文件列表失败: {e!s}") return Result.fail(f"获取文件列表失败: {e!s}")
@router.get( @router.get(

View File

@ -54,12 +54,12 @@ class BanConsole(Model):
raise UserAndGroupIsNone() raise UserAndGroupIsNone()
if user_id: if user_id:
return ( return (
await cls.get_or_none(user_id=user_id, group_id=group_id) await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
if group_id if group_id
else await cls.get_or_none(user_id=user_id, group_id__isnull=True) else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
) )
else: else:
return await cls.get_or_none(user_id="", group_id=group_id) return await cls.safe_get_or_none(user_id="", group_id=group_id)
@classmethod @classmethod
async def check_ban_level( async def check_ban_level(
@ -175,3 +175,25 @@ class BanConsole(Model):
await user.delete() await user.delete()
return True return True
return False return False
@classmethod
async def get_ban(
cls,
*,
id: int | None = None,
user_id: str | None = None,
group_id: str | None = None,
) -> Self | None:
"""安全地获取ban记录
参数:
id: 记录id
user_id: 用户id
group_id: 群组id
返回:
Self | None: ban记录
"""
if id is not None:
return await cls.safe_get_or_none(id=id)
return await cls._get_data(user_id, group_id)

View File

@ -218,20 +218,32 @@ class GroupConsole(Model):
@classmethod @classmethod
async def get_group( async def get_group(
cls, group_id: str, channel_id: str | None = None cls,
group_id: str,
channel_id: str | None = None,
clean_duplicates: bool = True,
) -> Self | None: ) -> Self | None:
"""获取群组 """获取群组
参数: 参数:
group_id: 群组id group_id: 群组id
channel_id: 频道id. channel_id: 频道id
clean_duplicates: 是否删除重复的记录仅保留最新的
返回: 返回:
Self: GroupConsole Self: GroupConsole
""" """
if channel_id: if channel_id:
return await cls.get_or_none(group_id=group_id, channel_id=channel_id) return await cls.safe_get_or_none(
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True) group_id=group_id,
channel_id=channel_id,
clean_duplicates=clean_duplicates,
)
return await cls.safe_get_or_none(
group_id=group_id,
channel_id__isnull=True,
clean_duplicates=clean_duplicates,
)
@classmethod @classmethod
async def is_super_group(cls, group_id: str) -> bool: async def is_super_group(cls, group_id: str) -> bool:

View File

@ -121,7 +121,7 @@ class CacheData(BaseModel):
lazy_load: bool = True # 默认延迟加载 lazy_load: bool = True # 默认延迟加载
result_model: type | None = None result_model: type | None = None
_keys: set[str] = set() # 存储所有缓存键 _keys: set[str] = set() # 存储所有缓存键
_cache: BaseCache | AioCache cache: BaseCache | AioCache
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -208,7 +208,7 @@ class CacheData(BaseModel):
async def get_data(self) -> Any: async def get_data(self) -> Any:
"""从缓存获取数据""" """从缓存获取数据"""
try: try:
data = await self._cache.get(self.name) data = await self.cache.get(self.name) # type: ignore
logger.debug(f"获取缓存 {self.name} 数据: {data}") logger.debug(f"获取缓存 {self.name} 数据: {data}")
# 如果数据为空,尝试重新加载 # 如果数据为空,尝试重新加载
@ -307,11 +307,11 @@ class CacheData(BaseModel):
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}") logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
# 2. 删除旧数据 # 2. 删除旧数据
await self._cache.delete(self.name) await self.cache.delete(self.name) # type: ignore
logger.debug(f"删除缓存 {self.name} 旧数据") logger.debug(f"删除缓存 {self.name} 旧数据")
# 3. 设置新数据 # 3. 设置新数据
await self._cache.set(self.name, serialized_value, ttl=self.expire) await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore
logger.debug(f"设置缓存 {self.name} 新数据完成") logger.debug(f"设置缓存 {self.name} 新数据完成")
except Exception as e: except Exception as e:
@ -321,7 +321,7 @@ class CacheData(BaseModel):
async def delete_data(self): async def delete_data(self):
"""删除缓存数据""" """删除缓存数据"""
try: try:
await self._cache.delete(self.name) await self.cache.delete(self.name) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"删除缓存 {self.name}", e=e) logger.error(f"删除缓存 {self.name}", e=e)
@ -428,7 +428,7 @@ class CacheData(BaseModel):
""" """
cache_key = self._get_cache_key(key) cache_key = self._get_cache_key(key)
try: try:
data = await self._cache.get(cache_key) data = await self.cache.get(cache_key) # type: ignore
logger.debug(f"获取缓存 {cache_key} 数据: {data}") logger.debug(f"获取缓存 {cache_key} 数据: {data}")
if self.result_model: if self.result_model:
@ -467,7 +467,7 @@ class CacheData(BaseModel):
for key in self._keys: for key in self._keys:
# 提取原始键名(去掉前缀) # 提取原始键名(去掉前缀)
original_key = key.split(":", 1)[1] original_key = key.split(":", 1)[1]
data = await self._cache.get(key) data = await self.cache.get(key) # type: ignore
if self.result_model: if self.result_model:
result[original_key] = self._deserialize_value( result[original_key] = self._deserialize_value(
data, self.result_model data, self.result_model
@ -484,7 +484,7 @@ class CacheData(BaseModel):
cache_key = self._get_cache_key(key) cache_key = self._get_cache_key(key)
try: try:
serialized_value = self._serialize_value(value) serialized_value = self._serialize_value(value)
await self._cache.set(cache_key, serialized_value, ttl=self.expire) await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore
self._keys.add(cache_key) # 添加到键列表 self._keys.add(cache_key) # 添加到键列表
logger.debug(f"设置缓存 {cache_key} 数据完成") logger.debug(f"设置缓存 {cache_key} 数据完成")
except Exception as e: except Exception as e:
@ -495,7 +495,7 @@ class CacheData(BaseModel):
"""删除指定键的缓存数据""" """删除指定键的缓存数据"""
cache_key = self._get_cache_key(key) cache_key = self._get_cache_key(key)
try: try:
await self._cache.delete(cache_key) await self.cache.delete(cache_key) # type: ignore
self._keys.discard(cache_key) # 从键列表中移除 self._keys.discard(cache_key) # 从键列表中移除
logger.debug(f"删除缓存 {cache_key} 完成") logger.debug(f"删除缓存 {cache_key} 完成")
except Exception as e: except Exception as e:
@ -505,7 +505,7 @@ class CacheData(BaseModel):
"""清除所有缓存数据""" """清除所有缓存数据"""
try: try:
for key in list(self._keys): # 使用列表复制避免在迭代时修改 for key in list(self._keys): # 使用列表复制避免在迭代时修改
await self._cache.delete(key) await self.cache.delete(key) # type: ignore
self._keys.clear() self._keys.clear()
logger.debug(f"清除缓存 {self.name} 完成") logger.debug(f"清除缓存 {self.name} 完成")
except Exception as e: except Exception as e:
@ -524,7 +524,7 @@ class CacheManager:
if self._cache_instance is None: if self._cache_instance is None:
if config.redis_host: if config.redis_host:
self._cache_instance = AioCache( self._cache_instance = AioCache(
AioCache.REDIS, AioCache.REDIS, # type: ignore
serializer=JsonSerializer(), serializer=JsonSerializer(),
namespace="zhenxun_cache", namespace="zhenxun_cache",
timeout=30, # 操作超时时间 timeout=30, # 操作超时时间
@ -546,12 +546,12 @@ class CacheManager:
async def close(self): async def close(self):
if self._cache_instance: if self._cache_instance:
await self._cache_instance.close() await self._cache_instance.close() # type: ignore
async def verify_connection(self): async def verify_connection(self):
"""连接测试""" """连接测试"""
try: try:
await self._cache.get("__test__") await self._cache.get("__test__") # type: ignore
except Exception as e: except Exception as e:
logger.error("连接失败", LOG_COMMAND, e=e) logger.error("连接失败", LOG_COMMAND, e=e)
raise raise
@ -560,7 +560,7 @@ class CacheManager:
"""初始化所有非延迟加载的缓存""" """初始化所有非延迟加载的缓存"""
await self.verify_connection() await self.verify_connection()
for name, cache in self._data.items(): for name, cache in self._data.items():
cache._cache = self._cache cache.cache = self._cache
if not cache.lazy_load: if not cache.lazy_load:
try: try:
await cache.reload() await cache.reload()
@ -587,7 +587,7 @@ class CacheManager:
func=func, func=func,
expire=expire, expire=expire,
lazy_load=lazy_load, lazy_load=lazy_load,
_cache=self._cache, cache=self._cache,
) )
return func return func

View File

@ -72,12 +72,23 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None, using_db: BaseDBAsyncClient | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[Self, bool]: ) -> tuple[Self, bool]:
result, is_create = await super().get_or_create( if sem := cls.get_semaphore(DbLockType.CREATE):
defaults=defaults, using_db=using_db, **kwargs async with sem:
) # 在锁内执行查询和创建操作
if is_create and (cache_type := cls.get_cache_type()): result, is_create = await super().get_or_create(
await CacheRoot.reload(cache_type) defaults=defaults, using_db=using_db, **kwargs
return (result, is_create) )
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return (result, is_create)
else:
# 如果没有锁,则执行原来的逻辑
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return (result, is_create)
@classmethod @classmethod
async def update_or_create( async def update_or_create(
@ -86,12 +97,23 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None, using_db: BaseDBAsyncClient | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[Self, bool]: ) -> tuple[Self, bool]:
result = await super().update_or_create( if sem := cls.get_semaphore(DbLockType.CREATE):
defaults=defaults, using_db=using_db, **kwargs async with sem:
) # 在锁内执行查询和创建操作
if cache_type := cls.get_cache_type(): result = await super().update_or_create(
await CacheRoot.reload(cache_type) defaults=defaults, using_db=using_db, **kwargs
return result )
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
else:
# 如果没有锁,则执行原来的逻辑
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
@classmethod @classmethod
async def bulk_create( # type: ignore async def bulk_create( # type: ignore
@ -167,6 +189,65 @@ class Model(TortoiseModel):
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)): if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type) await CacheRoot.reload(cache_type)
@classmethod
async def safe_get_or_none(
cls,
*args,
using_db: BaseDBAsyncClient | None = None,
clean_duplicates: bool = True,
**kwargs: Any,
) -> Self | None:
"""安全地获取一条记录或None处理存在多个记录时返回最新的那个
注意默认会删除重复的记录仅保留最新的
参数:
*args: 查询参数
using_db: 数据库连接
clean_duplicates: 是否删除重复的记录仅保留最新的
**kwargs: 查询参数
返回:
Self | None: 查询结果如果不存在返回None
"""
try:
# 先尝试使用 get_or_none 获取单个记录
return await cls.get_or_none(*args, using_db=using_db, **kwargs)
except Exception as e:
# 如果出现错误(可能是存在多个记录)
if "Multiple objects" in str(e):
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}"
)
# 查询所有匹配记录
records = await cls.filter(*args, **kwargs)
if not records:
return None
# 如果需要清理重复记录
if clean_duplicates and hasattr(cls, "id"):
# 按 id 排序
records = sorted(
records, key=lambda x: getattr(x, "id", 0), reverse=True
)
for record in records[1:]:
try:
await record.delete()
logger.info(
f"{cls.__name__} 删除重复记录:"
f" id={getattr(record, 'id', None)}"
)
except Exception as del_e:
logger.error(f"删除重复记录失败: {del_e}")
return records[0]
# 如果不需要清理或没有 id 字段,则返回最新的记录
if hasattr(cls, "id"):
return await cls.filter(*args, **kwargs).order_by("-id").first()
# 如果没有 id 字段,则返回第一个记录
return await cls.filter(*args, **kwargs).first()
# 其他类型的错误则继续抛出
raise
class DbUrlIsNode(HookPriorityException): class DbUrlIsNode(HookPriorityException):
""" """

View File

@ -53,9 +53,7 @@ class CommonUtils:
if await GroupConsole.is_block_task(group_id, module): if await GroupConsole.is_block_task(group_id, module):
"""群组是否禁用被动""" """群组是否禁用被动"""
return True return True
if g := await GroupConsole.get_or_none( if g := await GroupConsole.get_group(group_id=group_id):
group_id=group_id, channel_id__isnull=True
):
"""群组权限是否小于0""" """群组权限是否小于0"""
if g.level < 0: if g.level < 0:
return True return True