:增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。

This commit is contained in:
HibiKier 2025-07-07 11:05:21 +08:00
parent 61a9b74558
commit 251299c56c
13 changed files with 182 additions and 73 deletions

View File

@ -114,7 +114,7 @@ class BanManage:
if not is_superuser and user_id and session.id1:
user_level = await LevelUser.get_user_level(session.id1, group_id)
if idx:
ban_data = await BanConsole.get_or_none(id=idx)
ban_data = await BanConsole.get_ban(id=idx)
if not ban_data:
return False, "该用户/群组不在黑名单中捏..."
if ban_data.ban_level > user_level:

View File

@ -1,10 +1,12 @@
import os
from typing import cast
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.cache import Cache
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType, PluginType
from zhenxun.utils.exception import GroupInfoNotFound
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
@ -117,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
group = None
if group_id:
group = await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
)
group = await GroupConsole.get_group(group_id=group_id)
if not group:
raise GroupInfoNotFound()
else:
@ -201,20 +201,19 @@ class PluginManage:
)
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
if group_id:
if group := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
):
module_list = await PluginInfo.filter(
plugin_type=PluginType.NORMAL
).values_list("module", flat=True)
if group := await GroupConsole.get_group(group_id=group_id):
module_list = cast(
list[str],
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
"module", flat=True
),
)
if status:
for module in module_list:
group.block_plugin = group.block_plugin.replace(
f"<{module},", ""
)
# 开启所有功能 - 清空禁用列表
group.block_plugin = ""
else:
module_list = [f"<{module}" for module in module_list]
group.block_plugin = ",".join(module_list) + "," # type: ignore
# 关闭所有功能 - 将模块列表转换为禁用格式
group.block_plugin = CommonUtils.convert_module_format(module_list)
await group.save(update_fields=["block_plugin"])
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
return "获取群组失败..."
@ -235,9 +234,7 @@ class PluginManage:
返回:
bool: 是否醒来
"""
if c := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
):
if c := await GroupConsole.get_group(group_id=group_id):
return c.status
return False
@ -428,17 +425,18 @@ class PluginManage:
"""
status_str = "关闭" if status else "开启"
if is_all:
modules = await TaskInfo.annotate().values_list("module", flat=True)
if modules:
module_list = cast(
list[str], await TaskInfo.annotate().values_list("module", flat=True)
)
if module_list:
group, _ = await GroupConsole.get_or_create(
group_id=group_id, channel_id__isnull=True
)
modules = [f"<{module}" for module in modules]
if status:
group.block_task = ",".join(modules) + "," # type: ignore
group.block_task = CommonUtils.convert_module_format(module_list)
else:
for module in modules:
group.block_task = group.block_task.replace(f"{module},", "")
# 开启所有模块 - 清空禁用列表
group.block_task = ""
await group.save(update_fields=["block_task"])
return f"已成功{status_str}全部被动技能!"
elif task := await TaskInfo.get_or_none(name=task_name):

View File

