代码性能优化

This commit is contained in:
HibiKier 2025-05-11 22:52:13 +08:00
parent 92d118685d
commit 61f559d605
6 changed files with 106 additions and 52 deletions

View File

@ -1,3 +1,5 @@
import asyncio
from nonebot.adapters import Bot
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At
@ -25,8 +27,17 @@ Config.add_plugin_config(
async def is_ban(user_id: str | None, group_id: str | None) -> int:
if not user_id and not group_id:
return 0
cache = Cache[list[BanConsole]](CacheType.BAN)
results = await cache.get(user_id, group_id) or await cache.get(user_id)
group_user, user = await asyncio.gather(
cache.get(user_id, group_id), cache.get(user_id)
)
results = []
if group_user:
results.extend(group_user)
if user:
results.extend(user)
if not results:
return 0
for result in results:

View File

@ -1,5 +1,6 @@
from typing import ClassVar
import nonebot
from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel
@ -18,6 +19,14 @@ from zhenxun.utils.utils import (
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
driver = nonebot.get_driver()
@driver.on_startup
async def _():
"""初始化限制"""
await LimitManager.init_limit()
class Limit(BaseModel):
limit: PluginLimit
@ -34,6 +43,13 @@ class LimitManager:
block_limit: ClassVar[dict[str, Limit]] = {}
count_limit: ClassVar[dict[str, Limit]] = {}
@classmethod
async def init_limit(cls):
"""初始化限制"""
limit_list = await PluginLimit.filter(status=True).all()
for limit in limit_list:
cls.add_limit(limit)
@classmethod
def add_limit(cls, limit: PluginLimit):
"""添加限制
@ -169,9 +185,7 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
"""
entity = get_entity_ids(session)
if plugin.module not in LimitManager.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
status=True
).all() # type: ignore
limit_list = await PluginLimit.filter(module=plugin.module, status=True).all()
for limit in limit_list:
LimitManager.add_limit(limit)
if entity.user_id:

View File

@ -159,9 +159,9 @@ async def auth(
auth_group(plugin, entity, message),
auth_admin(plugin, session),
auth_plugin(plugin, session, event),
auth_limit(plugin, session),
]
)
await auth_limit(plugin, session)
except SkipPluginException as e:
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
logger.info(str(e), LOGGER_COMMAND, session=session)

View File

