代码性能优化

This commit is contained in:
HibiKier 2025-05-11 22:52:13 +08:00
parent 92d118685d
commit 61f559d605
6 changed files with 106 additions and 52 deletions

View File

@ -1,3 +1,5 @@
import asyncio
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At from nonebot_plugin_alconna import At
@ -25,8 +27,17 @@ Config.add_plugin_config(
async def is_ban(user_id: str | None, group_id: str | None) -> int: async def is_ban(user_id: str | None, group_id: str | None) -> int:
if not user_id and not group_id:
return 0
cache = Cache[list[BanConsole]](CacheType.BAN) cache = Cache[list[BanConsole]](CacheType.BAN)
results = await cache.get(user_id, group_id) or await cache.get(user_id) group_user, user = await asyncio.gather(
cache.get(user_id, group_id), cache.get(user_id)
)
results = []
if group_user:
results.extend(group_user)
if user:
results.extend(user)
if not results: if not results:
return 0 return 0
for result in results: for result in results:

View File

@ -1,5 +1,6 @@
from typing import ClassVar from typing import ClassVar
import nonebot
from nonebot_plugin_uninfo import Uninfo from nonebot_plugin_uninfo import Uninfo
from pydantic import BaseModel from pydantic import BaseModel
@ -18,6 +19,14 @@ from zhenxun.utils.utils import (
from .config import LOGGER_COMMAND from .config import LOGGER_COMMAND
from .exception import SkipPluginException from .exception import SkipPluginException
driver = nonebot.get_driver()
@driver.on_startup
async def _():
"""初始化限制"""
await LimitManager.init_limit()
class Limit(BaseModel): class Limit(BaseModel):
limit: PluginLimit limit: PluginLimit
@ -34,6 +43,13 @@ class LimitManager:
block_limit: ClassVar[dict[str, Limit]] = {} block_limit: ClassVar[dict[str, Limit]] = {}
count_limit: ClassVar[dict[str, Limit]] = {} count_limit: ClassVar[dict[str, Limit]] = {}
@classmethod
async def init_limit(cls):
"""初始化限制"""
limit_list = await PluginLimit.filter(status=True).all()
for limit in limit_list:
cls.add_limit(limit)
@classmethod @classmethod
def add_limit(cls, limit: PluginLimit): def add_limit(cls, limit: PluginLimit):
"""添加限制 """添加限制
@ -169,9 +185,7 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
""" """
entity = get_entity_ids(session) entity = get_entity_ids(session)
if plugin.module not in LimitManager.add_module: if plugin.module not in LimitManager.add_module:
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( limit_list = await PluginLimit.filter(module=plugin.module, status=True).all()
status=True
).all() # type: ignore
for limit in limit_list: for limit in limit_list:
LimitManager.add_limit(limit) LimitManager.add_limit(limit)
if entity.user_id: if entity.user_id:

View File

@ -159,9 +159,9 @@ async def auth(
auth_group(plugin, entity, message), auth_group(plugin, entity, message),
auth_admin(plugin, session), auth_admin(plugin, session),
auth_plugin(plugin, session, event), auth_plugin(plugin, session, event),
auth_limit(plugin, session),
] ]
) )
await auth_limit(plugin, session)
except SkipPluginException as e: except SkipPluginException as e:
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id) LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
logger.info(str(e), LOGGER_COMMAND, session=session) logger.info(str(e), LOGGER_COMMAND, session=session)

View File

@ -11,13 +11,18 @@ from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve())) nonebot.load_plugins(str(Path(__file__).parent.resolve()))
try: try:
from . import __init_cache from .__init_cache import CacheRoot
except DbCacheException as e: except DbCacheException as e:
raise SystemError(f"ERROR{e}") raise SystemError(f"ERROR{e}")
driver = nonebot.get_driver() driver = nonebot.get_driver()
@driver.on_startup
async def _():
await CacheRoot.init_non_lazy_caches()
@driver.on_bot_connect @driver.on_bot_connect
async def _(bot: Bot): async def _(bot: Bot):
"""将bot已存在的群组添加群认证 """将bot已存在的群组添加群认证

View File

