mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⚡ :增强获取群组的安全性和准确性。同时,优化了缓存管理中的相关逻辑,确保缓存操作的一致性。
This commit is contained in:
parent
61a9b74558
commit
251299c56c
@ -114,7 +114,7 @@ class BanManage:
|
|||||||
if not is_superuser and user_id and session.id1:
|
if not is_superuser and user_id and session.id1:
|
||||||
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
user_level = await LevelUser.get_user_level(session.id1, group_id)
|
||||||
if idx:
|
if idx:
|
||||||
ban_data = await BanConsole.get_or_none(id=idx)
|
ban_data = await BanConsole.get_ban(id=idx)
|
||||||
if not ban_data:
|
if not ban_data:
|
||||||
return False, "该用户/群组不在黑名单中捏..."
|
return False, "该用户/群组不在黑名单中捏..."
|
||||||
if ban_data.ban_level > user_level:
|
if ban_data.ban_level > user_level:
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
from zhenxun.configs.path_config import DATA_PATH, IMAGE_PATH
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.task_info import TaskInfo
|
from zhenxun.models.task_info import TaskInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.cache import Cache
|
||||||
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
from zhenxun.utils.enum import BlockType, CacheType, PluginType
|
||||||
from zhenxun.utils.exception import GroupInfoNotFound
|
from zhenxun.utils.exception import GroupInfoNotFound
|
||||||
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
from zhenxun.utils.image_utils import BuildImage, ImageTemplate, RowStyle
|
||||||
@ -117,9 +119,7 @@ async def build_task(group_id: str | None) -> BuildImage:
|
|||||||
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
column_name = ["ID", "模块", "名称", "群组状态", "全局状态", "运行时间"]
|
||||||
group = None
|
group = None
|
||||||
if group_id:
|
if group_id:
|
||||||
group = await GroupConsole.get_or_none(
|
group = await GroupConsole.get_group(group_id=group_id)
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
)
|
|
||||||
if not group:
|
if not group:
|
||||||
raise GroupInfoNotFound()
|
raise GroupInfoNotFound()
|
||||||
else:
|
else:
|
||||||
@ -201,20 +201,19 @@ class PluginManage:
|
|||||||
)
|
)
|
||||||
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
return f"成功将所有功能进群默认状态修改为: {'开启' if status else '关闭'}"
|
||||||
if group_id:
|
if group_id:
|
||||||
if group := await GroupConsole.get_or_none(
|
if group := await GroupConsole.get_group(group_id=group_id):
|
||||||
group_id=group_id, channel_id__isnull=True
|
module_list = cast(
|
||||||
):
|
list[str],
|
||||||
module_list = await PluginInfo.filter(
|
await PluginInfo.filter(plugin_type=PluginType.NORMAL).values_list(
|
||||||
plugin_type=PluginType.NORMAL
|
"module", flat=True
|
||||||
).values_list("module", flat=True)
|
),
|
||||||
|
)
|
||||||
if status:
|
if status:
|
||||||
for module in module_list:
|
# 开启所有功能 - 清空禁用列表
|
||||||
group.block_plugin = group.block_plugin.replace(
|
group.block_plugin = ""
|
||||||
f"<{module},", ""
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
module_list = [f"<{module}" for module in module_list]
|
# 关闭所有功能 - 将模块列表转换为禁用格式
|
||||||
group.block_plugin = ",".join(module_list) + "," # type: ignore
|
group.block_plugin = CommonUtils.convert_module_format(module_list)
|
||||||
await group.save(update_fields=["block_plugin"])
|
await group.save(update_fields=["block_plugin"])
|
||||||
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
return f"成功将此群组所有功能状态修改为: {'开启' if status else '关闭'}"
|
||||||
return "获取群组失败..."
|
return "获取群组失败..."
|
||||||
@ -235,9 +234,7 @@ class PluginManage:
|
|||||||
返回:
|
返回:
|
||||||
bool: 是否醒来
|
bool: 是否醒来
|
||||||
"""
|
"""
|
||||||
if c := await GroupConsole.get_or_none(
|
if c := await GroupConsole.get_group(group_id=group_id):
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
):
|
|
||||||
return c.status
|
return c.status
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -428,17 +425,18 @@ class PluginManage:
|
|||||||
"""
|
"""
|
||||||
status_str = "关闭" if status else "开启"
|
status_str = "关闭" if status else "开启"
|
||||||
if is_all:
|
if is_all:
|
||||||
modules = await TaskInfo.annotate().values_list("module", flat=True)
|
module_list = cast(
|
||||||
if modules:
|
list[str], await TaskInfo.annotate().values_list("module", flat=True)
|
||||||
|
)
|
||||||
|
if module_list:
|
||||||
group, _ = await GroupConsole.get_or_create(
|
group, _ = await GroupConsole.get_or_create(
|
||||||
group_id=group_id, channel_id__isnull=True
|
group_id=group_id, channel_id__isnull=True
|
||||||
)
|
)
|
||||||
modules = [f"<{module}" for module in modules]
|
|
||||||
if status:
|
if status:
|
||||||
group.block_task = ",".join(modules) + "," # type: ignore
|
group.block_task = CommonUtils.convert_module_format(module_list)
|
||||||
else:
|
else:
|
||||||
for module in modules:
|
# 开启所有模块 - 清空禁用列表
|
||||||
group.block_task = group.block_task.replace(f"{module},", "")
|
group.block_task = ""
|
||||||
await group.save(update_fields=["block_task"])
|
await group.save(update_fields=["block_task"])
|
||||||
return f"已成功{status_str}全部被动技能!"
|
return f"已成功{status_str}全部被动技能!"
|
||||||
elif task := await TaskInfo.get_or_none(name=task_name):
|
elif task := await TaskInfo.get_or_none(name=task_name):
|
||||||
|
|||||||
@ -45,7 +45,7 @@ async def classify_plugin(
|
|||||||
"""
|
"""
|
||||||
sort_data = await sort_type()
|
sort_data = await sort_type()
|
||||||
classify: dict[str, list] = {}
|
classify: dict[str, list] = {}
|
||||||
group = await GroupConsole.get_or_none(group_id=group_id) if group_id else None
|
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||||
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
bot = await BotConsole.get_or_none(bot_id=session.self_id)
|
||||||
for menu, value in sort_data.items():
|
for menu, value in sort_data.items():
|
||||||
for plugin in value:
|
for plugin in value:
|
||||||
|
|||||||
@ -45,7 +45,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
|||||||
color="black" if idx % 2 else "white",
|
color="black" if idx % 2 else "white",
|
||||||
)
|
)
|
||||||
curr_h = 10
|
curr_h = 10
|
||||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
group = await GroupConsole.get_group(group_id=group_id) if group_id else None
|
||||||
for _, plugin in enumerate(plugin_list):
|
for _, plugin in enumerate(plugin_list):
|
||||||
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
text_color = (255, 255, 255) if idx % 2 else (0, 0, 0)
|
||||||
if group and f"{plugin.module}," in group.block_plugin:
|
if group and f"{plugin.module}," in group.block_plugin:
|
||||||
@ -80,7 +80,7 @@ async def build_normal_image(group_id: str | None, is_detail: bool) -> BuildImag
|
|||||||
width, height = 10, 10
|
width, height = 10, 10
|
||||||
for s in [
|
for s in [
|
||||||
"目前支持的功能列表:",
|
"目前支持的功能列表:",
|
||||||
"可以通过 ‘帮助 [功能名称或功能Id]’ 来获取对应功能的使用方法",
|
"可以通过 '帮助 [功能名称或功能Id]' 来获取对应功能的使用方法",
|
||||||
]:
|
]:
|
||||||
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
text = await BuildImage.build_text_image(s, "HYWenHei-85W.ttf", 24)
|
||||||
await result.paste(text, (width, height))
|
await result.paste(text, (width, height))
|
||||||
|
|||||||
@ -45,9 +45,7 @@ class StatisticsManage:
|
|||||||
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
title = f"{user.user_name if user else user_id} {day_type}功能调用统计"
|
||||||
elif group_id:
|
elif group_id:
|
||||||
"""查群组"""
|
"""查群组"""
|
||||||
group = await GroupConsole.get_or_none(
|
group = await GroupConsole.get_group(group_id=group_id)
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
)
|
|
||||||
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
title = f"{group.group_name if group else group_id} {day_type}功能调用统计"
|
||||||
else:
|
else:
|
||||||
title = "功能调用统计"
|
title = "功能调用统计"
|
||||||
|
|||||||
@ -163,7 +163,7 @@ async def _(session: EventSession, arparma: Arparma, state: T_State, level: int)
|
|||||||
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
@_matcher.assign("super-handle", parameterless=[CheckGroupId()])
|
||||||
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
async def _(session: EventSession, arparma: Arparma, state: T_State):
|
||||||
gid = state["group_id"]
|
gid = state["group_id"]
|
||||||
group = await GroupConsole.get_or_none(group_id=gid)
|
group = await GroupConsole.get_group(group_id=gid)
|
||||||
if not group:
|
if not group:
|
||||||
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
await MessageUtils.build_message("群组信息不存在, 请更新群组信息...").finish()
|
||||||
s = "删除" if arparma.find("delete") else "添加"
|
s = "删除" if arparma.find("delete") else "添加"
|
||||||
|
|||||||
@ -250,7 +250,7 @@ class ApiDataSource:
|
|||||||
返回:
|
返回:
|
||||||
GroupDetail | None: 群组详情数据
|
GroupDetail | None: 群组详情数据
|
||||||
"""
|
"""
|
||||||
group = await GroupConsole.get_or_none(group_id=group_id)
|
group = await GroupConsole.get_group(group_id=group_id)
|
||||||
if not group:
|
if not group:
|
||||||
return None
|
return None
|
||||||
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
like_plugin = await cls.__get_group_detail_like_plugin(group_id)
|
||||||
|
|||||||
@ -45,10 +45,10 @@ async def _(path: str | None = None) -> Result[list[DirFile]]:
|
|||||||
mtime=file_path.stat().st_mtime,
|
mtime=file_path.stat().st_mtime,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sorted(data_list, key=lambda f: f.name)
|
data_list.sort(key=lambda f: f.name)
|
||||||
return Result.ok(data_list)
|
return Result.ok(data_list)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return Result.fail(f"获取文件列表失败: {e!s}")
|
return Result.fail(f"获取文件列表失败: {e!s}")
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@ -54,12 +54,12 @@ class BanConsole(Model):
|
|||||||
raise UserAndGroupIsNone()
|
raise UserAndGroupIsNone()
|
||||||
if user_id:
|
if user_id:
|
||||||
return (
|
return (
|
||||||
await cls.get_or_none(user_id=user_id, group_id=group_id)
|
await cls.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||||
if group_id
|
if group_id
|
||||||
else await cls.get_or_none(user_id=user_id, group_id__isnull=True)
|
else await cls.safe_get_or_none(user_id=user_id, group_id__isnull=True)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await cls.get_or_none(user_id="", group_id=group_id)
|
return await cls.safe_get_or_none(user_id="", group_id=group_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def check_ban_level(
|
async def check_ban_level(
|
||||||
@ -175,3 +175,25 @@ class BanConsole(Model):
|
|||||||
await user.delete()
|
await user.delete()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_ban(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
id: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
group_id: str | None = None,
|
||||||
|
) -> Self | None:
|
||||||
|
"""安全地获取ban记录
|
||||||
|
|
||||||
|
参数:
|
||||||
|
id: 记录id
|
||||||
|
user_id: 用户id
|
||||||
|
group_id: 群组id
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Self | None: ban记录
|
||||||
|
"""
|
||||||
|
if id is not None:
|
||||||
|
return await cls.safe_get_or_none(id=id)
|
||||||
|
return await cls._get_data(user_id, group_id)
|
||||||
|
|||||||
@ -218,20 +218,32 @@ class GroupConsole(Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_group(
|
async def get_group(
|
||||||
cls, group_id: str, channel_id: str | None = None
|
cls,
|
||||||
|
group_id: str,
|
||||||
|
channel_id: str | None = None,
|
||||||
|
clean_duplicates: bool = True,
|
||||||
) -> Self | None:
|
) -> Self | None:
|
||||||
"""获取群组
|
"""获取群组
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
channel_id: 频道id.
|
channel_id: 频道id
|
||||||
|
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
Self: GroupConsole
|
Self: GroupConsole
|
||||||
"""
|
"""
|
||||||
if channel_id:
|
if channel_id:
|
||||||
return await cls.get_or_none(group_id=group_id, channel_id=channel_id)
|
return await cls.safe_get_or_none(
|
||||||
return await cls.get_or_none(group_id=group_id, channel_id__isnull=True)
|
group_id=group_id,
|
||||||
|
channel_id=channel_id,
|
||||||
|
clean_duplicates=clean_duplicates,
|
||||||
|
)
|
||||||
|
return await cls.safe_get_or_none(
|
||||||
|
group_id=group_id,
|
||||||
|
channel_id__isnull=True,
|
||||||
|
clean_duplicates=clean_duplicates,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def is_super_group(cls, group_id: str) -> bool:
|
async def is_super_group(cls, group_id: str) -> bool:
|
||||||
|
|||||||
@ -121,7 +121,7 @@ class CacheData(BaseModel):
|
|||||||
lazy_load: bool = True # 默认延迟加载
|
lazy_load: bool = True # 默认延迟加载
|
||||||
result_model: type | None = None
|
result_model: type | None = None
|
||||||
_keys: set[str] = set() # 存储所有缓存键
|
_keys: set[str] = set() # 存储所有缓存键
|
||||||
_cache: BaseCache | AioCache
|
cache: BaseCache | AioCache
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -208,7 +208,7 @@ class CacheData(BaseModel):
|
|||||||
async def get_data(self) -> Any:
|
async def get_data(self) -> Any:
|
||||||
"""从缓存获取数据"""
|
"""从缓存获取数据"""
|
||||||
try:
|
try:
|
||||||
data = await self._cache.get(self.name)
|
data = await self.cache.get(self.name) # type: ignore
|
||||||
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
||||||
|
|
||||||
# 如果数据为空,尝试重新加载
|
# 如果数据为空,尝试重新加载
|
||||||
@ -307,11 +307,11 @@ class CacheData(BaseModel):
|
|||||||
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
|
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
|
||||||
|
|
||||||
# 2. 删除旧数据
|
# 2. 删除旧数据
|
||||||
await self._cache.delete(self.name)
|
await self.cache.delete(self.name) # type: ignore
|
||||||
logger.debug(f"删除缓存 {self.name} 旧数据")
|
logger.debug(f"删除缓存 {self.name} 旧数据")
|
||||||
|
|
||||||
# 3. 设置新数据
|
# 3. 设置新数据
|
||||||
await self._cache.set(self.name, serialized_value, ttl=self.expire)
|
await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore
|
||||||
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -321,7 +321,7 @@ class CacheData(BaseModel):
|
|||||||
async def delete_data(self):
|
async def delete_data(self):
|
||||||
"""删除缓存数据"""
|
"""删除缓存数据"""
|
||||||
try:
|
try:
|
||||||
await self._cache.delete(self.name)
|
await self.cache.delete(self.name) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除缓存 {self.name}", e=e)
|
logger.error(f"删除缓存 {self.name}", e=e)
|
||||||
|
|
||||||
@ -428,7 +428,7 @@ class CacheData(BaseModel):
|
|||||||
"""
|
"""
|
||||||
cache_key = self._get_cache_key(key)
|
cache_key = self._get_cache_key(key)
|
||||||
try:
|
try:
|
||||||
data = await self._cache.get(cache_key)
|
data = await self.cache.get(cache_key) # type: ignore
|
||||||
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
|
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
|
||||||
|
|
||||||
if self.result_model:
|
if self.result_model:
|
||||||
@ -467,7 +467,7 @@ class CacheData(BaseModel):
|
|||||||
for key in self._keys:
|
for key in self._keys:
|
||||||
# 提取原始键名(去掉前缀)
|
# 提取原始键名(去掉前缀)
|
||||||
original_key = key.split(":", 1)[1]
|
original_key = key.split(":", 1)[1]
|
||||||
data = await self._cache.get(key)
|
data = await self.cache.get(key) # type: ignore
|
||||||
if self.result_model:
|
if self.result_model:
|
||||||
result[original_key] = self._deserialize_value(
|
result[original_key] = self._deserialize_value(
|
||||||
data, self.result_model
|
data, self.result_model
|
||||||
@ -484,7 +484,7 @@ class CacheData(BaseModel):
|
|||||||
cache_key = self._get_cache_key(key)
|
cache_key = self._get_cache_key(key)
|
||||||
try:
|
try:
|
||||||
serialized_value = self._serialize_value(value)
|
serialized_value = self._serialize_value(value)
|
||||||
await self._cache.set(cache_key, serialized_value, ttl=self.expire)
|
await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore
|
||||||
self._keys.add(cache_key) # 添加到键列表
|
self._keys.add(cache_key) # 添加到键列表
|
||||||
logger.debug(f"设置缓存 {cache_key} 数据完成")
|
logger.debug(f"设置缓存 {cache_key} 数据完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -495,7 +495,7 @@ class CacheData(BaseModel):
|
|||||||
"""删除指定键的缓存数据"""
|
"""删除指定键的缓存数据"""
|
||||||
cache_key = self._get_cache_key(key)
|
cache_key = self._get_cache_key(key)
|
||||||
try:
|
try:
|
||||||
await self._cache.delete(cache_key)
|
await self.cache.delete(cache_key) # type: ignore
|
||||||
self._keys.discard(cache_key) # 从键列表中移除
|
self._keys.discard(cache_key) # 从键列表中移除
|
||||||
logger.debug(f"删除缓存 {cache_key} 完成")
|
logger.debug(f"删除缓存 {cache_key} 完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -505,7 +505,7 @@ class CacheData(BaseModel):
|
|||||||
"""清除所有缓存数据"""
|
"""清除所有缓存数据"""
|
||||||
try:
|
try:
|
||||||
for key in list(self._keys): # 使用列表复制避免在迭代时修改
|
for key in list(self._keys): # 使用列表复制避免在迭代时修改
|
||||||
await self._cache.delete(key)
|
await self.cache.delete(key) # type: ignore
|
||||||
self._keys.clear()
|
self._keys.clear()
|
||||||
logger.debug(f"清除缓存 {self.name} 完成")
|
logger.debug(f"清除缓存 {self.name} 完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -524,7 +524,7 @@ class CacheManager:
|
|||||||
if self._cache_instance is None:
|
if self._cache_instance is None:
|
||||||
if config.redis_host:
|
if config.redis_host:
|
||||||
self._cache_instance = AioCache(
|
self._cache_instance = AioCache(
|
||||||
AioCache.REDIS,
|
AioCache.REDIS, # type: ignore
|
||||||
serializer=JsonSerializer(),
|
serializer=JsonSerializer(),
|
||||||
namespace="zhenxun_cache",
|
namespace="zhenxun_cache",
|
||||||
timeout=30, # 操作超时时间
|
timeout=30, # 操作超时时间
|
||||||
@ -546,12 +546,12 @@ class CacheManager:
|
|||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
if self._cache_instance:
|
if self._cache_instance:
|
||||||
await self._cache_instance.close()
|
await self._cache_instance.close() # type: ignore
|
||||||
|
|
||||||
async def verify_connection(self):
|
async def verify_connection(self):
|
||||||
"""连接测试"""
|
"""连接测试"""
|
||||||
try:
|
try:
|
||||||
await self._cache.get("__test__")
|
await self._cache.get("__test__") # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("连接失败", LOG_COMMAND, e=e)
|
logger.error("连接失败", LOG_COMMAND, e=e)
|
||||||
raise
|
raise
|
||||||
@ -560,7 +560,7 @@ class CacheManager:
|
|||||||
"""初始化所有非延迟加载的缓存"""
|
"""初始化所有非延迟加载的缓存"""
|
||||||
await self.verify_connection()
|
await self.verify_connection()
|
||||||
for name, cache in self._data.items():
|
for name, cache in self._data.items():
|
||||||
cache._cache = self._cache
|
cache.cache = self._cache
|
||||||
if not cache.lazy_load:
|
if not cache.lazy_load:
|
||||||
try:
|
try:
|
||||||
await cache.reload()
|
await cache.reload()
|
||||||
@ -587,7 +587,7 @@ class CacheManager:
|
|||||||
func=func,
|
func=func,
|
||||||
expire=expire,
|
expire=expire,
|
||||||
lazy_load=lazy_load,
|
lazy_load=lazy_load,
|
||||||
_cache=self._cache,
|
cache=self._cache,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|||||||
@ -72,12 +72,23 @@ class Model(TortoiseModel):
|
|||||||
using_db: BaseDBAsyncClient | None = None,
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[Self, bool]:
|
) -> tuple[Self, bool]:
|
||||||
result, is_create = await super().get_or_create(
|
if sem := cls.get_semaphore(DbLockType.CREATE):
|
||||||
defaults=defaults, using_db=using_db, **kwargs
|
async with sem:
|
||||||
)
|
# 在锁内执行查询和创建操作
|
||||||
if is_create and (cache_type := cls.get_cache_type()):
|
result, is_create = await super().get_or_create(
|
||||||
await CacheRoot.reload(cache_type)
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
return (result, is_create)
|
)
|
||||||
|
if is_create and (cache_type := cls.get_cache_type()):
|
||||||
|
await CacheRoot.reload(cache_type)
|
||||||
|
return (result, is_create)
|
||||||
|
else:
|
||||||
|
# 如果没有锁,则执行原来的逻辑
|
||||||
|
result, is_create = await super().get_or_create(
|
||||||
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
|
)
|
||||||
|
if is_create and (cache_type := cls.get_cache_type()):
|
||||||
|
await CacheRoot.reload(cache_type)
|
||||||
|
return (result, is_create)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def update_or_create(
|
async def update_or_create(
|
||||||
@ -86,12 +97,23 @@ class Model(TortoiseModel):
|
|||||||
using_db: BaseDBAsyncClient | None = None,
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> tuple[Self, bool]:
|
) -> tuple[Self, bool]:
|
||||||
result = await super().update_or_create(
|
if sem := cls.get_semaphore(DbLockType.CREATE):
|
||||||
defaults=defaults, using_db=using_db, **kwargs
|
async with sem:
|
||||||
)
|
# 在锁内执行查询和创建操作
|
||||||
if cache_type := cls.get_cache_type():
|
result = await super().update_or_create(
|
||||||
await CacheRoot.reload(cache_type)
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
return result
|
)
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
await CacheRoot.reload(cache_type)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# 如果没有锁,则执行原来的逻辑
|
||||||
|
result = await super().update_or_create(
|
||||||
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
|
)
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
await CacheRoot.reload(cache_type)
|
||||||
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def bulk_create( # type: ignore
|
async def bulk_create( # type: ignore
|
||||||
@ -167,6 +189,65 @@ class Model(TortoiseModel):
|
|||||||
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
||||||
await CacheRoot.reload(cache_type)
|
await CacheRoot.reload(cache_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def safe_get_or_none(
|
||||||
|
cls,
|
||||||
|
*args,
|
||||||
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
|
clean_duplicates: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Self | None:
|
||||||
|
"""安全地获取一条记录或None,处理存在多个记录时返回最新的那个
|
||||||
|
注意,默认会删除重复的记录,仅保留最新的
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
using_db: 数据库连接
|
||||||
|
clean_duplicates: 是否删除重复的记录,仅保留最新的
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Self | None: 查询结果,如果不存在返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先尝试使用 get_or_none 获取单个记录
|
||||||
|
return await cls.get_or_none(*args, using_db=using_db, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# 如果出现错误(可能是存在多个记录)
|
||||||
|
if "Multiple objects" in str(e):
|
||||||
|
logger.warning(
|
||||||
|
f"{cls.__name__} safe_get_or_none 发现多个记录: {kwargs}"
|
||||||
|
)
|
||||||
|
# 查询所有匹配记录
|
||||||
|
records = await cls.filter(*args, **kwargs)
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果需要清理重复记录
|
||||||
|
if clean_duplicates and hasattr(cls, "id"):
|
||||||
|
# 按 id 排序
|
||||||
|
records = sorted(
|
||||||
|
records, key=lambda x: getattr(x, "id", 0), reverse=True
|
||||||
|
)
|
||||||
|
for record in records[1:]:
|
||||||
|
try:
|
||||||
|
await record.delete()
|
||||||
|
logger.info(
|
||||||
|
f"{cls.__name__} 删除重复记录:"
|
||||||
|
f" id={getattr(record, 'id', None)}"
|
||||||
|
)
|
||||||
|
except Exception as del_e:
|
||||||
|
logger.error(f"删除重复记录失败: {del_e}")
|
||||||
|
return records[0]
|
||||||
|
# 如果不需要清理或没有 id 字段,则返回最新的记录
|
||||||
|
if hasattr(cls, "id"):
|
||||||
|
return await cls.filter(*args, **kwargs).order_by("-id").first()
|
||||||
|
# 如果没有 id 字段,则返回第一个记录
|
||||||
|
return await cls.filter(*args, **kwargs).first()
|
||||||
|
# 其他类型的错误则继续抛出
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class DbUrlIsNode(HookPriorityException):
|
class DbUrlIsNode(HookPriorityException):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -53,9 +53,7 @@ class CommonUtils:
|
|||||||
if await GroupConsole.is_block_task(group_id, module):
|
if await GroupConsole.is_block_task(group_id, module):
|
||||||
"""群组是否禁用被动"""
|
"""群组是否禁用被动"""
|
||||||
return True
|
return True
|
||||||
if g := await GroupConsole.get_or_none(
|
if g := await GroupConsole.get_group(group_id=group_id):
|
||||||
group_id=group_id, channel_id__isnull=True
|
|
||||||
):
|
|
||||||
"""群组权限是否小于0"""
|
"""群组权限是否小于0"""
|
||||||
if g.level < 0:
|
if g.level < 0:
|
||||||
return True
|
return True
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user