@ -45,7 +45,7 @@ async def classify_plugin(
"""
sort_data = await sort_type()
classify: dict[str, list] = {}
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
bot = await BotConsole.get_or_none(bot_id=session.self_id)
for menu, value in sort_data.items():
for plugin in value:

View File

@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
color="black" if idx % 2 else "white",
)
curr_h = 10
group = await GroupConsole.get_or_none(group_id=group_id)
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
for _, plugin in enumerate(plugin_list):
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
if group and f"{plugin.module}," in group.block_plugin:
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
width, height = 10, 10
for s in [
"目前支持的功能列表:",
"可以通过 ‘帮助 [功能名称或功能Id] 来获取对应功能的使用方法",
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
]:
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
await result.paste(text, (width, height))

View File

@ -45,9 +45,7 @@ class StatisticsManage:
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
elif group_id:
"""查群组"""
group = await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
)
group = await GroupConsole.get_group(group_id=group_id)
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
else:
title = "功能调用统计"

View File

@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
async def _(session: EventSession, arparma: Arparma, state: T_State):
gid = state["group_id"]
group = await GroupConsole.get_or_none(group_id=gid)
group = await GroupConsole.get_group(group_id=gid)
if not group:
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
s = "删除" if arparma.find("delete") else "添加"

View File

@ -250,7 +250,7 @@ class ApiDataSource:
返回:
GroupDetail | None: 群组详情数据
"""
group = await GroupConsole.get_or_none(group_id=group_id)
group = await GroupConsole.get_group(group_id=group_id)
if not group:
return None
like_plugin = await cls.__get_group_detail_like_plugin(group_id)

View File

@ -45,10 +45,10 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
mtime=file_path.stat().st_mtime,
)
)
sorted(data_list, key=lambda f: f.name)
return Result.ok(data_list)
data_list.sort(key=lambda f: f.name)
return Result.ok(data_list)
except Exception as e:
return Result.fail(f"获取文件列表失败: {e!s}")
return Result.fail(f"获取文件列表失败: {e!s}")
@router.get(

View File

@ -54,12 +54,12 @@ class BanConsole(Model):
raise UserAndGroupIsNone()
if user_id:
return (
await cls.get_or_none(user_id=user_id, group_id=group_id)
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
if group_id
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
)
else:
return await cls.get_or_none(user_id="", group_id=group_id)
return await cls.safe_get_or_none(user_id="", group_id=group_id)
@classmethod
async def check_ban_level(
@ -175,3 +175,25 @@ class BanConsole(Model):
await user.delete()
return True
return False
@classmethod
async def get_ban(
cls,
*,
id: int | None = None,
user_id: str | None = None,
group_id: str | None = None,
) -> Self | None:
"""安全地获取ban记录
参数:
id: 记录id
user_id: 用户id
group_id: 群组id
返回:
Self | None: ban记录
"""
if id is not None:
return await cls.safe_get_or_none(id=id)
return await cls._get_data(user_id, group_id)

View File

@ -218,20 +218,32 @@ class GroupConsole(Model):
@classmethod
async def get_group(
cls, group_id: str, channel_id: str | None = None
cls,
group_id: str,
channel_id: str | None = None,
clean_duplicates: bool = True,
) -> Self | None:
"""获取群组
参数:
group_id: 群组id
channel_id: 频道id.
channel_id: 频道id
clean_duplicates: 是否删除重复的记录仅保留最新的
返回:
Self: GroupConsole
"""
if channel_id:
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
return await cls.safe_get_or_none(
group_id=group_id,
channel_id=channel_id,
clean_duplicates=clean_duplicates,
)
return await cls.safe_get_or_none(
group_id=group_id,
channel_id__isnull=True,
clean_duplicates=clean_duplicates,
)
@classmethod
async def is_super_group(cls, group_id: str) -> bool:

View File

@ -121,7 +121,7 @@ class CacheData(BaseModel):
lazy_load: bool = True # 默认延迟加载
result_model: type | None = None
_keys: set[str] = set() # 存储所有缓存键
_cache: BaseCache | AioCache
cache: BaseCache | AioCache
class Config:
arbitrary_types_allowed = True
@ -208,7 +208,7 @@ class CacheData(BaseModel):
async def get_data(self) -> Any:
"""从缓存获取数据"""
try:
data = await self._cache.get(self.name)
data = await self.cache.get(self.name) # type: ignore
logger.debug(f"获取缓存 {self.name} 数据: {data}")
# 如果数据为空,尝试重新加载
@ -307,11 +307,11 @@ class CacheData(BaseModel):
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
# 2. 删除旧数据
await self._cache.delete(self.name)
await self.cache.delete(self.name) # type: ignore
logger.debug(f"删除缓存 {self.name} 旧数据")
# 3. 设置新数据
await self._cache.set(self.name, serialized_value, ttl=self.expire)
await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore
logger.debug(f"设置缓存 {self.name} 新数据完成")
except Exception as e:
@ -321,7 +321,7 @@ class CacheData(BaseModel):
async def delete_data(self):
"""删除缓存数据"""
try:
await self._cache.delete(self.name)
await self.cache.delete(self.name) # type: ignore
except Exception as e:
logger.error(f"删除缓存 {self.name}", e=e)
@ -428,7 +428,7 @@ class CacheData(BaseModel):
"""
cache_key = self._get_cache_key(key)
try:
data = await self._cache.get(cache_key)
data = await self.cache.get(cache_key) # type: ignore
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
if self.result_model:
@ -467,7 +467,7 @@ class CacheData(BaseModel):
for key in self._keys:
# 提取原始键名(去掉前缀)
original_key = key.split(":", 1)[1]
data = await self._cache.get(key)
data = await self.cache.get(key) # type: ignore
if self.result_model:
result[original_key] = self._deserialize_value(
data, self.result_model
@ -484,7 +484,7 @@ class CacheData(BaseModel):
cache_key = self._get_cache_key(key)
try:
serialized_value = self._serialize_value(value)
await self._cache.set(cache_key, serialized_value, ttl=self.expire)
await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore
self._keys.add(cache_key) # 添加到键列表
logger.debug(f"设置缓存 {cache_key} 数据完成")
except Exception as e:
@ -495,7 +495,7 @@ class CacheData(BaseModel):
"""删除指定键的缓存数据"""
cache_key = self._get_cache_key(key)
try:
await self._cache.delete(cache_key)
await self.cache.delete(cache_key) # type: ignore
self._keys.discard(cache_key) # 从键列表中移除
logger.debug(f"删除缓存 {cache_key} 完成")
except Exception as e:
@ -505,7 +505,7 @@ class CacheData(BaseModel):
"""清除所有缓存数据"""
try:
for key in list(self._keys): # 使用列表复制避免在迭代时修改
await self._cache.delete(key)
await self.cache.delete(key) # type: ignore
self._keys.clear()
logger.debug(f"清除缓存 {self.name} 完成")
except Exception as e:
@ -524,7 +524,7 @@ class CacheManager:
if self._cache_instance is None:
if config.redis_host:
self._cache_instance = AioCache(
AioCache.REDIS,
AioCache.REDIS, # type: ignore
serializer=JsonSerializer(),
namespace="zhenxun_cache",
timeout=30, # 操作超时时间
@ -546,12 +546,12 @@ class CacheManager:
async def close(self):
if self._cache_instance:
await self._cache_instance.close()
await self._cache_instance.close() # type: ignore
async def verify_connection(self):
"""连接测试"""
try:
await self._cache.get("__test__")
await self._cache.get("__test__") # type: ignore
except Exception as e:
logger.error("连接失败", LOG_COMMAND, e=e)
raise
@ -560,7 +560,7 @@ class CacheManager:
"""初始化所有非延迟加载的缓存"""
await self.verify_connection()
for name, cache in self._data.items():
cache._cache = self._cache
cache.cache = self._cache
if not cache.lazy_load:
try:
await cache.reload()
@ -587,7 +587,7 @@ class CacheManager:
func=func,
expire=expire,
lazy_load=lazy_load,
_cache=self._cache,
cache=self._cache,
)
return func

View File

@ -72,12 +72,23 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return (result, is_create)
if sem := cls.get_semaphore(DbLockType.CREATE):
async with sem:
# 在锁内执行查询和创建操作
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return (result, is_create)
else:
# 如果没有锁,则执行原来的逻辑
result, is_create = await super().get_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
return (result, is_create)
@classmethod
async def update_or_create(
@ -86,12 +97,23 @@ class Model(TortoiseModel):
using_db: BaseDBAsyncClient | None = None,
**kwargs: Any,
) -> tuple[Self, bool]:
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
if sem := cls.get_semaphore(DbLockType.CREATE):
async with sem:
# 在锁内执行查询和创建操作
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
else:
# 如果没有锁,则执行原来的逻辑
result = await super().update_or_create(
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
return result
@classmethod
async def bulk_create( # type: ignore
@ -167,6 +189,65 @@ class Model(TortoiseModel):
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type)
@classmethod
async def safe_get_or_none(
cls,
*args,
using_db: BaseDBAsyncClient | None = None,
clean_duplicates: bool = True,
**kwargs: Any,
) -> Self | None:
"""安全地获取一条记录或None处理存在多个记录时返回最新的那个
注意默认会删除重复的记录仅保留最新的
参数:
*args: 查询参数
using_db: 数据库连接
clean_duplicates: 是否删除重复的记录仅保留最新的
**kwargs: 查询参数
返回:
Self | None: 查询结果如果不存在返回None
"""
try:
# 先尝试使用 get_or_none 获取单个记录
return await cls.get_or_none(*args, using_db=using_db, **kwargs)
except Exception as e:
# 如果出现错误(可能是存在多个记录)
if "Multiple objects" in str(e):
logger.warning(
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}"
)
# 查询所有匹配记录
records = await cls.filter(*args, **kwargs)
if not records:
return None
# 如果需要清理重复记录
if clean_duplicates and hasattr(cls, "id"):
# 按 id 排序
records = sorted(
records, key=lambda x: getattr(x, "id", 0), reverse=True
)
for record in records[1:]:
try:
await record.delete()
logger.info(
f"{cls.__name__} 删除重复记录:"
f" id={getattr(record, 'id', None)}"
)
except Exception as del_e:
logger.error(f"删除重复记录失败: {del_e}")
return records[0]
# 如果不需要清理或没有 id 字段,则返回最新的记录
if hasattr(cls, "id"):
return await cls.filter(*args, **kwargs).order_by("-id").first()
# 如果没有 id 字段,则返回第一个记录
return await cls.filter(*args, **kwargs).first()
# 其他类型的错误则继续抛出
raise
class DbUrlIsNode(HookPriorityException):
"""

View File

@ -53,9 +53,7 @@ class CommonUtils:
if await GroupConsole.is_block_task(group_id, module):
"""群组是否禁用被动"""
return True
if g := await GroupConsole.get_or_none(
group_id=group_id, channel_id__isnull=True
):
if g := await GroupConsole.get_group(group_id=group_id):
"""群组权限是否小于0"""
if g.level < 0:
return True