@ -3,7 +3,6 @@ from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.plugin_limit import PluginLimit
from zhenxun.models.user_console import UserConsole from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.services.log import logger from zhenxun.services.log import logger
@ -125,7 +124,7 @@ async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
await cache_data.set_key(user.user_id, user) await cache_data.set_key(user.user_id, user)
@CacheRoot.new(CacheType.LEVEL, False) @CacheRoot.new(CacheType.LEVEL)
async def _(): async def _():
"""初始化等级缓存""" """初始化等级缓存"""
data_list = await LevelUser().all() data_list = await LevelUser().all()
@ -152,52 +151,61 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
async def _(): async def _():
"""初始化封禁缓存""" """初始化封禁缓存"""
data_list = await BanConsole.all() data_list = await BanConsole.all()
return {f"{d.user_id or ''}:{d.group_id or ''}": d for d in data_list} return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list}
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole]) @CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
"""获取封禁缓存""" """获取封禁缓存"""
key = f"{user_id or ''}:{group_id or ''}" if not user_id and not group_id:
return []
key = f"{group_id or ''}:{user_id or ''}"
logger.info(f"获取封禁缓存: {key}")
data = await cache_data.get_key(key) data = await cache_data.get_key(key)
if not data:
if user_id and group_id:
data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
elif user_id:
data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
elif group_id:
data = await BanConsole.filter(
user_id__isnull=True, group_id=group_id
).all()
if data: if data:
await cache_data.set_key(key, data) logger.info(f"已存在缓存: {key}:{data}")
return data # if not data:
# start = time.time()
# if user_id and group_id:
# data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
# elif user_id:
# data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
# elif group_id:
# data = await BanConsole.filter(
# user_id__isnull=True, group_id=group_id
# ).all()
# logger.info(
# f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}"
# )
# if data:
# await cache_data.set_key(key, data)
# return data
return data or [] return data or []
@CacheRoot.new(CacheType.LIMIT) # @CacheRoot.new(CacheType.LIMIT)
async def _(): # async def _():
"""初始化限制缓存""" # """初始化限制缓存"""
data_list = await PluginLimit.filter(status=True).all() # data_list = await PluginLimit.filter(status=True).all()
return {data.module: data for data in data_list} # return {data.module: data for data in data_list}
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) # @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
async def _(cache_data: CacheData, module: str): # async def _(cache_data: CacheData, module: str):
"""获取限制缓存""" # """获取限制缓存"""
data = await cache_data.get_key(module) # data = await cache_data.get_key(module)
if not data: # if not data:
if limits := await PluginLimit.filter(module=module, status=True): # if limits := await PluginLimit.filter(module=module, status=True):
await cache_data.set_key(module, limits) # await cache_data.set_key(module, limits)
return limits # return limits
return data or [] # return data or []
@CacheRoot.with_refresh(CacheType.LIMIT) # @CacheRoot.with_refresh(CacheType.LIMIT)
async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None): # async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
"""刷新限制缓存""" # """刷新限制缓存"""
if not data: # if not data:
return # return
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() # limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
for limit in limits: # for limit in limits:
await cache_data.set_key(limit.module, limit) # await cache_data.set_key(limit.module, limit)

View File

@ -98,7 +98,7 @@ class CacheData(BaseModel):
with_refresh: Callable[..., Any] | None = None with_refresh: Callable[..., Any] | None = None
expire: int = 600 # 默认10分钟过期 expire: int = 600 # 默认10分钟过期
reload_count: int = 0 reload_count: int = 0
incremental_update: bool = True lazy_load: bool = True # 默认延迟加载
_cache_instance: BaseCache | None = None _cache_instance: BaseCache | None = None
result_model: type | None = None result_model: type | None = None
_keys: set[str] = set() # 存储所有缓存键 _keys: set[str] = set() # 存储所有缓存键
@ -168,12 +168,12 @@ class CacheData(BaseModel):
try: try:
if hasattr(field, "to_python_value"): if hasattr(field, "to_python_value"):
if not field.field_type: if not field.field_type:
logger.warning(f"字段 {field_name} 类型为空") logger.debug(f"字段 {field_name} 类型为空")
continue continue
field_value = field.to_python_value(field_value) field_value = field.to_python_value(field_value)
setattr(instance, field_name, field_value) setattr(instance, field_name, field_value)
except Exception as e: except Exception as e:
logger.warning(f"设置字段 {field_name} 失败: {e}") logger.warning(f"设置字段 {field_name} 失败", e=e)
# 设置 _saved_in_db 标志 # 设置 _saved_in_db 标志
instance._saved_in_db = True instance._saved_in_db = True
@ -333,7 +333,7 @@ class CacheData(BaseModel):
async def get(self, key: str, *args, **kwargs) -> Any: async def get(self, key: str, *args, **kwargs) -> Any:
"""获取缓存""" """获取缓存"""
if not self.reload_count and not self.incremental_update: if not self.reload_count and not self.lazy_load:
await self.reload(*args, **kwargs) await self.reload(*args, **kwargs)
if not self.getter: if not self.getter:
@ -343,7 +343,7 @@ class CacheData(BaseModel):
async def get_all(self, *args, **kwargs) -> dict[str, Any]: async def get_all(self, *args, **kwargs) -> dict[str, Any]:
"""获取所有缓存数据""" """获取所有缓存数据"""
if not self.reload_count and not self.incremental_update: if not self.reload_count and not self.lazy_load:
await self.reload(*args, **kwargs) await self.reload(*args, **kwargs)
if not self.getter: if not self.getter:
@ -523,8 +523,24 @@ class CacheManager:
_data: ClassVar[dict[str, CacheData]] = {} _data: ClassVar[dict[str, CacheData]] = {}
def new(self, name: str, incremental_update: bool = True, expire: int = 600): async def init_non_lazy_caches(self):
"""注册新缓存""" """初始化所有非延迟加载的缓存"""
for name, cache in self._data.items():
if not cache.lazy_load:
try:
await cache.reload()
logger.info(f"初始化缓存 {name} 完成")
except Exception as e:
logger.error(f"初始化缓存 {name} 失败: {e}")
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
"""注册新缓存
Args:
name: 缓存名称
lazy_load: 是否延迟加载默认为True为False时会在程序启动时自动加载
expire: 过期时间
"""
def wrapper(func: Callable): def wrapper(func: Callable):
_name = name.upper() _name = name.upper()
@ -535,7 +551,7 @@ class CacheManager:
name=_name, name=_name,
func=func, func=func,
expire=expire, expire=expire,
incremental_update=incremental_update, lazy_load=lazy_load,
) )
return func return func