mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⚡ :增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。
This commit is contained in:
parent
61a9b74558
commit
251299c56c
@ -114,7 +114,7 @@ class BanManage:
|
||||
if not is_superuser and user_id and session.id1:
|
||||
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
||||
if idx:
|
||||
ban_data = await BanConsole.get_or_none(id=idx)
|
||||
ban_data = await BanConsole.get_ban(id=idx)
|
||||
if not ban_data:
|
||||
return False, "该用户/群组不在黑名单中捏..."
|
||||
if ban_data.ban_level > user_level:
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.task_info import TaskInfo
|
||||
from zhenxun.services.cache import Cache
|
||||
from zhenxun.utils.common_utils import CommonUtils
|
||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||
from zhenxun.utils.exception import GroupInfoNotFound
|
||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||
@ -117,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
||||
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
||||
group = None
|
||||
if group_id:
|
||||
group = await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
if not group:
|
||||
raise GroupInfoNotFound()
|
||||
else:
|
||||
@ -201,20 +201,19 @@ class PluginManage:
|
||||
)
|
||||
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
||||
if group_id:
|
||||
if group := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
module_list = await PluginInfo.filter(
|
||||
plugin_type=PluginType.NORMAL
|
||||
).values_list("module", flat=True)
|
||||
if group := await GroupConsole.get_group(group_id=group_id):
|
||||
module_list = cast(
|
||||
list[str],
|
||||
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
|
||||
"module", flat=True
|
||||
),
|
||||
)
|
||||
if status:
|
||||
for module in module_list:
|
||||
group.block_plugin = group.block_plugin.replace(
|
||||
f"<{module},", ""
|
||||
)
|
||||
# 开启所有功能 - 清空禁用列表
|
||||
group.block_plugin = ""
|
||||
else:
|
||||
module_list = [f"<{module}" for module in module_list]
|
||||
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
||||
# 关闭所有功能 - 将模块列表转换为禁用格式
|
||||
group.block_plugin = CommonUtils.convert_module_format(module_list)
|
||||
await group.save(update_fields=["block_plugin"])
|
||||
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
||||
return "获取群组失败..."
|
||||
@ -235,9 +234,7 @@ class PluginManage:
|
||||
返回:
|
||||
bool: 是否醒来
|
||||
"""
|
||||
if c := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
if c := await GroupConsole.get_group(group_id=group_id):
|
||||
return c.status
|
||||
return False
|
||||
|
||||
@ -428,17 +425,18 @@ class PluginManage:
|
||||
"""
|
||||
status_str = "关闭" if status else "开启"
|
||||
if is_all:
|
||||
modules = await TaskInfo.annotate().values_list("module", flat=True)
|
||||
if modules:
|
||||
module_list = cast(
|
||||
list[str], await TaskInfo.annotate().values_list("module", flat=True)
|
||||
)
|
||||
if module_list:
|
||||
group, _ = await GroupConsole.get_or_create(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
modules = [f"<{module}" for module in modules]
|
||||
if status:
|
||||
group.block_task = ",".join(modules) + "," # type: ignore
|
||||
group.block_task = CommonUtils.convert_module_format(module_list)
|
||||
else:
|
||||
for module in modules:
|
||||
group.block_task = group.block_task.replace(f"{module},", "")
|
||||
# 开启所有模块 - 清空禁用列表
|
||||
group.block_task = ""
|
||||
await group.save(update_fields=["block_task"])
|
||||
return f"已成功{status_str}全部被动技能!"
|
||||
elif task := await TaskInfo.get_or_none(name=task_name):
|
||||
|
||||
@ -45,7 +45,7 @@ async def classify_plugin(
|
||||
"""
|
||||
sort_data = await sort_type()
|
||||
classify: dict[str, list] = {}
|
||||
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
|
||||
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
||||
for menu, value in sort_data.items():
|
||||
for plugin in value:
|
||||
|
||||
@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
||||
color="black" if idx % 2 else "white",
|
||||
)
|
||||
curr_h = 10
|
||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
||||
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||
for _, plugin in enumerate(plugin_list):
|
||||
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
||||
if group and f"{plugin.module}," in group.block_plugin:
|
||||
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
||||
width, height = 10, 10
|
||||
for s in [
|
||||
"目前支持的功能列表:",
|
||||
"可以通过 ‘帮助 [功能名称或功能Id]’ 来获取对应功能的使用方法",
|
||||
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
|
||||
]:
|
||||
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
||||
await result.paste(text, (width, height))
|
||||
|
||||
@ -45,9 +45,7 @@ class StatisticsManage:
|
||||
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
||||
elif group_id:
|
||||
"""查群组"""
|
||||
group = await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
||||
else:
|
||||
title = "功能调用统计"
|
||||
|
||||
@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
|
||||
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||
gid = state["group_id"]
|
||||
group = await GroupConsole.get_or_none(group_id=gid)
|
||||
group = await GroupConsole.get_group(group_id=gid)
|
||||
if not group:
|
||||
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
||||
s = "删除" if arparma.find("delete") else "添加"
|
||||
|
||||
@ -250,7 +250,7 @@ class ApiDataSource:
|
||||
返回:
|
||||
GroupDetail | None: 群组详情数据
|
||||
"""
|
||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
||||
group = await GroupConsole.get_group(group_id=group_id)
|
||||
if not group:
|
||||
return None
|
||||
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
||||
|
||||
@ -45,10 +45,10 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
|
||||
mtime=file_path.stat().st_mtime,
|
||||
)
|
||||
)
|
||||
sorted(data_list, key=lambda f: f.name)
|
||||
return Result.ok(data_list)
|
||||
data_list.sort(key=lambda f: f.name)
|
||||
return Result.ok(data_list)
|
||||
except Exception as e:
|
||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@ -54,12 +54,12 @@ class BanConsole(Model):
|
||||
raise UserAndGroupIsNone()
|
||||
if user_id:
|
||||
return (
|
||||
await cls.get_or_none(user_id=user_id, group_id=group_id)
|
||||
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||
if group_id
|
||||
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||
)
|
||||
else:
|
||||
return await cls.get_or_none(user_id="", group_id=group_id)
|
||||
return await cls.safe_get_or_none(user_id="", group_id=group_id)
|
||||
|
||||
@classmethod
|
||||
async def check_ban_level(
|
||||
@ -175,3 +175,25 @@ class BanConsole(Model):
|
||||
await user.delete()
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_ban(
|
||||
cls,
|
||||
*,
|
||||
id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
group_id: str | None = None,
|
||||
) -> Self | None:
|
||||
"""安全地获取ban记录
|
||||
|
||||
参数:
|
||||
id: 记录id
|
||||
user_id: 用户id
|
||||
group_id: 群组id
|
||||
|
||||
返回:
|
||||
Self | None: ban记录
|
||||
"""
|
||||
if id is not None:
|
||||
return await cls.safe_get_or_none(id=id)
|
||||
return await cls._get_data(user_id, group_id)
|
||||
|
||||
@ -218,20 +218,32 @@ class GroupConsole(Model):
|
||||
|
||||
@classmethod
|
||||
async def get_group(
|
||||
cls, group_id: str, channel_id: str | None = None
|
||||
cls,
|
||||
group_id: str,
|
||||
channel_id: str | None = None,
|
||||
clean_duplicates: bool = True,
|
||||
) -> Self | None:
|
||||
"""获取群组
|
||||
|
||||
参数:
|
||||
group_id: 群组id
|
||||
channel_id: 频道id.
|
||||
channel_id: 频道id
|
||||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||
|
||||
返回:
|
||||
Self: GroupConsole
|
||||
"""
|
||||
if channel_id:
|
||||
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
|
||||
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
|
||||
return await cls.safe_get_or_none(
|
||||
group_id=group_id,
|
||||
channel_id=channel_id,
|
||||
clean_duplicates=clean_duplicates,
|
||||
)
|
||||
return await cls.safe_get_or_none(
|
||||
group_id=group_id,
|
||||
channel_id__isnull=True,
|
||||
clean_duplicates=clean_duplicates,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def is_super_group(cls, group_id: str) -> bool:
|
||||
|
||||
@ -121,7 +121,7 @@ class CacheData(BaseModel):
|
||||
lazy_load: bool = True # 默认延迟加载
|
||||
result_model: type | None = None
|
||||
_keys: set[str] = set() # 存储所有缓存键
|
||||
_cache: BaseCache | AioCache
|
||||
cache: BaseCache | AioCache
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -208,7 +208,7 @@ class CacheData(BaseModel):
|
||||
async def get_data(self) -> Any:
|
||||
"""从缓存获取数据"""
|
||||
try:
|
||||
data = await self._cache.get(self.name)
|
||||
data = await self.cache.get(self.name) # type: ignore
|
||||
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
||||
|
||||
# 如果数据为空,尝试重新加载
|
||||
@ -307,11 +307,11 @@ class CacheData(BaseModel):
|
||||
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
|
||||
|
||||
# 2. 删除旧数据
|
||||
await self._cache.delete(self.name)
|
||||
await self.cache.delete(self.name) # type: ignore
|
||||
logger.debug(f"删除缓存 {self.name} 旧数据")
|
||||
|
||||
# 3. 设置新数据
|
||||
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
||||
await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore
|
||||
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
||||
|
||||
except Exception as e:
|
||||
@ -321,7 +321,7 @@ class CacheData(BaseModel):
|
||||
async def delete_data(self):
|
||||
"""删除缓存数据"""
|
||||
try:
|
||||
await self._cache.delete(self.name)
|
||||
await self.cache.delete(self.name) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"删除缓存 {self.name}", e=e)
|
||||
|
||||
@ -428,7 +428,7 @@ class CacheData(BaseModel):
|
||||
"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
try:
|
||||
data = await self._cache.get(cache_key)
|
||||
data = await self.cache.get(cache_key) # type: ignore
|
||||
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
|
||||
|
||||
if self.result_model:
|
||||
@ -467,7 +467,7 @@ class CacheData(BaseModel):
|
||||
for key in self._keys:
|
||||
# 提取原始键名(去掉前缀)
|
||||
original_key = key.split(":", 1)[1]
|
||||
data = await self._cache.get(key)
|
||||
data = await self.cache.get(key) # type: ignore
|
||||
if self.result_model:
|
||||
result[original_key] = self._deserialize_value(
|
||||
data, self.result_model
|
||||
@ -484,7 +484,7 @@ class CacheData(BaseModel):
|
||||
cache_key = self._get_cache_key(key)
|
||||
try:
|
||||
serialized_value = self._serialize_value(value)
|
||||
await self._cache.set(cache_key, serialized_value, ttl=self.expire)
|
||||
await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore
|
||||
self._keys.add(cache_key) # 添加到键列表
|
||||
logger.debug(f"设置缓存 {cache_key} 数据完成")
|
||||
except Exception as e:
|
||||
@ -495,7 +495,7 @@ class CacheData(BaseModel):
|
||||
"""删除指定键的缓存数据"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
try:
|
||||
await self._cache.delete(cache_key)
|
||||
await self.cache.delete(cache_key) # type: ignore
|
||||
self._keys.discard(cache_key) # 从键列表中移除
|
||||
logger.debug(f"删除缓存 {cache_key} 完成")
|
||||
except Exception as e:
|
||||
@ -505,7 +505,7 @@ class CacheData(BaseModel):
|
||||
"""清除所有缓存数据"""
|
||||
try:
|
||||
for key in list(self._keys): # 使用列表复制避免在迭代时修改
|
||||
await self._cache.delete(key)
|
||||
await self.cache.delete(key) # type: ignore
|
||||
self._keys.clear()
|
||||
logger.debug(f"清除缓存 {self.name} 完成")
|
||||
except Exception as e:
|
||||
@ -524,7 +524,7 @@ class CacheManager:
|
||||
if self._cache_instance is None:
|
||||
if config.redis_host:
|
||||
self._cache_instance = AioCache(
|
||||
AioCache.REDIS,
|
||||
AioCache.REDIS, # type: ignore
|
||||
serializer=JsonSerializer(),
|
||||
namespace="zhenxun_cache",
|
||||
timeout=30, # 操作超时时间
|
||||
@ -546,12 +546,12 @@ class CacheManager:
|
||||
|
||||
async def close(self):
|
||||
if self._cache_instance:
|
||||
await self._cache_instance.close()
|
||||
await self._cache_instance.close() # type: ignore
|
||||
|
||||
async def verify_connection(self):
|
||||
"""连接测试"""
|
||||
try:
|
||||
await self._cache.get("__test__")
|
||||
await self._cache.get("__test__") # type: ignore
|
||||
except Exception as e:
|
||||
logger.error("连接失败", LOG_COMMAND, e=e)
|
||||
raise
|
||||
@ -560,7 +560,7 @@ class CacheManager:
|
||||
"""初始化所有非延迟加载的缓存"""
|
||||
await self.verify_connection()
|
||||
for name, cache in self._data.items():
|
||||
cache._cache = self._cache
|
||||
cache.cache = self._cache
|
||||
if not cache.lazy_load:
|
||||
try:
|
||||
await cache.reload()
|
||||
@ -587,7 +587,7 @@ class CacheManager:
|
||||
func=func,
|
||||
expire=expire,
|
||||
lazy_load=lazy_load,
|
||||
_cache=self._cache,
|
||||
cache=self._cache,
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
@ -72,12 +72,23 @@ class Model(TortoiseModel):
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
result, is_create = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if is_create and (cache_type := cls.get_cache_type()):
|
||||
await CacheRoot.reload(cache_type)
|
||||
return (result, is_create)
|
||||
if sem := cls.get_semaphore(DbLockType.CREATE):
|
||||
async with sem:
|
||||
# 在锁内执行查询和创建操作
|
||||
result, is_create = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if is_create and (cache_type := cls.get_cache_type()):
|
||||
await CacheRoot.reload(cache_type)
|
||||
return (result, is_create)
|
||||
else:
|
||||
# 如果没有锁,则执行原来的逻辑
|
||||
result, is_create = await super().get_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if is_create and (cache_type := cls.get_cache_type()):
|
||||
await CacheRoot.reload(cache_type)
|
||||
return (result, is_create)
|
||||
|
||||
@classmethod
|
||||
async def update_or_create(
|
||||
@ -86,12 +97,23 @@ class Model(TortoiseModel):
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Self, bool]:
|
||||
result = await super().update_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.reload(cache_type)
|
||||
return result
|
||||
if sem := cls.get_semaphore(DbLockType.CREATE):
|
||||
async with sem:
|
||||
# 在锁内执行查询和创建操作
|
||||
result = await super().update_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.reload(cache_type)
|
||||
return result
|
||||
else:
|
||||
# 如果没有锁,则执行原来的逻辑
|
||||
result = await super().update_or_create(
|
||||
defaults=defaults, using_db=using_db, **kwargs
|
||||
)
|
||||
if cache_type := cls.get_cache_type():
|
||||
await CacheRoot.reload(cache_type)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
async def bulk_create( # type: ignore
|
||||
@ -167,6 +189,65 @@ class Model(TortoiseModel):
|
||||
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
||||
await CacheRoot.reload(cache_type)
|
||||
|
||||
@classmethod
|
||||
async def safe_get_or_none(
|
||||
cls,
|
||||
*args,
|
||||
using_db: BaseDBAsyncClient | None = None,
|
||||
clean_duplicates: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Self | None:
|
||||
"""安全地获取一条记录或None,处理存在多个记录时返回最新的那个
|
||||
注意,默认会删除重复的记录,仅保留最新的
|
||||
|
||||
参数:
|
||||
*args: 查询参数
|
||||
using_db: 数据库连接
|
||||
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||
**kwargs: 查询参数
|
||||
|
||||
返回:
|
||||
Self | None: 查询结果,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
# 先尝试使用 get_or_none 获取单个记录
|
||||
return await cls.get_or_none(*args, using_db=using_db, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果出现错误(可能是存在多个记录)
|
||||
if "Multiple objects" in str(e):
|
||||
logger.warning(
|
||||
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}"
|
||||
)
|
||||
# 查询所有匹配记录
|
||||
records = await cls.filter(*args, **kwargs)
|
||||
|
||||
if not records:
|
||||
return None
|
||||
|
||||
# 如果需要清理重复记录
|
||||
if clean_duplicates and hasattr(cls, "id"):
|
||||
# 按 id 排序
|
||||
records = sorted(
|
||||
records, key=lambda x: getattr(x, "id", 0), reverse=True
|
||||
)
|
||||
for record in records[1:]:
|
||||
try:
|
||||
await record.delete()
|
||||
logger.info(
|
||||
f"{cls.__name__} 删除重复记录:"
|
||||
f" id={getattr(record, 'id', None)}"
|
||||
)
|
||||
except Exception as del_e:
|
||||
logger.error(f"删除重复记录失败: {del_e}")
|
||||
return records[0]
|
||||
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||||
if hasattr(cls, "id"):
|
||||
return await cls.filter(*args, **kwargs).order_by("-id").first()
|
||||
# 如果没有 id 字段,则返回第一个记录
|
||||
return await cls.filter(*args, **kwargs).first()
|
||||
# 其他类型的错误则继续抛出
|
||||
raise
|
||||
|
||||
|
||||
class DbUrlIsNode(HookPriorityException):
|
||||
"""
|
||||
|
||||
@ -53,9 +53,7 @@ class CommonUtils:
|
||||
if await GroupConsole.is_block_task(group_id, module):
|
||||
"""群组是否禁用被动"""
|
||||
return True
|
||||
if g := await GroupConsole.get_or_none(
|
||||
group_id=group_id, channel_id__isnull=True
|
||||
):
|
||||
if g := await GroupConsole.get_group(group_id=group_id):
|
||||
"""群组权限是否小于0"""
|
||||
if g.level < 0:
|
||||
return True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user