增强缓存功能,优化请求管理逻辑 (#2012)
Some checks failed
检查bot是否运行正常 / bot check (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, javascript-typescript) (push) Has been cancelled
CodeQL Code Security Analysis / Analyze (${{ matrix.language }}) (none, python) (push) Has been cancelled
Sequential Lint and Type Check / ruff-call (push) Has been cancelled
Release Drafter / Update Release Draft (push) Has been cancelled
Force Sync to Aliyun / sync (push) Has been cancelled
Update Version / update-version (push) Has been cancelled
Sequential Lint and Type Check / pyright-call (push) Has been cancelled

- 在 `record_request.py` 和 `group_handle/__init__.py` 中引入了 `CacheRoot`,实现请求缓存,避免重复处理相同请求。
- 在 `exception.py` 中更新 `ForceAddGroupError` 类,新增 `group_id` 属性以便于错误处理。
- 在 `data_source.py` 中修改 `ForceAddGroupError` 的抛出逻辑,包含 `group_id` 信息。
- 更新 `cache` 类,支持类型化缓存字典和列表,增强缓存的类型安全性。

此更新提升了请求处理的效率和准确性,同时增强了错误信息的可追溯性。
This commit is contained in:
HibiKier 2025-08-06 16:31:09 +08:00 committed by GitHub
parent be86e0bb7f
commit 3deffcb46c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 146 additions and 63 deletions

View File

@ -150,6 +150,10 @@ async def get_plugin_and_user(
) )
except IntegrityError: except IntegrityError:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
plugin_task = plugin_dao.safe_get_or_none(module=module)
user_task = user_dao.get_by_func_or_none(
UserConsole.get_user, False, user_id=user_id
)
plugin, user = await with_timeout( plugin, user = await with_timeout(
asyncio.gather(plugin_task, user_task), name="get_plugin_and_user" asyncio.gather(plugin_task, user_task), name="get_plugin_and_user"
) )

View File

@ -3,9 +3,13 @@ class ForceAddGroupError(Exception):
强制拉群 强制拉群
""" """
def __init__(self, info: str): def __init__(self, info: str, group_id: str):
super().__init__(self) super().__init__(self)
self._info = info self._info = info
self._group_id = group_id
def get_info(self) -> str: def get_info(self) -> str:
return self._info return self._info
def get_group_id(self) -> str:
return self._group_id

View File

@ -16,6 +16,7 @@ from zhenxun.configs.config import BotConfig, Config
from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task from zhenxun.configs.utils import PluginExtraData, RegisterConfig, Task
from zhenxun.models.event_log import EventLog from zhenxun.models.event_log import EventLog
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.services.cache import CacheRoot
from zhenxun.utils.common_utils import CommonUtils from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import EventLogType, PluginType from zhenxun.utils.enum import EventLogType, PluginType
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
@ -92,6 +93,10 @@ group_decrease_handle = on_notice(
) )
"""群员减少处理""" """群员减少处理"""
cache = CacheRoot.cache_dict(
"REQUEST_CACHE", (base_config.get("TIP_MESSAGE_LIMIT") or 360) * 60, str
)
@group_increase_handle.handle() @group_increase_handle.handle()
async def _( async def _(
@ -109,7 +114,9 @@ async def _(
bot, str(event.operator_id), str(event.group_id), group bot, str(event.operator_id), str(event.group_id), group
) )
except ForceAddGroupError as e: except ForceAddGroupError as e:
await PlatformUtils.send_superuser(bot, e.get_info()) if not cache.get(e.get_group_id()):
cache.set(e.get_group_id(), "1")
await PlatformUtils.send_superuser(bot, e.get_info())
else: else:
await GroupManager.add_user(session, bot) await GroupManager.add_user(session, bot)

View File

@ -147,9 +147,13 @@ class GroupManager:
group_id=group_id, group_id=group_id,
e=e, e=e,
) )
raise ForceAddGroupError("强制拉群或未有群信息,退出群聊失败...") from e raise ForceAddGroupError(
"强制拉群或未有群信息,退出群聊失败...", group_id
) from e
# await GroupConsole.filter(group_id=group_id).delete() # await GroupConsole.filter(group_id=group_id).delete()
raise ForceAddGroupError(f"触发强制入群保护,已成功退出群聊 {group_id}...") raise ForceAddGroupError(
f"触发强制入群保护,已成功退出群聊 {group_id}...", group_id
)
else: else:
await cls.__handle_add_group(bot, group_id, group) await cls.__handle_add_group(bot, group_id, group)
"""刷新群管理员权限""" """刷新群管理员权限"""