@ -11,13 +11,18 @@ from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
try:
from . import __init_cache
from .__init_cache import CacheRoot
except DbCacheException as e:
raise SystemError(f"ERROR{e}")
driver = nonebot.get_driver()
@driver.on_startup
async def _():
await CacheRoot.init_non_lazy_caches()
@driver.on_bot_connect
async def _(bot: Bot):
"""将bot已存在的群组添加群认证

View File

@ -3,7 +3,6 @@ from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.services.log import logger
@ -125,7 +124,7 @@ async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
await cache_data.set_key(user.user_id, user)
@CacheRoot.new(CacheType.LEVEL, False)
@CacheRoot.new(CacheType.LEVEL)
async def _():
"""初始化等级缓存"""
data_list = await LevelUser().all()
@ -152,52 +151,61 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
async def _():
"""初始化封禁缓存"""
data_list = await BanConsole.all()
return {f"{d.user_id or ''}:{d.group_id or ''}": d for d in data_list}
return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list}
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
"""获取封禁缓存"""
key = f"{user_id or ''}:{group_id or ''}"
if not user_id and not group_id:
return []
key = f"{group_id or ''}:{user_id or ''}"
logger.info(f"获取封禁缓存: {key}")
data = await cache_data.get_key(key)
if not data:
if user_id and group_id:
data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
elif user_id:
data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
elif group_id:
data = await BanConsole.filter(
user_id__isnull=True, group_id=group_id
).all()
if data:
await cache_data.set_key(key, data)
return data
if data:
logger.info(f"已存在缓存: {key}:{data}")
# if not data:
# start = time.time()
# if user_id and group_id:
# data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
# elif user_id:
# data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
# elif group_id:
# data = await BanConsole.filter(
# user_id__isnull=True, group_id=group_id
# ).all()
# logger.info(
# f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}"
# )
# if data:
# await cache_data.set_key(key, data)
# return data
return data or []
@CacheRoot.new(CacheType.LIMIT)
async def _():
"""初始化限制缓存"""
data_list = await PluginLimit.filter(status=True).all()
return {data.module: data for data in data_list}
# @CacheRoot.new(CacheType.LIMIT)
# async def _():
# """初始化限制缓存"""
# data_list = await PluginLimit.filter(status=True).all()
# return {data.module: data for data in data_list}
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
async def _(cache_data: CacheData, module: str):
"""获取限制缓存"""
data = await cache_data.get_key(module)
if not data:
if limits := await PluginLimit.filter(module=module, status=True):
await cache_data.set_key(module, limits)
return limits
return data or []
# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
# async def _(cache_data: CacheData, module: str):
# """获取限制缓存"""
# data = await cache_data.get_key(module)
# if not data:
# if limits := await PluginLimit.filter(module=module, status=True):
# await cache_data.set_key(module, limits)
# return limits
# return data or []
@CacheRoot.with_refresh(CacheType.LIMIT)
async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
"""刷新限制缓存"""
if not data:
return
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
for limit in limits:
await cache_data.set_key(limit.module, limit)
# @CacheRoot.with_refresh(CacheType.LIMIT)
# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
# """刷新限制缓存"""
# if not data:
# return
# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
# for limit in limits:
# await cache_data.set_key(limit.module, limit)

View File

@ -98,7 +98,7 @@ class CacheData(BaseModel):
with_refresh: Callable[..., Any] | None = None
expire: int = 600 # 默认10分钟过期
reload_count: int = 0
incremental_update: bool = True
lazy_load: bool = True # 默认延迟加载
_cache_instance: BaseCache | None = None
result_model: type | None = None
_keys: set[str] = set() # 存储所有缓存键
@ -168,12 +168,12 @@ class CacheData(BaseModel):
try:
if hasattr(field, "to_python_value"):
if not field.field_type:
logger.warning(f"字段 {field_name} 类型为空")
logger.debug(f"字段 {field_name} 类型为空")
continue
field_value = field.to_python_value(field_value)
setattr(instance, field_name, field_value)
except Exception as e:
logger.warning(f"设置字段 {field_name} 失败: {e}")
logger.warning(f"设置字段 {field_name} 失败", e=e)
# 设置 _saved_in_db 标志
instance._saved_in_db = True
@ -333,7 +333,7 @@ class CacheData(BaseModel):
async def get(self, key: str, *args, **kwargs) -> Any:
"""获取缓存"""
if not self.reload_count and not self.incremental_update:
if not self.reload_count and not self.lazy_load:
await self.reload(*args, **kwargs)
if not self.getter:
@ -343,7 +343,7 @@ class CacheData(BaseModel):
async def get_all(self, *args, **kwargs) -> dict[str, Any]:
"""获取所有缓存数据"""
if not self.reload_count and not self.incremental_update:
if not self.reload_count and not self.lazy_load:
await self.reload(*args, **kwargs)
if not self.getter:
@ -523,8 +523,24 @@ class CacheManager:
_data: ClassVar[dict[str, CacheData]] = {}
def new(self, name: str, incremental_update: bool = True, expire: int = 600):
"""注册新缓存"""
async def init_non_lazy_caches(self):
"""初始化所有非延迟加载的缓存"""
for name, cache in self._data.items():
if not cache.lazy_load:
try:
await cache.reload()
logger.info(f"初始化缓存 {name} 完成")
except Exception as e:
logger.error(f"初始化缓存 {name} 失败: {e}")
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
"""注册新缓存
Args:
name: 缓存名称
lazy_load: 是否延迟加载默认为True为False时会在程序启动时自动加载
expire: 过期时间
"""
def wrapper(func: Callable):
_name = name.upper()
@ -535,7 +551,7 @@ class CacheManager:
name=_name,
func=func,
expire=expire,
incremental_update=incremental_update,
lazy_load=lazy_load,
)
return func