mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
⚡ 代码性能优化
This commit is contained in:
parent
92d118685d
commit
61f559d605
@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot_plugin_alconna import At
|
||||
@ -25,8 +27,17 @@ Config.add_plugin_config(
|
||||
|
||||
|
||||
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||
if not user_id and not group_id:
|
||||
return 0
|
||||
cache = Cache[list[BanConsole]](CacheType.BAN)
|
||||
results = await cache.get(user_id, group_id) or await cache.get(user_id)
|
||||
group_user, user = await asyncio.gather(
|
||||
cache.get(user_id, group_id), cache.get(user_id)
|
||||
)
|
||||
results = []
|
||||
if group_user:
|
||||
results.extend(group_user)
|
||||
if user:
|
||||
results.extend(user)
|
||||
if not results:
|
||||
return 0
|
||||
for result in results:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import ClassVar
|
||||
|
||||
import nonebot
|
||||
from nonebot_plugin_uninfo import Uninfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -18,6 +19,14 @@ from zhenxun.utils.utils import (
|
||||
from .config import LOGGER_COMMAND
|
||||
from .exception import SkipPluginException
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
"""初始化限制"""
|
||||
await LimitManager.init_limit()
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
limit: PluginLimit
|
||||
@ -34,6 +43,13 @@ class LimitManager:
|
||||
block_limit: ClassVar[dict[str, Limit]] = {}
|
||||
count_limit: ClassVar[dict[str, Limit]] = {}
|
||||
|
||||
@classmethod
|
||||
async def init_limit(cls):
|
||||
"""初始化限制"""
|
||||
limit_list = await PluginLimit.filter(status=True).all()
|
||||
for limit in limit_list:
|
||||
cls.add_limit(limit)
|
||||
|
||||
@classmethod
|
||||
def add_limit(cls, limit: PluginLimit):
|
||||
"""添加限制
|
||||
@ -169,9 +185,7 @@ async def auth_limit(plugin: PluginInfo, session: Uninfo):
|
||||
"""
|
||||
entity = get_entity_ids(session)
|
||||
if plugin.module not in LimitManager.add_module:
|
||||
limit_list: list[PluginLimit] = await plugin.plugin_limit.filter(
|
||||
status=True
|
||||
).all() # type: ignore
|
||||
limit_list = await PluginLimit.filter(module=plugin.module, status=True).all()
|
||||
for limit in limit_list:
|
||||
LimitManager.add_limit(limit)
|
||||
if entity.user_id:
|
||||
|
||||
@ -159,9 +159,9 @@ async def auth(
|
||||
auth_group(plugin, entity, message),
|
||||
auth_admin(plugin, session),
|
||||
auth_plugin(plugin, session, event),
|
||||
auth_limit(plugin, session),
|
||||
]
|
||||
)
|
||||
await auth_limit(plugin, session)
|
||||
except SkipPluginException as e:
|
||||
LimitManager.unblock(module, entity.user_id, entity.group_id, entity.channel_id)
|
||||
logger.info(str(e), LOGGER_COMMAND, session=session)
|
||||
|
||||
@ -11,13 +11,18 @@ from zhenxun.utils.platform import PlatformUtils
|
||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||
|
||||
try:
|
||||
from . import __init_cache
|
||||
from .__init_cache import CacheRoot
|
||||
except DbCacheException as e:
|
||||
raise SystemError(f"ERROR:{e}")
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
await CacheRoot.init_non_lazy_caches()
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: Bot):
|
||||
"""将bot已存在的群组添加群认证
|
||||
|
||||
@ -3,7 +3,6 @@ from zhenxun.models.bot_console import BotConsole
|
||||
from zhenxun.models.group_console import GroupConsole
|
||||
from zhenxun.models.level_user import LevelUser
|
||||
from zhenxun.models.plugin_info import PluginInfo
|
||||
from zhenxun.models.plugin_limit import PluginLimit
|
||||
from zhenxun.models.user_console import UserConsole
|
||||
from zhenxun.services.cache import CacheData, CacheRoot
|
||||
from zhenxun.services.log import logger
|
||||
@ -125,7 +124,7 @@ async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
|
||||
await cache_data.set_key(user.user_id, user)
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.LEVEL, False)
|
||||
@CacheRoot.new(CacheType.LEVEL)
|
||||
async def _():
|
||||
"""初始化等级缓存"""
|
||||
data_list = await LevelUser().all()
|
||||
@ -152,52 +151,61 @@ async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
||||
async def _():
|
||||
"""初始化封禁缓存"""
|
||||
data_list = await BanConsole.all()
|
||||
return {f"{d.user_id or ''}:{d.group_id or ''}": d for d in data_list}
|
||||
return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.BAN, result_model=list[BanConsole])
|
||||
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
||||
"""获取封禁缓存"""
|
||||
key = f"{user_id or ''}:{group_id or ''}"
|
||||
if not user_id and not group_id:
|
||||
return []
|
||||
key = f"{group_id or ''}:{user_id or ''}"
|
||||
logger.info(f"获取封禁缓存: {key}")
|
||||
data = await cache_data.get_key(key)
|
||||
if not data:
|
||||
if user_id and group_id:
|
||||
data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
|
||||
elif user_id:
|
||||
data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
|
||||
elif group_id:
|
||||
data = await BanConsole.filter(
|
||||
user_id__isnull=True, group_id=group_id
|
||||
).all()
|
||||
if data:
|
||||
await cache_data.set_key(key, data)
|
||||
return data
|
||||
if data:
|
||||
logger.info(f"已存在缓存: {key}:{data}")
|
||||
# if not data:
|
||||
# start = time.time()
|
||||
# if user_id and group_id:
|
||||
# data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
|
||||
# elif user_id:
|
||||
# data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
|
||||
# elif group_id:
|
||||
# data = await BanConsole.filter(
|
||||
# user_id__isnull=True, group_id=group_id
|
||||
# ).all()
|
||||
# logger.info(
|
||||
# f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}"
|
||||
# )
|
||||
# if data:
|
||||
# await cache_data.set_key(key, data)
|
||||
# return data
|
||||
return data or []
|
||||
|
||||
|
||||
@CacheRoot.new(CacheType.LIMIT)
|
||||
async def _():
|
||||
"""初始化限制缓存"""
|
||||
data_list = await PluginLimit.filter(status=True).all()
|
||||
return {data.module: data for data in data_list}
|
||||
# @CacheRoot.new(CacheType.LIMIT)
|
||||
# async def _():
|
||||
# """初始化限制缓存"""
|
||||
# data_list = await PluginLimit.filter(status=True).all()
|
||||
# return {data.module: data for data in data_list}
|
||||
|
||||
|
||||
@CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||
async def _(cache_data: CacheData, module: str):
|
||||
"""获取限制缓存"""
|
||||
data = await cache_data.get_key(module)
|
||||
if not data:
|
||||
if limits := await PluginLimit.filter(module=module, status=True):
|
||||
await cache_data.set_key(module, limits)
|
||||
return limits
|
||||
return data or []
|
||||
# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
||||
# async def _(cache_data: CacheData, module: str):
|
||||
# """获取限制缓存"""
|
||||
# data = await cache_data.get_key(module)
|
||||
# if not data:
|
||||
# if limits := await PluginLimit.filter(module=module, status=True):
|
||||
# await cache_data.set_key(module, limits)
|
||||
# return limits
|
||||
# return data or []
|
||||
|
||||
|
||||
@CacheRoot.with_refresh(CacheType.LIMIT)
|
||||
async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
|
||||
"""刷新限制缓存"""
|
||||
if not data:
|
||||
return
|
||||
limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||
for limit in limits:
|
||||
await cache_data.set_key(limit.module, limit)
|
||||
# @CacheRoot.with_refresh(CacheType.LIMIT)
|
||||
# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
|
||||
# """刷新限制缓存"""
|
||||
# if not data:
|
||||
# return
|
||||
# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
||||
# for limit in limits:
|
||||
# await cache_data.set_key(limit.module, limit)
|
||||
|
||||
@ -98,7 +98,7 @@ class CacheData(BaseModel):
|
||||
with_refresh: Callable[..., Any] | None = None
|
||||
expire: int = 600 # 默认10分钟过期
|
||||
reload_count: int = 0
|
||||
incremental_update: bool = True
|
||||
lazy_load: bool = True # 默认延迟加载
|
||||
_cache_instance: BaseCache | None = None
|
||||
result_model: type | None = None
|
||||
_keys: set[str] = set() # 存储所有缓存键
|
||||
@ -168,12 +168,12 @@ class CacheData(BaseModel):
|
||||
try:
|
||||
if hasattr(field, "to_python_value"):
|
||||
if not field.field_type:
|
||||
logger.warning(f"字段 {field_name} 类型为空")
|
||||
logger.debug(f"字段 {field_name} 类型为空")
|
||||
continue
|
||||
field_value = field.to_python_value(field_value)
|
||||
setattr(instance, field_name, field_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"设置字段 {field_name} 失败: {e}")
|
||||
logger.warning(f"设置字段 {field_name} 失败", e=e)
|
||||
|
||||
# 设置 _saved_in_db 标志
|
||||
instance._saved_in_db = True
|
||||
@ -333,7 +333,7 @@ class CacheData(BaseModel):
|
||||
|
||||
async def get(self, key: str, *args, **kwargs) -> Any:
|
||||
"""获取缓存"""
|
||||
if not self.reload_count and not self.incremental_update:
|
||||
if not self.reload_count and not self.lazy_load:
|
||||
await self.reload(*args, **kwargs)
|
||||
|
||||
if not self.getter:
|
||||
@ -343,7 +343,7 @@ class CacheData(BaseModel):
|
||||
|
||||
async def get_all(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""获取所有缓存数据"""
|
||||
if not self.reload_count and not self.incremental_update:
|
||||
if not self.reload_count and not self.lazy_load:
|
||||
await self.reload(*args, **kwargs)
|
||||
|
||||
if not self.getter:
|
||||
@ -523,8 +523,24 @@ class CacheManager:
|
||||
|
||||
_data: ClassVar[dict[str, CacheData]] = {}
|
||||
|
||||
def new(self, name: str, incremental_update: bool = True, expire: int = 600):
|
||||
"""注册新缓存"""
|
||||
async def init_non_lazy_caches(self):
|
||||
"""初始化所有非延迟加载的缓存"""
|
||||
for name, cache in self._data.items():
|
||||
if not cache.lazy_load:
|
||||
try:
|
||||
await cache.reload()
|
||||
logger.info(f"初始化缓存 {name} 完成")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化缓存 {name} 失败: {e}")
|
||||
|
||||
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
|
||||
"""注册新缓存
|
||||
|
||||
Args:
|
||||
name: 缓存名称
|
||||
lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载
|
||||
expire: 过期时间(秒)
|
||||
"""
|
||||
|
||||
def wrapper(func: Callable):
|
||||
_name = name.upper()
|
||||
@ -535,7 +551,7 @@ class CacheManager:
|
||||
name=_name,
|
||||
func=func,
|
||||
expire=expire,
|
||||
incremental_update=incremental_update,
|
||||
lazy_load=lazy_load,
|
||||
)
|
||||
return func
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user