View File

@ -21,6 +21,7 @@ from zhenxun.models.event_log import EventLog
from zhenxun.models.fg_request import FgRequest from zhenxun.models.fg_request import FgRequest
from zhenxun.models.friend_user import FriendUser from zhenxun.models.friend_user import FriendUser
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.services.cache import CacheRoot
from zhenxun.services.log import logger from zhenxun.services.log import logger
from zhenxun.utils.enum import EventLogType, PluginType, RequestHandleType, RequestType from zhenxun.utils.enum import EventLogType, PluginType, RequestHandleType, RequestType
from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.platform import PlatformUtils
@ -52,6 +53,14 @@ __plugin_meta__ = PluginMetadata(
type=bool, type=bool,
default_value=False, default_value=False,
), ),
RegisterConfig(
module="invite_manager",
key="TIP_MESSAGE_LIMIT",
value=360,
help="重复申请与退群提醒过滤时间(分钟)",
type=int,
default_value=360,
),
], ],
).to_dict(), ).to_dict(),
) )
@ -77,6 +86,11 @@ group_req = on_request(priority=5, block=True)
_t = on_message(priority=999, block=False, rule=lambda: False) _t = on_message(priority=999, block=False, rule=lambda: False)
cache = CacheRoot.cache_dict(
"REQUEST_CACHE", (base_config.get("TIP_MESSAGE_LIMIT") or 360) * 60, str
)
@friend_req.handle() @friend_req.handle()
async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSession): async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSession):
if event.user_id and Timer.check(event.user_id): if event.user_id and Timer.check(event.user_id):
@ -113,22 +127,25 @@ async def _(bot: v12Bot | v11Bot, event: FriendRequestEvent, session: EventSessi
nickname=nickname, nickname=nickname,
comment=comment, comment=comment,
) )
results = await PlatformUtils.send_superuser( cache_key = str(event.user_id)
bot, if not cache.get(cache_key):
f"*****一份好友申请*****\n" cache.set(cache_key, "1")
f"ID: {f.id}\n" results = await PlatformUtils.send_superuser(
f"昵称:{nickname}({event.user_id})\n" bot,
f"自动同意:{'' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n" f"*****一份好友申请*****\n"
f"日期:{datetime.now().replace(microsecond=0)}\n" f"ID: {f.id}\n"
f"备注:{event.comment}", f"昵称:{nickname}({event.user_id})\n"
) f"自动同意:{'' if base_config.get('AUTO_ADD_FRIEND') else '×'}\n"
if message_ids := [ f"日期:{datetime.now().replace(microsecond=0)}\n"
str(r[1].msg_ids[0]["message_id"]) f"备注:{event.comment}",
for r in results )
if r[1] and r[1].msg_ids if message_ids := [
]: str(r[1].msg_ids[0]["message_id"])
f.message_ids = ",".join(message_ids) for r in results
await f.save(update_fields=["message_ids"]) if r[1] and r[1].msg_ids
]:
f.message_ids = ",".join(message_ids)
await f.save(update_fields=["message_ids"])
else: else:
logger.debug("好友请求五分钟内重复, 已忽略", "好友请求", target=event.user_id) logger.debug("好友请求五分钟内重复, 已忽略", "好友请求", target=event.user_id)
@ -210,7 +227,8 @@ async def _(bot: v12Bot | v11Bot, event: GroupRequestEvent, session: EventSessio
"\n在群组中 群组管理员与群主 允许使用管理员帮助" "\n在群组中 群组管理员与群主 允许使用管理员帮助"
"包括ban与功能开关等\n请在群组中发送 '管理员帮助'", "包括ban与功能开关等\n请在群组中发送 '管理员帮助'",
) )
elif Timer.check(f"{event.user_id}:{event.group_id}"): elif cache.get(f"{event.group_id}"):
cache.set(f"{event.group_id}", "1")
logger.debug( logger.debug(
f"收录 用户[{event.user_id}] 群聊[{event.group_id}] 群聊请求", f"收录 用户[{event.user_id}] 群聊[{event.group_id}] 群聊请求",
"群聊请求", "群聊请求",

View File

@ -13,7 +13,7 @@ class BotSetting(BaseModel):
"""回复时NICKNAME""" """回复时NICKNAME"""
system_proxy: str | None = None system_proxy: str | None = None
"""系统代理""" """系统代理"""
db_url: str = "sqlite:data/zhenxun.db" db_url: str = ""
"""数据库链接, 默认值为sqlite:data/zhenxun.db""" """数据库链接, 默认值为sqlite:data/zhenxun.db"""
platform_superusers: dict[str, list[str]] = Field(default_factory=dict) platform_superusers: dict[str, list[str]] = Field(default_factory=dict)
"""平台超级用户""" """平台超级用户"""

View File

@ -54,6 +54,27 @@ message = message_list[0]
# 保存缓存数据(可选) # 保存缓存数据(可选)
await message_list.save() await message_list.save()
``` ```
4. 使用CacheManager的类型化缓存方法
```python
from zhenxun.services.cache import CacheRoot
# 获取字符串类型的缓存字典(向后兼容)
str_cache = CacheRoot.cache_dict("string_cache")
# 获取类型化的缓存字典(推荐)
int_cache = CacheRoot.cache_dict_typed("int_cache", value_type=int)
user_cache = CacheRoot.cache_dict_typed("user_cache", value_type=User)
# 获取类型化的缓存列表
message_list = CacheRoot.cache_list_typed("messages", value_type=str)
user_list = CacheRoot.cache_list_typed("users", value_type=User)
# 使用类型化的缓存
int_cache["count"] = 42 # 类型安全
user_cache["user1"] = User(name="Alice") # 类型安全
message_list.append("Hello") # 类型安全
```
""" """
import asyncio import asyncio
@ -94,6 +115,7 @@ __all__ = [
] ]
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U")
class Config(BaseModel): class Config(BaseModel):
@ -294,16 +316,36 @@ class CacheManager:
self.__class__._enabled = False self.__class__._enabled = False
logger.info("缓存功能已禁用", LOG_COMMAND) logger.info("缓存功能已禁用", LOG_COMMAND)
def cache_dict(self, cache_type: str, expire: int = 0) -> CacheDict: def cache_dict(
"""获取缓存字典""" self, cache_type: str, expire: int = 0, value_type: type[U] = str
) -> CacheDict[U]:
"""获取缓存字典
参数:
cache_type: 缓存类型
expire: 过期时间
value_type: 值类型
返回:
CacheDict: 缓存字典
"""
if cache_type not in self._dict_caches: if cache_type not in self._dict_caches:
self._dict_caches[cache_type] = CacheDict(cache_type, expire) self._dict_caches[cache_type] = CacheDict[value_type](cache_type, expire)
return self._dict_caches[cache_type] return self._dict_caches[cache_type]
def cache_list(self, cache_type: str, expire: int = 0) -> CacheList: def cache_list(
"""获取缓存列表""" self, cache_type: str, expire: int = 0, value_type: type[U] = str
) -> CacheList[U]:
"""获取缓存列表
参数:
cache_type: 缓存类型
expire: 过期时间
value_type: 值类型
返回:
CacheList: 缓存列表
"""
if cache_type not in self._list_caches: if cache_type not in self._list_caches:
self._list_caches[cache_type] = CacheList(cache_type, expire) self._list_caches[cache_type] = CacheList[value_type](cache_type, expire)
return self._list_caches[cache_type] return self._list_caches[cache_type]
def listener(self, cache_type: str): def listener(self, cache_type: str):

View File

@ -13,7 +13,7 @@ class CacheData(Generic[T]):
expire_time: float = 0 # 0表示永不过期 expire_time: float = 0 # 0表示永不过期
class CacheDict: class CacheDict(Generic[T]):
"""缓存字典类,提供类似普通字典的接口,数据只存储在内存中""" """缓存字典类,提供类似普通字典的接口,数据只存储在内存中"""
def __init__(self, name: str, expire: int = 0): def __init__(self, name: str, expire: int = 0):
@ -25,29 +25,32 @@ class CacheDict:
""" """
self.name = name.upper() self.name = name.upper()
self.expire = expire self.expire = expire
self._data: dict[str, CacheData[Any]] = {} self._data: dict[str, CacheData[T]] = {}
def __getitem__(self, key: str) -> Any: def expire_time(self, key: str) -> float:
"""获取字典项的过期时间"""
data = self._data.get(key)
if data is None:
return 0
if data.expire_time > 0 and data.expire_time < time.time():
del self._data[key]
return 0
return data.expire_time
def __getitem__(self, key: str) -> T | None:
"""获取字典项 """获取字典项
参数: 参数:
key: 字典键 key: 字典键
返回: 返回:
Any: 字典值 T: 字典值
""" """
data = self._data.get(key) if value := self._data.get(key):
if data is None: return value.value if self.expire_time(key) else None
return None return None
# 检查是否过期 def __setitem__(self, key: str, value: T) -> None:
if data.expire_time > 0 and data.expire_time < time.time():
del self._data[key]
return None
return data.value
def __setitem__(self, key: str, value: Any) -> None:
"""设置字典项 """设置字典项
参数: 参数:
@ -86,7 +89,7 @@ class CacheDict:
return True return True
def get(self, key: str, default: Any = None) -> Any: def get(self, key: str, default: Any = None) -> T | None:
"""获取字典项,如果不存在返回默认值 """获取字典项,如果不存在返回默认值
参数: 参数:
@ -99,7 +102,7 @@ class CacheDict:
value = self[key] value = self[key]
return default if value is None else value return default if value is None else value
def set(self, key: str, value: Any, expire: int | None = None) -> None: def set(self, key: str, value: Any, expire: int | None = None):
"""设置字典项 """设置字典项
参数: 参数:
@ -116,7 +119,7 @@ class CacheDict:
self._data[key] = CacheData(value=value, expire_time=expire_time) self._data[key] = CacheData(value=value, expire_time=expire_time)
def pop(self, key: str, default: Any = None) -> Any: def pop(self, key: str, default: Any = None) -> T:
"""删除并返回字典项 """删除并返回字典项
参数: 参数:
@ -133,6 +136,7 @@ class CacheDict:
# 检查是否过期 # 检查是否过期
if data.expire_time > 0 and data.expire_time < time.time(): if data.expire_time > 0 and data.expire_time < time.time():
del self._data[key]
return default return default
return data.value return data.value
@ -161,7 +165,7 @@ class CacheDict:
self._clean_expired() self._clean_expired()
return [data.value for data in self._data.values()] return [data.value for data in self._data.values()]
def items(self) -> list[tuple[str, Any]]: def items(self) -> list[tuple[str, T]]:
"""获取所有键值对 """获取所有键值对
返回: 返回:
@ -171,7 +175,7 @@ class CacheDict:
self._clean_expired() self._clean_expired()
return [(key, data.value) for key, data in self._data.items()] return [(key, data.value) for key, data in self._data.items()]
def _clean_expired(self) -> None: def _clean_expired(self):
"""清理过期的键""" """清理过期的键"""
now = time.time() now = time.time()
expired_keys = [ expired_keys = [
@ -203,7 +207,7 @@ class CacheDict:
return f"CacheDict({self.name}, {len(self._data)} items)" return f"CacheDict({self.name}, {len(self._data)} items)"
class CacheList: class CacheList(Generic[T]):
"""缓存列表类,提供类似普通列表的接口,数据只存储在内存中""" """缓存列表类,提供类似普通列表的接口,数据只存储在内存中"""
def __init__(self, name: str, expire: int = 0): def __init__(self, name: str, expire: int = 0):
@ -215,21 +219,21 @@ class CacheList:
""" """
self.name = name.upper() self.name = name.upper()
self.expire = expire self.expire = expire
self._data: list[CacheData[Any]] = [] self._data: list[CacheData[T]] = []
self._expire_time = 0 self._expire_time = 0
# 如果设置了过期时间,计算整个列表的过期时间 # 如果设置了过期时间,计算整个列表的过期时间
if self.expire > 0: if self.expire > 0:
self._expire_time = time.time() + self.expire self._expire_time = time.time() + self.expire
def __getitem__(self, index: int) -> Any: def __getitem__(self, index: int) -> T:
"""获取列表项 """获取列表项
参数: 参数:
index: 列表索引 index: 列表索引
返回: 返回:
Any: 列表值 T: 列表值
""" """
# 检查整个列表是否过期 # 检查整个列表是否过期
if self._is_expired(): if self._is_expired():
@ -240,7 +244,7 @@ class CacheList:
return self._data[index].value return self._data[index].value
raise IndexError(f"列表索引 {index} 超出范围") raise IndexError(f"列表索引 {index} 超出范围")
def __setitem__(self, index: int, value: Any) -> None: def __setitem__(self, index: int, value: T):
"""设置列表项 """设置列表项
参数: 参数:
@ -253,13 +257,13 @@ class CacheList:
# 确保索引有效 # 确保索引有效
while len(self._data) <= index: while len(self._data) <= index:
self._data.append(CacheData(value=None)) raise IndexError(f"列表索引 {index} 超出范围")
self._data[index] = CacheData(value=value) self._data[index] = CacheData(value=value)
# 更新过期时间 # 更新过期时间
self._update_expire_time() self._update_expire_time()
def __delitem__(self, index: int) -> None: def __delitem__(self, index: int):
"""删除列表项 """删除列表项
参数: 参数:
@ -287,7 +291,7 @@ class CacheList:
self.clear() self.clear()
return len(self._data) return len(self._data)
def append(self, value: Any) -> None: def append(self, value: T):
"""添加列表项 """添加列表项
参数: 参数:
@ -302,7 +306,7 @@ class CacheList:
# 更新过期时间 # 更新过期时间
self._update_expire_time() self._update_expire_time()
def extend(self, values: list[Any]) -> None: def extend(self, values: list[T]):
"""扩展列表 """扩展列表
参数: 参数:
@ -317,7 +321,7 @@ class CacheList:
# 更新过期时间 # 更新过期时间
self._update_expire_time() self._update_expire_time()
def insert(self, index: int, value: Any) -> None: def insert(self, index: int, value: T):
"""插入列表项 """插入列表项
参数: 参数:
@ -333,7 +337,7 @@ class CacheList:
# 更新过期时间 # 更新过期时间
self._update_expire_time() self._update_expire_time()
def pop(self, index: int = -1) -> Any: def pop(self, index: int = -1) -> T:
"""删除并返回列表项 """删除并返回列表项
参数: 参数:
@ -357,7 +361,7 @@ class CacheList:
return item.value return item.value
def remove(self, value: Any) -> None: def remove(self, value: T):
"""删除第一个匹配的列表项 """删除第一个匹配的列表项
参数: 参数:
@ -384,7 +388,7 @@ class CacheList:
# 重置过期时间 # 重置过期时间
self._update_expire_time() self._update_expire_time()
def index(self, value: Any, start: int = 0, end: int | None = None) -> int: def index(self, value: T, start: int = 0, end: int | None = None) -> int:
"""查找值的索引 """查找值的索引
参数: 参数:
@ -408,7 +412,7 @@ class CacheList:
raise ValueError(f"{value} 不在列表中") raise ValueError(f"{value} 不在列表中")
def count(self, value: Any) -> int: def count(self, value: T) -> int:
"""计算值出现的次数 """计算值出现的次数
参数: 参数:
@ -429,7 +433,7 @@ class CacheList:
"""检查整个列表是否过期""" """检查整个列表是否过期"""
return self._expire_time > 0 and self._expire_time < time.time() return self._expire_time > 0 and self._expire_time < time.time()
def _update_expire_time(self) -> None: def _update_expire_time(self):
"""更新过期时间""" """更新过期时间"""
self._expire_time = time.time() + self.expire if self.expire > 0 else 0 self._expire_time = time.time() + self.expire if self.expire > 0 else 0