diff --git a/.env.dev b/.env.dev index 62f9d1bd..015a950c 100644 --- a/.env.dev +++ b/.env.dev @@ -27,6 +27,8 @@ QBOT_ID_DATA = '{ # 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹 DB_URL = "" +# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存 +CACHE_MODE = NONE # REDIS配置,使用REDIS替换Cache内存缓存 # REDIS地址 # REDIS_HOST = "127.0.0.1" @@ -50,7 +52,7 @@ PLATFORM_SUPERUSERS = ' DRIVER=~fastapi+~httpx+~websockets -# LOG_LEVEL=DEBUG +# LOG_LEVEL = DEBUG # 服务器和端口 HOST = 127.0.0.1 PORT = 8080 diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py index 1e93f089..634a0690 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_admin.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_admin.py @@ -3,8 +3,7 @@ from nonebot_plugin_uninfo import Uninfo from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.cache import Cache -from zhenxun.utils.enum import CacheType +from zhenxun.services.data_access import DataAccess from zhenxun.utils.utils import get_entity_ids from .exception import SkipPluginException @@ -21,15 +20,21 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo): if not plugin.admin_level: return entity = get_entity_ids(session) - cache = Cache[list[LevelUser]](CacheType.LEVEL) - user_list = await cache.get(session.user.id) or [] + level_dao = DataAccess(LevelUser) + global_user = await level_dao.safe_get_or_none( + user_id=session.user.id, group_id__isnull=True + ) + user_level = 0 + if global_user: + user_level = global_user.user_level if entity.group_id: - user_list += await cache.get(session.user.id, entity.group_id) or [] - if user_list: - user = max(user_list, key=lambda x: x.user_level) - user_level = user.user_level - else: - user_level = 0 + # 获取用户在当前群组的权限数据 + group_users = await level_dao.safe_get_or_none( + user_id=session.user.id, group_id=entity.group_id + ) + if group_users: + user_level = max(user_level, group_users.user_level) + if user_level < plugin.admin_level: await send_message( session, @@ -42,11 +47,12 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo): raise SkipPluginException( f"{plugin.name}({plugin.module}) 管理员权限不足..." ) - elif user_list: - user = max(user_list, key=lambda x: x.user_level) - if user.user_level < plugin.admin_level: + elif global_user: + if global_user.user_level < plugin.admin_level: await send_message( session, f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}", ) - raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...") + raise SkipPluginException( + f"{plugin.name}({plugin.module}) 管理员权限不足..." + ) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py index dcca0731..8bc7e0a0 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_ban.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_ban.py @@ -1,20 +1,15 @@ -import asyncio - from nonebot.adapters import Bot from nonebot.matcher import Matcher from nonebot_plugin_alconna import At from nonebot_plugin_uninfo import Uninfo -from tortoise.exceptions import MultipleObjectsReturned from zhenxun.configs.config import Config from zhenxun.models.ban_console import BanConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.cache import Cache -from zhenxun.services.log import logger -from zhenxun.utils.enum import CacheType, PluginType +from zhenxun.services.data_access import DataAccess +from zhenxun.utils.enum import PluginType from zhenxun.utils.utils import EntityIDs, get_entity_ids -from .config import LOGGER_COMMAND from .exception import SkipPluginException from .utils import freq, send_message @@ -29,10 +24,18 @@ Config.add_plugin_config( async def is_ban(user_id: str | None, group_id: str | None) -> int: if not user_id and not group_id: return 0 - cache = Cache[BanConsole](CacheType.BAN) - group_user, user = await asyncio.gather( - cache.get(user_id, group_id), cache.get(user_id) - ) + ban_dao = DataAccess(BanConsole) + + # 分别获取用户在群组中的ban记录和全局ban记录 + group_user = None + user = None + + if user_id and group_id: + group_user = await ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id) + + if user_id: + user = await ban_dao.safe_get_or_none(user_id=user_id, group_id="") + results = [] if group_user: results.append(group_user) @@ -88,76 +91,55 @@ def format_time(time: float) -> str: return time_str -async def group_handle(cache: Cache[list[BanConsole]], group_id: str): +async def group_handle(group_id: str): """群组ban检查 参数: - cache: cache + ban_dao: BanConsole数据访问对象 group_id: 群组id 异常: SkipPluginException: 群组处于黑名单 """ - try: - if await is_ban(None, group_id): - raise SkipPluginException("群组处于黑名单中...") - except MultipleObjectsReturned: - logger.warning( - "群组黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND - ) - ids = await BanConsole.filter(user_id="", group_id=group_id).values_list( - "id", flat=True - ) - await BanConsole.filter(id__in=ids[:-1]).delete() - await cache.reload() + if await is_ban(None, group_id): + raise SkipPluginException("群组处于黑名单中...") -async def user_handle( - module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo -): +async def user_handle(module: str, entity: EntityIDs, session: Uninfo): """用户ban检查 参数: module: 插件模块名 - cache: cache - user_id: 用户id + ban_dao: BanConsole数据访问对象 + entity: 实体ID信息 session: Uninfo 异常: SkipPluginException: 用户处于黑名单 """ ban_result = Config.get_config("hook", "BAN_RESULT") - try: - time = await is_ban(entity.user_id, entity.group_id) - if not time: - return - time_str = format_time(time) - db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) - if ( - db_plugin - # and not db_plugin.ignore_prompt - and time != -1 - and ban_result - and freq.is_send_limit_message(db_plugin, entity.user_id, False) - ): - await send_message( - session, - [ - At(flag="user", target=entity.user_id), - f"{ban_result}\n在..在 {time_str} 后才会理你喔", - ], - entity.user_id, - ) - raise SkipPluginException("用户处于黑名单中...") - except MultipleObjectsReturned: - logger.warning( - "用户黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND + time = await is_ban(entity.user_id, entity.group_id) + if not time: + return + time_str = format_time(time) + plugin_dao = DataAccess(PluginInfo) + db_plugin = await plugin_dao.safe_get_or_none(module=module) + if ( + db_plugin + and not db_plugin.ignore_prompt + and time != -1 + and ban_result + and freq.is_send_limit_message(db_plugin, entity.user_id, False) + ): + await send_message( + session, + [ + At(flag="user", target=entity.user_id), + f"{ban_result}\n在..在 {time_str} 后才会理你喔", + ], + entity.user_id, ) - ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list( - "id", flat=True - ) - await BanConsole.filter(id__in=ids[:-1]).delete() - await cache.reload() + raise SkipPluginException("用户处于黑名单中...") async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo): @@ -168,8 +150,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo): entity = get_entity_ids(session) if entity.user_id in bot.config.superusers: return - cache = Cache[list[BanConsole]](CacheType.BAN) if entity.group_id: - await group_handle(cache, entity.group_id) + await group_handle(entity.group_id) if entity.user_id: - await user_handle(matcher.plugin_name, cache, entity, session) + await user_handle(matcher.plugin_name, entity, session) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py index 2427223f..b627156a 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_bot.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_bot.py @@ -1,8 +1,7 @@ from zhenxun.models.bot_console import BotConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.cache import Cache +from zhenxun.services.data_access import DataAccess from zhenxun.utils.common_utils import CommonUtils -from zhenxun.utils.enum import CacheType from .exception import SkipPluginException @@ -18,11 +17,11 @@ async def auth_bot(plugin: PluginInfo, bot_id: str): SkipPluginException: 忽略插件 SkipPluginException: 忽略插件 """ - if cache := Cache[BotConsole](CacheType.BOT): - bot = await cache.get(bot_id) - if not bot or not bot.status: - raise SkipPluginException("Bot不存在或休眠中阻断权限检测...") - if CommonUtils.format(plugin.module) in bot.block_plugins: - raise SkipPluginException( - f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..." - ) + bot_dao = DataAccess(BotConsole) + bot = await bot_dao.safe_get_or_none(bot_id=bot_id) + if not bot or not bot.status: + raise SkipPluginException("Bot不存在或休眠中阻断权限检测...") + if CommonUtils.format(plugin.module) in bot.block_plugins: + raise SkipPluginException( + f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..." + ) diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_cost.py b/zhenxun/builtin_plugins/hooks/auth/auth_cost.py index 7a971085..4f632c42 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_cost.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_cost.py @@ -11,6 +11,7 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i """检测是否满足金币条件 参数: + user: UserConsole plugin: PluginInfo session: Uninfo diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_group.py b/zhenxun/builtin_plugins/hooks/auth/auth_group.py index 290a3ad9..8bb3ac3c 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_group.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_group.py @@ -2,8 +2,7 @@ from nonebot_plugin_alconna import UniMsg from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.cache import Cache -from zhenxun.utils.enum import CacheType +from zhenxun.services.data_access import DataAccess from zhenxun.utils.utils import EntityIDs from .config import SwitchEnum @@ -21,7 +20,10 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg): if not entity.group_id: return text = message.extract_plain_text() - group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id) + group_dao = DataAccess(GroupConsole) + group = await group_dao.safe_get_or_none( + group_id=entity.group_id, channel_id__isnull=True + ) if not group: raise SkipPluginException("群组信息不存在...") if group.level < 0: diff --git a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py index ebfe7be1..15351f8c 100644 --- a/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py +++ b/zhenxun/builtin_plugins/hooks/auth/auth_plugin.py @@ -3,9 +3,9 @@ from nonebot_plugin_uninfo import Uninfo from zhenxun.models.group_console import GroupConsole from zhenxun.models.plugin_info import PluginInfo -from zhenxun.services.cache import Cache +from zhenxun.services.data_access import DataAccess from zhenxun.utils.common_utils import CommonUtils -from zhenxun.utils.enum import BlockType, CacheType +from zhenxun.utils.enum import BlockType from zhenxun.utils.utils import get_entity_ids from .exception import IsSuperuserException, SkipPluginException @@ -20,10 +20,12 @@ class GroupCheck: self.session = session self.is_poke = is_poke self.plugin = plugin + self.group_dao = DataAccess(GroupConsole) async def __get_data(self): - cache = Cache[GroupConsole](CacheType.GROUPS) - return await cache.get(self.group_id) + return await self.group_dao.safe_get_or_none( + group_id=self.group_id, channel_id__isnull=True + ) async def check(self): await self.check_superuser_block(self.plugin) @@ -89,6 +91,7 @@ class PluginCheck: self.session = session self.is_poke = is_poke self.group_id = group_id + self.group_dao = DataAccess(GroupConsole) async def check_user(self, plugin: PluginInfo): """全局私聊禁用检测 @@ -118,9 +121,11 @@ class PluginCheck: if plugin.status or plugin.block_type != BlockType.ALL: return """全局状态""" - cache = Cache[GroupConsole](CacheType.GROUPS) - if self.group_id and (group := await cache.get(self.group_id)): - if group.is_super: + if self.group_id: + group = await self.group_dao.safe_get_or_none( + group_id=self.group_id, channel_id__isnull=True + ) + if group and group.is_super: raise IsSuperuserException() sid = self.group_id or self.session.user.id if freq.is_send_limit_message(plugin, sid, self.is_poke): diff --git a/zhenxun/builtin_plugins/hooks/auth_checker.py b/zhenxun/builtin_plugins/hooks/auth_checker.py index 0e0d5c64..e69bae6a 100644 --- a/zhenxun/builtin_plugins/hooks/auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/auth_checker.py @@ -9,13 +9,9 @@ from tortoise.exceptions import IntegrityError from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.user_console import UserConsole -from zhenxun.services.cache import Cache +from zhenxun.services.data_access import DataAccess from zhenxun.services.log import logger -from zhenxun.utils.enum import ( - CacheType, - GoldHandle, - PluginType, -) +from zhenxun.utils.enum import GoldHandle, PluginType from zhenxun.utils.exception import InsufficientGold from zhenxun.utils.platform import PlatformUtils from zhenxun.utils.utils import get_entity_ids @@ -54,8 +50,9 @@ async def get_plugin_and_user( 返回: tuple[PluginInfo, UserConsole]: 插件信息,用户信息 """ - user_cache = Cache[UserConsole](CacheType.USERS) - plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module) + user_dao = DataAccess(UserConsole) + plugin_dao = DataAccess(PluginInfo) + plugin = await plugin_dao.safe_get_or_none(module=module) if not plugin: raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...") if plugin.plugin_type == PluginType.HIDDEN: @@ -64,7 +61,7 @@ async def get_plugin_and_user( ) user = None try: - user = await user_cache.get(user_id) + user = await user_dao.safe_get_or_none(user_id=user_id) except IntegrityError as e: raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e if not user: @@ -108,7 +105,7 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo cost_gold: 消耗金币 session: Uninfo """ - user_cache = Cache[UserConsole](CacheType.USERS) + user_dao = DataAccess(UserConsole) try: await UserConsole.reduce_gold( user_id, @@ -121,8 +118,8 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo if u := await UserConsole.get_user(user_id): u.gold = 0 await u.save(update_fields=["gold"]) - # 更新缓存 - await user_cache.update(user_id) + # 清除缓存,使下次查询时从数据库获取最新数据 + await user_dao.clear_cache(user_id=user_id) logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session) diff --git a/zhenxun/builtin_plugins/hooks/call_hook.py b/zhenxun/builtin_plugins/hooks/call_hook.py index 1893754d..1695a48e 100644 --- a/zhenxun/builtin_plugins/hooks/call_hook.py +++ b/zhenxun/builtin_plugins/hooks/call_hook.py @@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType from zhenxun.utils.manager.message_manager import MessageManager from zhenxun.utils.platform import PlatformUtils +LOG_COMMAND = "MessageHook" + def replace_message(message: Message) -> str: """将消息中的at、image、record、face替换为字符串 @@ -54,11 +56,11 @@ async def handle_api_result( if user_id and message_id: MessageManager.add(str(user_id), str(message_id)) logger.debug( - f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook" + f"收集消息id,user_id: {user_id}, msg_id: {message_id}", LOG_COMMAND ) except Exception as e: logger.warning( - f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e + f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e ) if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"): return @@ -80,6 +82,6 @@ async def handle_api_result( except Exception as e: logger.warning( f"消息发送记录发生错误...data: {data}, result: {result}", - "msg_hook", + LOG_COMMAND, e=e, ) diff --git a/zhenxun/builtin_plugins/init/__init__.py b/zhenxun/builtin_plugins/init/__init__.py index 7c78b019..1bc259fc 100644 --- a/zhenxun/builtin_plugins/init/__init__.py +++ b/zhenxun/builtin_plugins/init/__init__.py @@ -4,23 +4,25 @@ import nonebot from nonebot.adapters import Bot from zhenxun.models.group_console import GroupConsole -from zhenxun.services.cache import DbCacheException +from zhenxun.services.cache import CacheException from zhenxun.services.log import logger +from zhenxun.utils.manager.priority_manager import PriorityLifecycle from zhenxun.utils.platform import PlatformUtils nonebot.load_plugins(str(Path(__file__).parent.resolve())) try: - from .__init_cache import CacheRoot -except DbCacheException as e: + from .__init_cache import register_cache_types +except CacheException as e: raise SystemError(f"ERROR:{e}") driver = nonebot.get_driver() -@driver.on_startup +@PriorityLifecycle.on_startup(priority=5) async def _(): - await CacheRoot.init_non_lazy_caches() + register_cache_types() + logger.info("缓存类型注册完成") @driver.on_bot_connect diff --git a/zhenxun/builtin_plugins/init/__init_cache.py b/zhenxun/builtin_plugins/init/__init_cache.py index 53bfe7e7..29653e12 100644 --- a/zhenxun/builtin_plugins/init/__init_cache.py +++ b/zhenxun/builtin_plugins/init/__init_cache.py @@ -1,208 +1,35 @@ +""" +缓存初始化模块 + +负责注册各种缓存类型,实现按需缓存机制 +""" + from zhenxun.models.ban_console import BanConsole from zhenxun.models.bot_console import BotConsole from zhenxun.models.group_console import GroupConsole from zhenxun.models.level_user import LevelUser from zhenxun.models.plugin_info import PluginInfo from zhenxun.models.user_console import UserConsole -from zhenxun.services.cache import CacheData, CacheRoot +from zhenxun.services.cache import CacheRegistry, cache_config +from zhenxun.services.cache.config import CacheMode from zhenxun.services.log import logger from zhenxun.utils.enum import CacheType -@CacheRoot.new(CacheType.PLUGINS) -async def _(): - """初始化插件缓存""" - data_list = await PluginInfo.get_plugins() - return {p.module: p for p in data_list} +# 注册缓存类型 +def register_cache_types(): + """注册所有缓存类型""" + CacheRegistry.register(CacheType.PLUGINS, PluginInfo) + CacheRegistry.register(CacheType.GROUPS, GroupConsole) + CacheRegistry.register(CacheType.BOT, BotConsole) + CacheRegistry.register(CacheType.USERS, UserConsole) + CacheRegistry.register( + CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}" + ) + CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}") - -@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo) -async def _(cache_data: CacheData, module: str): - """获取插件缓存""" - data = await cache_data.get_key(module) - if not data: - if plugin := await PluginInfo.get_plugin(module=module): - await cache_data.set_key(module, plugin) - logger.debug(f"插件 {module} 数据已设置到缓存") - return plugin - return data - - -@CacheRoot.with_refresh(CacheType.PLUGINS) -async def _(cache_data: CacheData, data: dict[str, PluginInfo] | None): - """刷新插件缓存""" - if not data: - return - plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all() - for plugin in plugins: - await cache_data.set_key(plugin.module, plugin) - - -@CacheRoot.new(CacheType.GROUPS) -async def _(): - """初始化群组缓存""" - data_list = await GroupConsole.all() - return {p.group_id: p for p in data_list if not p.channel_id} - - -@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole) -async def _(cache_data: CacheData, group_id: str): - """获取群组缓存""" - data = await cache_data.get_key(group_id) - if not data: - if group := await GroupConsole.get_group(group_id=group_id): - await cache_data.set_key(group_id, group) - return group - return data - - -@CacheRoot.with_refresh(CacheType.GROUPS) -async def _(cache_data: CacheData, data: dict[str, GroupConsole] | None): - """刷新群组缓存""" - if not data: - return - groups = await GroupConsole.filter( - group_id__in=data.keys(), channel_id__isnull=True - ).all() - for group in groups: - await cache_data.set_key(group.group_id, group) - - -@CacheRoot.new(CacheType.BOT) -async def _(): - """初始化机器人缓存""" - data_list = await BotConsole.all() - return {p.bot_id: p for p in data_list} - - -@CacheRoot.getter(CacheType.BOT, result_model=BotConsole) -async def _(cache_data: CacheData, bot_id: str): - """获取机器人缓存""" - data = await cache_data.get_key(bot_id) - if not data: - if bot := await BotConsole.get_or_none(bot_id=bot_id): - await cache_data.set_key(bot_id, bot) - return bot - return data - - -@CacheRoot.with_refresh(CacheType.BOT) -async def _(cache_data: CacheData, data: dict[str, BotConsole] | None): - """刷新机器人缓存""" - if not data: - return - bots = await BotConsole.filter(bot_id__in=data.keys()).all() - for bot in bots: - await cache_data.set_key(bot.bot_id, bot) - - -@CacheRoot.new(CacheType.USERS) -async def _(): - """初始化用户缓存""" - data_list = await UserConsole.all() - return {p.user_id: p for p in data_list} - - -@CacheRoot.getter(CacheType.USERS, result_model=UserConsole) -async def _(cache_data: CacheData, user_id: str): - """获取用户缓存""" - data = await cache_data.get_key(user_id) - if not data: - if user := await UserConsole.get_user(user_id=user_id): - await cache_data.set_key(user_id, user) - return user - return data - - -@CacheRoot.with_refresh(CacheType.USERS) -async def _(cache_data: CacheData, data: dict[str, UserConsole] | None): - """刷新用户缓存""" - if not data: - return - users = await UserConsole.filter(user_id__in=data.keys()).all() - for user in users: - await cache_data.set_key(user.user_id, user) - - -@CacheRoot.new(CacheType.LEVEL) -async def _(): - """初始化等级缓存""" - data_list = await LevelUser().all() - return {f"{d.user_id}:{d.group_id or ''}": d for d in data_list} - - -@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser]) -async def _(cache_data: CacheData, user_id: str, group_id: str | None = None): - """获取等级缓存""" - key = f"{user_id}:{group_id or ''}" - data = await cache_data.get_key(key) - if not data: - if group_id: - data = await LevelUser.filter(user_id=user_id, group_id=group_id).all() - else: - data = await LevelUser.filter(user_id=user_id, group_id__isnull=True).all() - if data: - await cache_data.set_key(key, data) - return data - return data or [] - - -@CacheRoot.new(CacheType.BAN, False) -async def _(): - """初始化封禁缓存""" - data_list = await BanConsole.all() - return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list} - - -@CacheRoot.getter(CacheType.BAN, result_model=BanConsole) -async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None): - """获取封禁缓存""" - if not user_id and not group_id: - return [] - key = f"{group_id or ''}:{user_id or ''}" - data = await cache_data.get_key(key) - # if not data: - # start = time.time() - # if user_id and group_id: - # data = await BanConsole.filter(user_id=user_id, group_id=group_id).all() - # elif user_id: - # data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all() - # elif group_id: - # data = await BanConsole.filter( - # user_id__isnull=True, group_id=group_id - # ).all() - # logger.info( - # f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}" - # ) - # if data: - # await cache_data.set_key(key, data) - # return data - return data or [] - - -# @CacheRoot.new(CacheType.LIMIT) -# async def _(): -# """初始化限制缓存""" -# data_list = await PluginLimit.filter(status=True).all() -# return {data.module: data for data in data_list} - - -# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit]) -# async def _(cache_data: CacheData, module: str): -# """获取限制缓存""" -# data = await cache_data.get_key(module) -# if not data: -# if limits := await PluginLimit.filter(module=module, status=True): -# await cache_data.set_key(module, limits) -# return limits -# return data or [] - - -# @CacheRoot.with_refresh(CacheType.LIMIT) -# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None): -# """刷新限制缓存""" -# if not data: -# return -# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all() -# for limit in limits: -# await cache_data.set_key(limit.module, limit) + if cache_config.cache_mode == CacheMode.NONE: + logger.info("缓存功能已禁用,将直接从数据库获取数据") + else: + logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}") + logger.info("使用增量缓存模式,数据将按需加载到缓存中") diff --git a/zhenxun/builtin_plugins/init/manager.py b/zhenxun/builtin_plugins/init/manager.py index d6ffa223..9fab6a1d 100644 --- a/zhenxun/builtin_plugins/init/manager.py +++ b/zhenxun/builtin_plugins/init/manager.py @@ -205,7 +205,7 @@ class Manager: self.cd_data: dict[str, PluginCdBlock] = {} if self.cd_file.exists(): with open(self.cd_file, encoding="utf8") as f: - temp = _yaml.load(f) + temp = _yaml.load(f) or {} if "PluginCdLimit" in temp.keys(): for k, v in temp["PluginCdLimit"].items(): if "." in k: @@ -216,7 +216,7 @@ class Manager: self.block_data: dict[str, BaseBlock] = {} if self.block_file.exists(): with open(self.block_file, encoding="utf8") as f: - temp = _yaml.load(f) + temp = _yaml.load(f) or {} if "PluginBlockLimit" in temp.keys(): for k, v in temp["PluginBlockLimit"].items(): if "." in k: @@ -227,7 +227,7 @@ class Manager: self.count_data: dict[str, PluginCountBlock] = {} if self.count_file.exists(): with open(self.count_file, encoding="utf8") as f: - temp = _yaml.load(f) + temp = _yaml.load(f) or {} if "PluginCountLimit" in temp.keys(): for k, v in temp["PluginCountLimit"].items(): if "." in k: diff --git a/zhenxun/models/ban_console.py b/zhenxun/models/ban_console.py index 6c6f895b..5375fcf9 100644 --- a/zhenxun/models/ban_console.py +++ b/zhenxun/models/ban_console.py @@ -33,6 +33,8 @@ class BanConsole(Model): cache_type = CacheType.BAN """缓存类型""" + cache_key_field = ("user_id", "group_id") + """缓存键字段""" enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE] """开启锁""" diff --git a/zhenxun/models/bot_console.py b/zhenxun/models/bot_console.py index d329c551..01a93535 100644 --- a/zhenxun/models/bot_console.py +++ b/zhenxun/models/bot_console.py @@ -31,6 +31,9 @@ class BotConsole(Model): table_description = "Bot数据表" cache_type = CacheType.BOT + """缓存类型""" + cache_key_field = "bot_id" + """缓存键字段""" @staticmethod def format(name: str) -> str: diff --git a/zhenxun/models/group_console.py b/zhenxun/models/group_console.py index 8e81000b..a09ca476 100644 --- a/zhenxun/models/group_console.py +++ b/zhenxun/models/group_console.py @@ -90,6 +90,8 @@ class GroupConsole(Model): cache_type = CacheType.GROUPS """缓存类型""" + cache_key_field = ("group_id", "channel_id") + """缓存键字段""" enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE] """开启锁""" @@ -123,7 +125,18 @@ class GroupConsole(Model): ) @classmethod - @CacheRoot.listener(CacheType.GROUPS) + async def _update_cache(cls, instance): + """更新缓存 + + 参数: + instance: 需要更新缓存的实例 + """ + if cache_type := cls.get_cache_type(): + key = cls.get_cache_key(instance) + if key is not None: + await CacheRoot.invalidate_cache(cache_type, key) + + @classmethod async def create( cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any ) -> Self: @@ -136,6 +149,9 @@ class GroupConsole(Model): if task_modules or plugin_modules: await cls._update_modules(group, task_modules, plugin_modules, using_db) + # 更新缓存 + await cls._update_cache(group) + return group @classmethod @@ -187,14 +203,13 @@ class GroupConsole(Model): if task_modules or plugin_modules: await cls._update_modules(group, task_modules, plugin_modules, using_db) + # 更新缓存 if is_create: - if cache := await CacheRoot.get_cache(CacheType.GROUPS): - await cache.update(group.group_id, group) + await cls._update_cache(group) return group, is_create @classmethod - @CacheRoot.listener(CacheType.GROUPS) async def update_or_create( cls, defaults: dict | None = None, @@ -214,6 +229,9 @@ class GroupConsole(Model): if task_modules or plugin_modules: await cls._update_modules(group, task_modules, plugin_modules, using_db) + # 更新缓存 + await cls._update_cache(group) + return group, is_create @classmethod @@ -327,6 +345,9 @@ class GroupConsole(Model): if update_fields: await group.save(update_fields=update_fields) + # 更新缓存 + await cls._update_cache(group) + @classmethod async def set_unblock_plugin( cls, @@ -363,6 +384,9 @@ class GroupConsole(Model): if update_fields: await group.save(update_fields=update_fields) + # 更新缓存 + await cls._update_cache(group) + @classmethod async def is_normal_block_plugin( cls, group_id: str, module: str, channel_id: str | None = None @@ -466,6 +490,9 @@ class GroupConsole(Model): if update_fields: await group.save(update_fields=update_fields) + # 更新缓存 + await cls._update_cache(group) + @classmethod async def set_unblock_task( cls, @@ -500,6 +527,9 @@ class GroupConsole(Model): if update_fields: await group.save(update_fields=update_fields) + # 更新缓存 + await cls._update_cache(group) + @classmethod def _run_script(cls): return [ diff --git a/zhenxun/models/level_user.py b/zhenxun/models/level_user.py index dac9e3cb..3b468ad8 100644 --- a/zhenxun/models/level_user.py +++ b/zhenxun/models/level_user.py @@ -22,6 +22,9 @@ class LevelUser(Model): unique_together = ("user_id", "group_id") cache_type = CacheType.LEVEL + """缓存类型""" + cache_key_field = ("user_id", "group_id") + """缓存键字段""" @classmethod async def get_user_level(cls, user_id: str, group_id: str | None) -> int: diff --git a/zhenxun/models/plugin_info.py b/zhenxun/models/plugin_info.py index aeecc71b..177ab70e 100644 --- a/zhenxun/models/plugin_info.py +++ b/zhenxun/models/plugin_info.py @@ -60,6 +60,9 @@ class PluginInfo(Model): table_description = "插件基本信息" cache_type = CacheType.PLUGINS + """缓存类型""" + cache_key_field = "module" + """缓存键字段""" @classmethod async def get_plugin( diff --git a/zhenxun/models/user_console.py b/zhenxun/models/user_console.py index 9529993c..5bca4238 100644 --- a/zhenxun/models/user_console.py +++ b/zhenxun/models/user_console.py @@ -31,6 +31,9 @@ class UserConsole(Model): table_description = "用户数据表" cache_type = CacheType.USERS + """缓存类型""" + cache_key_field = "user_id" + """缓存键字段""" @classmethod async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole": diff --git a/zhenxun/services/cache.py b/zhenxun/services/cache.py deleted file mode 100644 index 04222e95..00000000 --- a/zhenxun/services/cache.py +++ /dev/null @@ -1,703 +0,0 @@ -from collections.abc import Callable -from datetime import datetime -from functools import wraps -import inspect -from typing import Any, ClassVar, Generic, TypeVar - -from aiocache import Cache as AioCache - -# from aiocache.backends.redis import RedisCache -from aiocache.base import BaseCache -from aiocache.serializers import JsonSerializer -import nonebot -from nonebot.compat import model_dump -from nonebot.utils import is_coroutine_callable -from pydantic import BaseModel - -from zhenxun.services.log import logger - -__all__ = ["Cache", "CacheData", "CacheRoot"] - -T = TypeVar("T") - -LOG_COMMAND = "cache" - -driver = nonebot.get_driver() - - -class Config(BaseModel): - enable_cache: bool = True - """是否开启缓存功能""" - redis_host: str | None = None - """redis地址""" - redis_port: int | None = None - """redis端口""" - redis_password: str | None = None - """redis密码""" - redis_expire: int = 600 - """redis过期时间""" - - -config = nonebot.get_plugin_config(Config) - - -class DbCacheException(Exception): - """缓存相关异常""" - - def __init__(self, info: str): - self.info = info - - def __str__(self) -> str: - return self.info - - -def validate_name(func: Callable): - """验证缓存名称是否存在的装饰器""" - - def wrapper(self, name: str, *args, **kwargs): - _name = name.upper() - if _name not in CacheManager._data: - raise DbCacheException(f"缓存数据 {name} 不存在") - return func(self, _name, *args, **kwargs) - - return wrapper - - -class CacheGetter(BaseModel, Generic[T]): - """缓存数据获取器""" - - get_func: Callable[..., Any] | None = None - get_all_func: Callable[..., Any] | None = None - - async def get(self, cache_data: "CacheData", key: str, *args, **kwargs) -> T: - """获取单个缓存数据""" - if not self.get_func: - data = await cache_data.get_key(key) - if cache_data.result_model: - return cache_data._deserialize_value(data, cache_data.result_model) - return data - - if is_coroutine_callable(self.get_func): - data = await self.get_func(cache_data, key, *args, **kwargs) - else: - data = self.get_func(cache_data, key, *args, **kwargs) - - if cache_data.result_model: - return cache_data._deserialize_value(data, cache_data.result_model) - return data - - async def get_all(self, cache_data: "CacheData", *args, **kwargs) -> dict[str, T]: - """获取所有缓存数据""" - if not self.get_all_func: - data = await cache_data.get_all_data() - if cache_data.result_model: - return { - k: cache_data._deserialize_value(v, cache_data.result_model) - for k, v in data.items() - } - return data - - if is_coroutine_callable(self.get_all_func): - data = await self.get_all_func(cache_data, *args, **kwargs) - else: - data = self.get_all_func(cache_data, *args, **kwargs) - - if cache_data.result_model: - return { - k: cache_data._deserialize_value(v, cache_data.result_model) - for k, v in data.items() - } - return data - - -class CacheData(BaseModel): - """缓存数据模型""" - - name: str - func: Callable[..., Any] - getter: CacheGetter | None = None - updater: Callable[..., Any] | None = None - with_refresh: Callable[..., Any] | None = None - expire: int = 600 # 默认10分钟过期 - reload_count: int = 0 - lazy_load: bool = True # 默认延迟加载 - result_model: type | None = None - _keys: set[str] = set() # 存储所有缓存键 - cache: BaseCache | AioCache - - class Config: - arbitrary_types_allowed = True - - def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any: - """反序列化值,将JSON数据转换回原始类型 - - 参数: - value: 需要反序列化的值 - target_type: 目标类型,用于指导反序列化 - - 返回: - 反序列化后的值 - """ - if value is None: - return None - - # 如果是字典且指定了目标类型 - if isinstance(value, dict) and target_type: - # 处理Tortoise-ORM Model - if hasattr(target_type, "_meta"): - return self._extracted_from__deserialize_value_19(value, target_type) - elif hasattr(target_type, "model_validate"): - return target_type.model_validate(value) - elif hasattr(target_type, "from_dict"): - return target_type.from_dict(value) - elif hasattr(target_type, "parse_obj"): - return target_type.parse_obj(value) - else: - return target_type(**value) - - # 处理列表类型 - if isinstance(value, list): - if not value: - return value - if ( - target_type - and hasattr(target_type, "__origin__") - and target_type.__origin__ is list - ): - item_type = target_type.__args__[0] - return [self._deserialize_value(item, item_type) for item in value] - return [self._deserialize_value(item) for item in value] - - # 处理字典类型 - if isinstance(value, dict): - return {k: self._deserialize_value(v) for k, v in value.items()} - - return value - - def _extracted_from__deserialize_value_19(self, value, target_type): - # 处理字段值 - processed_value = {} - for field_name, field_value in value.items(): - if field := target_type._meta.fields_map.get(field_name): - # 跳过反向关系字段 - if hasattr(field, "_related_name"): - continue - processed_value[field_name] = field_value - - logger.debug(f"处理后的值: {processed_value}") - - # 创建模型实例 - instance = target_type() - # 设置字段值 - for field_name, field_value in processed_value.items(): - if field_name in target_type._meta.fields_map: - field = target_type._meta.fields_map[field_name] - # 设置字段值 - try: - if hasattr(field, "to_python_value"): - if not field.field_type: - logger.debug(f"字段 {field_name} 类型为空") - continue - field_value = field.to_python_value(field_value) - setattr(instance, field_name, field_value) - except Exception as e: - logger.warning(f"设置字段 {field_name} 失败", e=e) - - # 设置 _saved_in_db 标志 - instance._saved_in_db = True - return instance - - async def get_data(self) -> Any: - """从缓存获取数据""" - try: - data = await self.cache.get(self.name) # type: ignore - logger.debug(f"获取缓存 {self.name} 数据: {data}") - - # 如果数据为空,尝试重新加载 - # if data is None: - # logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载") - # try: - # if self.has_args(): - # new_data = ( - # await self.func() - # if is_coroutine_callable(self.func) - # else self.func() - # ) - # else: - # new_data = ( - # await self.func() - # if is_coroutine_callable(self.func) - # else self.func() - # ) - - # await self.set_data(new_data) - # self.reload_count += 1 - # logger.info(f"重新加载缓存 {self.name} 完成") - # return new_data - # except Exception as e: - # logger.error(f"重新加载缓存 {self.name} 失败: {e}") - # return None - - # 使用 result_model 进行反序列化 - if self.result_model: - return self._deserialize_value(data, self.result_model) - - return data - except Exception as e: - logger.error(f"获取缓存 {self.name} 失败: {e}") - return None - - def _serialize_value(self, value: Any) -> Any: - """序列化值,将数据转换为JSON可序列化的格式 - - 参数: - value: 需要序列化的值 - - 返回: - JSON可序列化的值 - """ - if value is None: - return None - - # 处理datetime - if isinstance(value, datetime): - return value.isoformat() - - # 处理Tortoise-ORM Model - if hasattr(value, "_meta") and hasattr(value, "__dict__"): - result = {} - for field in value._meta.fields: - try: - field_value = getattr(value, field) - # 跳过反向关系字段 - if isinstance(field_value, list | set) and hasattr( - field_value, "_related_name" - ): - continue - # 跳过外键关系字段 - if hasattr(field_value, "_meta"): - field_value = getattr( - field_value, value._meta.fields[field].related_name or "id" - ) - result[field] = self._serialize_value(field_value) - except AttributeError: - continue - return result - - # 处理Pydantic模型 - elif isinstance(value, BaseModel): - return model_dump(value) - elif isinstance(value, dict): - # 处理字典 - return {str(k): self._serialize_value(v) for k, v in value.items()} - elif isinstance(value, list | tuple | set): - # 处理列表、元组、集合 - return [self._serialize_value(item) for item in value] - elif isinstance(value, int | float | str | bool): - # 基本类型直接返回 - return value - else: - # 其他类型转换为字符串 - return str(value) - - async def set_data(self, value: Any): - """设置缓存数据""" - try: - # 1. 序列化数据 - serialized_value = self._serialize_value(value) - logger.debug(f"设置缓存 {self.name} 原始数据: {value}") - logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}") - - # 2. 删除旧数据 - await self.cache.delete(self.name) # type: ignore - logger.debug(f"删除缓存 {self.name} 旧数据") - - # 3. 设置新数据 - await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore - logger.debug(f"设置缓存 {self.name} 新数据完成") - - except Exception as e: - logger.error(f"设置缓存 {self.name} 失败: {e}") - raise # 重新抛出异常,让上层处理 - - async def delete_data(self): - """删除缓存数据""" - try: - await self.cache.delete(self.name) # type: ignore - except Exception as e: - logger.error(f"删除缓存 {self.name}", e=e) - - async def get(self, key: str, *args, **kwargs) -> Any: - """获取缓存""" - if not self.reload_count and not self.lazy_load: - await self.reload(*args, **kwargs) - - if not self.getter: - return await self.get_key(key) - - return await self.getter.get(self, key, *args, **kwargs) - - async def get_all(self, *args, **kwargs) -> dict[str, Any]: - """获取所有缓存数据""" - if not self.reload_count and not self.lazy_load: - await self.reload(*args, **kwargs) - - if not self.getter: - return await self.get_all_data() - - return await self.getter.get_all(self, *args, **kwargs) - - async def update(self, key: str, value: Any = None, *args, **kwargs): - """更新单个缓存项""" - if not self.updater: - logger.warning(f"缓存 {self.name} 未配置更新方法") - return - - current_data = await self.get_key(key) or {} - if is_coroutine_callable(self.updater): - await self.updater(current_data, key, value, *args, **kwargs) - else: - self.updater(current_data, key, value, *args, **kwargs) - - await self.set_key(key, current_data) - logger.debug(f"更新缓存 {self.name}.{key}") - - async def refresh(self, *args, **kwargs): - """刷新缓存数据""" - if not self.with_refresh: - return await self.reload(*args, **kwargs) - - current_data = await self.get_data() - if current_data: - if is_coroutine_callable(self.with_refresh): - await self.with_refresh(current_data, *args, **kwargs) - else: - self.with_refresh(current_data, *args, **kwargs) - await self.set_data(current_data) - logger.debug(f"刷新缓存 {self.name}") - - async def reload(self, *args, **kwargs): - """重新加载全部数据""" - try: - if self.has_args(): - new_data = ( - await self.func(*args, **kwargs) - if is_coroutine_callable(self.func) - else self.func(*args, **kwargs) - ) - else: - new_data = ( - await self.func() - if is_coroutine_callable(self.func) - else self.func() - ) - - # 如果是字典,则分别存储每个键值对 - if isinstance(new_data, dict): - for key, value in new_data.items(): - await self.set_key(key, value) - else: - # 如果不是字典,则存储为单个键值对 - await self.set_key("default", new_data) - - self.reload_count += 1 - logger.info(f"重新加载缓存 {self.name} 完成") - except Exception as e: - logger.error(f"重新加载缓存 {self.name} 失败: {e}") - raise - - def has_args(self) -> bool: - """检查函数是否需要参数""" - sig = inspect.signature(self.func) - return any( - param.kind - in ( - param.POSITIONAL_OR_KEYWORD, - param.POSITIONAL_ONLY, - param.VAR_POSITIONAL, - ) - for param in sig.parameters.values() - ) - - async def get_key(self, key: str) -> Any: - """获取缓存中指定键的数据 - - 参数: - key: 要获取的键名 - - 返回: - 键对应的值,如果不存在返回None - """ - cache_key = self._get_cache_key(key) - try: - data = await self.cache.get(cache_key) # type: ignore - logger.debug(f"获取缓存 {cache_key} 数据: {data}") - - if self.result_model: - return self._deserialize_value(data, self.result_model) - return data - except Exception as e: - logger.error(f"获取缓存 {cache_key} 失败: {e}") - return None - - async def get_keys(self, keys: list[str]) -> dict[str, Any]: - """获取缓存中多个键的数据 - - 参数: - keys: 要获取的键名列表 - - 返回: - 包含所有请求键值的字典,不存在的键值为None - """ - try: - data = await self.get_data() - if isinstance(data, dict): - return {key: data.get(key) for key in keys} - return dict.fromkeys(keys) - except Exception as e: - logger.error(f"获取缓存 {self.name} 的多个键失败: {e}") - return dict.fromkeys(keys) - - def _get_cache_key(self, key: str) -> str: - """获取缓存键名""" - return f"{self.name}:{key}" - - async def get_all_data(self) -> dict[str, Any]: - """获取所有缓存数据""" - try: - result = {} - for key in self._keys: - # 提取原始键名(去掉前缀) - original_key = key.split(":", 1)[1] - data = await self.cache.get(key) # type: ignore - if self.result_model: - result[original_key] = self._deserialize_value( - data, self.result_model - ) - else: - result[original_key] = data - return result - except Exception as e: - logger.error(f"获取所有缓存数据失败: {e}") - return {} - - async def set_key(self, key: str, value: Any): - """设置指定键的缓存数据""" - cache_key = self._get_cache_key(key) - try: - serialized_value = self._serialize_value(value) - await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore - self._keys.add(cache_key) # 添加到键列表 - logger.debug(f"设置缓存 {cache_key} 数据完成") - except Exception as e: - logger.error(f"设置缓存 {cache_key} 失败: {e}") - raise - - async def delete_key(self, key: str): - """删除指定键的缓存数据""" - cache_key = self._get_cache_key(key) - try: - await self.cache.delete(cache_key) # type: ignore - self._keys.discard(cache_key) # 从键列表中移除 - logger.debug(f"删除缓存 {cache_key} 完成") - except Exception as e: - logger.error(f"删除缓存 {cache_key} 失败: {e}") - - async def clear(self): - """清除所有缓存数据""" - try: - for key in list(self._keys): # 使用列表复制避免在迭代时修改 - await self.cache.delete(key) # type: ignore - self._keys.clear() - logger.debug(f"清除缓存 {self.name} 完成") - except Exception as e: - logger.error(f"清除缓存 {self.name} 失败: {e}") - - -class CacheManager: - """全局缓存管理器""" - - _cache_instance: BaseCache | AioCache | None = None - _data: ClassVar[dict[str, CacheData]] = {} - - @property - def _cache(self) -> BaseCache | AioCache: - """获取aiocache实例""" - if self._cache_instance is None: - if config.redis_host: - self._cache_instance = AioCache( - AioCache.REDIS, # type: ignore - serializer=JsonSerializer(), - namespace="zhenxun_cache", - timeout=30, # 操作超时时间 - ttl=config.redis_expire, # 设置默认过期时间 - endpoint=config.redis_host, - port=config.redis_port, - password=config.redis_password, - ) - else: - self._cache_instance = AioCache( - AioCache.MEMORY, - serializer=JsonSerializer(), - namespace="zhenxun_cache", - timeout=30, # 操作超时时间 - ttl=config.redis_expire, # 设置默认过期时间 - ) - logger.info("初始化缓存完成...", LOG_COMMAND) - return self._cache_instance - - async def close(self): - if self._cache_instance: - await self._cache_instance.close() # type: ignore - - async def verify_connection(self): - """连接测试""" - try: - await self._cache.get("__test__") # type: ignore - except Exception as e: - logger.error("连接失败", LOG_COMMAND, e=e) - raise - - async def init_non_lazy_caches(self): - """初始化所有非延迟加载的缓存""" - await self.verify_connection() - for name, cache in self._data.items(): - cache.cache = self._cache - if not cache.lazy_load: - try: - await cache.reload() - logger.info(f"初始化缓存 {name} 完成") - except Exception as e: - logger.error(f"初始化缓存 {name} 失败: {e}") - - def new(self, name: str, lazy_load: bool = True, expire: int = 600): - """注册新缓存 - - 参数: - name: 缓存名称 - lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载 - expire: 过期时间(秒) - """ - - def wrapper(func: Callable): - _name = name.upper() - if _name in self._data: - raise DbCacheException(f"缓存 {name} 已存在") - - self._data[_name] = CacheData( - name=_name, - func=func, - expire=expire, - lazy_load=lazy_load, - cache=self._cache, - ) - return func - - return wrapper - - def listener(self, name: str): - """创建缓存监听器""" - - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - try: - return ( - await func(*args, **kwargs) - if is_coroutine_callable(func) - else func(*args, **kwargs) - ) - finally: - cache = self._data.get(name.upper()) - if cache and cache.with_refresh: - await cache.refresh() - logger.debug(f"监听器触发缓存 {name} 刷新") - - return wrapper - - return decorator - - @validate_name - def updater(self, name: str): - """设置缓存更新方法""" - - def wrapper(func: Callable): - self._data[name].updater = func - return func - - return wrapper - - @validate_name - def getter(self, name: str, result_model: type): - """设置缓存获取方法""" - - def wrapper(func: Callable): - self._data[name].getter = CacheGetter[result_model](get_func=func) - self._data[name].result_model = result_model - return func - - return wrapper - - @validate_name - def with_refresh(self, name: str): - """设置缓存刷新方法""" - - def wrapper(func: Callable): - self._data[name].with_refresh = func - return func - - return wrapper - - async def get_cache_data(self, name: str) -> Any | None: - """获取缓存数据""" - cache = await self.get_cache(name.upper()) - return await cache.get_data() if cache else None - - async def get_cache(self, name: str) -> CacheData | None: - """获取缓存对象""" - return self._data.get(name.upper()) - - async def get(self, name: str, *args, **kwargs) -> Any: - """获取缓存内容""" - cache = await self.get_cache(name.upper()) - return await cache.get(*args, **kwargs) if cache else None - - async def update(self, name: str, key: str, value: Any = None, *args, **kwargs): - """更新缓存项""" - cache = await self.get_cache(name.upper()) - if cache: - await cache.update(key, value, *args, **kwargs) - - async def reload(self, name: str, *args, **kwargs): - """重新加载缓存""" - cache = await self.get_cache(name.upper()) - if cache: - await cache.reload(*args, **kwargs) - - -# 全局缓存管理器实例 -CacheRoot = CacheManager() - - -class Cache(Generic[T]): - """类型化缓存访问接口""" - - def __init__(self, module: str): - self.module = module.upper() - - async def get(self, *args, **kwargs) -> T | None: - """获取缓存""" - return await CacheRoot.get(self.module, *args, **kwargs) - - async def update(self, key: str, value: Any = None, *args, **kwargs): - """更新缓存项""" - await CacheRoot.update(self.module, key, value, *args, **kwargs) - - async def reload(self, *args, **kwargs): - """重新加载缓存""" - await CacheRoot.reload(self.module, *args, **kwargs) - - -@driver.on_shutdown -async def _(): - await CacheRoot.close() diff --git a/zhenxun/services/cache/__init__.py b/zhenxun/services/cache/__init__.py new file mode 100644 index 00000000..4b98e871 --- /dev/null +++ b/zhenxun/services/cache/__init__.py @@ -0,0 +1,1032 @@ +""" +缓存系统模块 + +提供统一的缓存访问接口,支持内存缓存和Redis缓存 + +使用示例: +1. 使用Cache类进行缓存操作 +```python +from zhenxun.services.cache import Cache +from zhenxun.utils.enum import CacheType + +# 创建缓存访问对象 +level_cache = Cache[list[LevelUser]](CacheType.LEVEL) + +# 获取缓存数据 +users = await level_cache.get({"user_id": "123", "group_id": "456"}) + +# 设置缓存数据 +await level_cache.set({"user_id": "123", "group_id": "456"}, users) +``` + +2. 使用CacheDict作为全局字典 +```python +from zhenxun.services.cache.cache_containers import CacheDict + +# 创建缓存字典(默认永不过期) +config_dict = CacheDict("global_config") + +# 创建有过期时间的缓存字典(1小时后过期) +temp_dict = CacheDict("temp_config", expire=3600) + +# 使用字典操作 +config_dict["key"] = "value" +value = config_dict["key"] + +# 保存缓存数据(可选) +await config_dict.save() +``` + +3. 使用CacheList作为全局列表 +```python +from zhenxun.services.cache.cache_containers import CacheList + +# 创建缓存列表(默认永不过期) +message_list = CacheList("recent_messages") + +# 创建有过期时间的缓存列表(30分钟后过期) +temp_list = CacheList("temp_messages", expire=1800) + +# 使用列表操作 +message_list.append("新消息") +message = message_list[0] + +# 保存缓存数据(可选) +await message_list.save() +``` +""" + +from collections.abc import Callable +from datetime import datetime +from functools import wraps +from typing import Any, ClassVar, Generic, TypeVar, get_type_hints + +from aiocache import Cache as AioCache +from aiocache.base import BaseCache +from aiocache.serializers import JsonSerializer +import nonebot +from nonebot.compat import model_dump +from nonebot.utils import is_coroutine_callable +from pydantic import BaseModel + +from zhenxun.services.log import logger + +from .cache_containers import CacheDict, CacheList +from .config import ( + CACHE_KEY_PREFIX, + CACHE_KEY_SEPARATOR, + DEFAULT_EXPIRE, + LOG_COMMAND, + SPECIAL_KEY_FORMATS, + CacheMode, +) + +__all__ = [ + "Cache", + "CacheData", + "CacheDict", + "CacheList", + "CacheManager", + "CacheRegistry", + "CacheRoot", +] + +T = TypeVar("T") + + +class Config(BaseModel): + """缓存配置""" + + cache_mode: str = CacheMode.NONE + """缓存模式: MEMORY(内存缓存), REDIS(Redis缓存), NONE(不使用缓存)""" + redis_host: str | None = None + """redis地址""" + redis_port: int | None = None + """redis端口""" + redis_password: str | None = None + """redis密码""" + redis_expire: int = DEFAULT_EXPIRE + """redis过期时间""" + + +# 获取配置 +driver = nonebot.get_driver() +cache_config = nonebot.get_plugin_config(Config) + + +class CacheException(Exception): + """缓存相关异常""" + + def __init__(self, info: str): + self.info = info + + def __str__(self) -> str: + return self.info + + +class CacheModel(BaseModel): + """缓存数据模型""" + + name: str + """缓存名称""" + expire: int = DEFAULT_EXPIRE + """过期时间(秒)""" + result_type: type | None = None + """结果类型""" + key_format: str | None = None + """键格式""" + + class Config: + arbitrary_types_allowed = True + + +""" +CacheData类是缓存系统的核心组件,它负责管理单个缓存项的数据和生命周期。 + +设计思路: +1. 每个CacheData实例代表一个具名的缓存项,如"用户列表"、"配置数据"等 +2. 它提供了数据的懒加载、自动过期和持久化等功能 +3. 可以通过func参数提供一个获取数据的函数,在数据不存在或过期时自动调用 +4. 支持直接设置_data属性,方便外部直接操作数据 + +主要用途: +1. 作为CacheDict和CacheList的后端存储 +2. 被CacheManager管理,实现统一的缓存生命周期控制 +3. 提供数据过期和自动刷新机制 + +通常情况下,用户不需要直接使用CacheData,而是通过Cache、CacheDict或CacheList来操作缓存。 +""" + + +class CacheData: + """缓存数据类""" + + def __init__( + self, + name: str, + func: Callable, + expire: int = DEFAULT_EXPIRE, + lazy_load: bool = True, + cache: BaseCache | AioCache | None = None, + ): + """初始化缓存数据 + + 参数: + name: 缓存名称 + func: 获取数据的函数 + expire: 过期时间(秒) + lazy_load: 是否延迟加载 + cache: 缓存后端 + """ + self.name = name.upper() + self.func = func + self.expire = expire + self.lazy_load = lazy_load + self.cache = cache + self._data = None + self._last_update = 0 + + # 如果不是延迟加载,立即加载数据 + if not lazy_load: + import asyncio + + try: + loop = asyncio.get_event_loop() + if not loop.is_running(): + loop.run_until_complete(self.get_data()) + except Exception: + pass + + async def get_data(self) -> Any: + """获取数据 + + 返回: + Any: 缓存数据 + """ + # 检查是否需要更新 + now = datetime.now().timestamp() + if self._data is None or ( + self.expire > 0 and now - self._last_update > self.expire + ): + # 更新数据 + try: + self._data = await self.func() + self._last_update = now + except Exception as e: + logger.error(f"获取缓存数据 {self.name} 失败", LOG_COMMAND, e=e) + + return self._data + + async def set_data(self, data: Any) -> bool: + """设置数据 + + 参数: + data: 缓存数据 + + 返回: + bool: 是否成功 + """ + try: + self._data = data + self._last_update = datetime.now().timestamp() + # 如果有缓存后端,保存到缓存 + if self.cache and cache_config.cache_mode != CacheMode.NONE: + await self.cache.set(self.name, data, ttl=self.expire) # type: ignore + return True + except Exception as e: + logger.error(f"设置缓存数据 {self.name} 失败", LOG_COMMAND, e=e) + return False + + async def clear(self) -> bool: + """清除数据 + + 返回: + bool: 是否成功 + """ + try: + self._data = None + self._last_update = 0 + # 如果有缓存后端,清除缓存 + if self.cache and cache_config.cache_mode != CacheMode.NONE: + await self.cache.delete(self.name) # type: ignore + return True + except Exception as e: + logger.error(f"清除缓存数据 {self.name} 失败", LOG_COMMAND, e=e) + return False + + +class CacheManager: + """缓存管理器""" + + _instance: ClassVar["CacheManager | None"] = None + _cache_backend: BaseCache | AioCache | None = None + _registry: ClassVar[dict[str, CacheModel]] = {} + _data: ClassVar[dict[str, CacheData]] = {} + _enabled = False # 缓存启用标记 + + def __new__(cls) -> "CacheManager": + """单例模式""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @property + def enabled(self) -> bool: + """获取缓存启用状态""" + return self.__class__._enabled + + @enabled.setter + def enabled(self, value: bool): + """设置缓存启用状态""" + self.__class__._enabled = value + + def enable(self): + """启用缓存""" + self.__class__._enabled = True + logger.info("缓存功能已启用", LOG_COMMAND) + + def disable(self): + """禁用缓存""" + self.__class__._enabled = False + logger.info("缓存功能已禁用", LOG_COMMAND) + + def listener(self, cache_type: str): + """缓存监听器装饰器 + + 在方法调用后自动刷新缓存数据 + + 参数: + cache_type: 缓存类型 + + 返回: + Callable: 装饰器 + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(cls, *args, **kwargs): + # 执行原函数 + result = await func(cls, *args, **kwargs) + + obj = None + # 如果启用了缓存,自动刷新缓存 + if cache_config.cache_mode != CacheMode.NONE: + # 根据返回值类型处理 + if isinstance(result, tuple) and len(result) > 0: + # 处理返回元组的情况,如 update_or_create 返回 (obj, created) + obj = result[0] + else: + # 处理返回单个对象的情况 + obj = result + + # 获取缓存键并刷新缓存 + if ( + obj + and hasattr(cls, "get_cache_key") + and hasattr(obj, cls.get_cache_key_field()) + ): + key = cls.get_cache_key(obj) + if key is not None: + await self.invalidate_cache(cache_type, key) + + return result + + return wrapper + + return decorator + + async def get_cache(self, cache_type: str) -> Any: + """获取指定类型的缓存对象 + + 此方法返回一个简单的缓存对象,具有 update 方法 + + 参数: + cache_type: 缓存类型 + + 返回: + Any: 缓存对象 + """ + + class CacheAdapter: + """缓存适配器""" + + def __init__(self, cache_manager: CacheManager, cache_type: str): + self.cache_manager = cache_manager + self.cache_type = cache_type + + async def update(self, key: Any, value: Any) -> None: + """更新缓存 + + 参数: + key: 缓存键 + value: 缓存值 + """ + # 先清除旧缓存 + await self.cache_manager.invalidate_cache(self.cache_type, key) + + # 如果需要,可以在这里添加重新设置缓存的逻辑 + # 目前我们只清除缓存,让下次查询时自动重建 + + return ( + CacheAdapter(self, cache_type) + if cache_config.cache_mode != CacheMode.NONE + else None + ) + + @property + def cache_backend(self) -> BaseCache | AioCache: + """获取缓存后端""" + if self._cache_backend is None: + try: + from aiocache import RedisCache, SimpleMemoryCache + + if cache_config.cache_mode == CacheMode.NONE: + # 使用内存缓存但禁用持久化 + self._cache_backend = SimpleMemoryCache( + serializer=JsonSerializer(), + namespace=CACHE_KEY_PREFIX, + timeout=30, + ttl=0, # 设置为0,不缓存 + ) + logger.info("缓存功能已禁用,使用非持久化内存缓存", LOG_COMMAND) + elif ( + cache_config.cache_mode == CacheMode.REDIS + and cache_config.redis_host + ): + # 使用Redis缓存 + self._cache_backend = RedisCache( + serializer=JsonSerializer(), + namespace=CACHE_KEY_PREFIX, + timeout=30, + ttl=cache_config.redis_expire, + endpoint=cache_config.redis_host, + port=cache_config.redis_port, + password=cache_config.redis_password, + ) + logger.info( + f"使用Redis缓存,地址: {cache_config.redis_host}", LOG_COMMAND + ) + else: + # 默认使用内存缓存 + self._cache_backend = SimpleMemoryCache( + serializer=JsonSerializer(), + namespace=CACHE_KEY_PREFIX, + timeout=30, + ttl=cache_config.redis_expire, + ) + logger.info("使用内存缓存", LOG_COMMAND) + except ImportError: + logger.error("导入aiocache模块失败,使用内存缓存", LOG_COMMAND) + # 使用内存缓存 + self._cache_backend = AioCache( + cache_class=AioCache.MEMORY, + serializer=JsonSerializer(), + namespace=CACHE_KEY_PREFIX, + timeout=30, + ttl=cache_config.redis_expire, + ) + return self._cache_backend + + @property + def _cache(self) -> BaseCache | AioCache: + """获取缓存后端(别名)""" + return self.cache_backend + + async def get_cache_data(self, name: str) -> Any: + """获取缓存数据 + + 参数: + name: 缓存名称 + + 返回: + Any: 缓存数据 + """ + name = name.upper() + # 检查是否存在缓存数据 + if name in self._data: + return await self._data[name].get_data() + + # 尝试从缓存后端获取 + if cache_config.cache_mode != CacheMode.NONE: + try: + data = await self.cache_backend.get(name) # type: ignore + if data is not None: + return data + except Exception as e: + logger.error(f"从缓存后端获取数据 {name} 失败", LOG_COMMAND, e=e) + return None + + async def invalidate_cache( + self, cache_type: str, key: str | dict[str, Any] | None = None + ) -> bool: + """使指定类型的缓存失效 + + 当数据库中的数据发生变化时,调用此方法清除对应类型的缓存 + + 参数: + cache_type: 缓存类型 + key: 缓存键或键参数,为None时清除该类型的所有缓存 + + 返回: + bool: 是否成功 + """ + # 如果缓存被禁用或缓存模式为NONE,直接返回True + if not self.enabled or cache_config.cache_mode == CacheMode.NONE: + return True + + try: + if key is not None: + # 只清除特定的缓存项 + cache_key = self._build_key(cache_type, key) + await self.cache_backend.delete(cache_key) # type: ignore + logger.debug(f"清除缓存: {cache_type}, 键: {key}", LOG_COMMAND) + return True + else: + # 清除指定类型的所有缓存 + logger.debug(f"清除所有 {cache_type} 缓存", LOG_COMMAND) + return await self.clear(cache_type) + except Exception as e: + logger.error(f"清除缓存 {cache_type} 失败", LOG_COMMAND, e=e) + return False + + async def get( + self, cache_type: str, key: str | dict[str, Any], default: Any = None + ) -> Any: + """获取缓存数据 + + 参数: + cache_type: 缓存类型 + key: 键或键参数 + default: 默认值 + + 返回: + Any: 缓存数据,如果不存在返回默认值 + """ + # 如果缓存被禁用或缓存模式为NONE,直接返回默认值 + if not self.enabled or cache_config.cache_mode == CacheMode.NONE: + return default + + try: + cache_key = self._build_key(cache_type, key) + data = await self.cache_backend.get(cache_key) # type: ignore + + if data is None: + return default + + # 获取缓存模型 + model = self.get_model(cache_type) + + # 反序列化 + if model.result_type: + return self._deserialize_value(data, model.result_type) + return data + except Exception as e: + logger.error(f"获取缓存 {cache_type} 失败", LOG_COMMAND, e=e) + return default + + async def set( + self, + cache_type: str, + key: str | dict[str, Any], + value: Any, + expire: int | None = None, + ) -> bool: + """设置缓存数据 + + 参数: + cache_type: 缓存类型 + key: 键或键参数 + value: 值 + expire: 过期时间(秒),为None时使用默认值 + + 返回: + bool: 是否成功 + """ + # 如果缓存被禁用或缓存模式为NONE,直接返回False + if not self.enabled or cache_config.cache_mode == CacheMode.NONE: + return False + + try: + cache_key = self._build_key(cache_type, key) + model = self.get_model(cache_type) + + # 序列化 + serialized_value = self._serialize_value(value) + + # 设置过期时间 + ttl = expire if expire is not None else model.expire + + # 设置缓存 + await self.cache_backend.set(cache_key, serialized_value, ttl=ttl) # type: ignore + return True + except Exception as e: + logger.error(f"设置缓存 {cache_type} 失败", LOG_COMMAND, e=e) + return False + + async def delete(self, cache_type: str, key: str | dict[str, Any]) -> bool: + """删除缓存数据 + + 参数: + cache_type: 缓存类型 + key: 键或键参数 + + 返回: + bool: 是否成功 + """ + # 如果缓存被禁用或缓存模式为NONE,直接返回False + if not self.enabled or cache_config.cache_mode == CacheMode.NONE: + return False + + try: + cache_key = self._build_key(cache_type, key) + await self.cache_backend.delete(cache_key) # type: ignore + return True + except Exception as e: + logger.error(f"删除缓存 {cache_type} 失败", LOG_COMMAND, e=e) + return False + + async def exists(self, cache_type: str, key: str | dict[str, Any]) -> bool: + """检查缓存是否存在 + + 参数: + cache_type: 缓存类型 + key: 键或键参数 + + 返回: + bool: 是否存在 + """ + # 如果缓存被禁用或缓存模式为NONE,直接返回False + if not self.enabled or cache_config.cache_mode == CacheMode.NONE: + return False + + try: + cache_key = self._build_key(cache_type, key) + # 由于aiocache可能没有exists方法,使用get检查 + data = await self.cache_backend.get(cache_key) # type: ignore + return data is not None + except Exception as e: + logger.error(f"检查缓存 {cache_type} 是否存在失败", LOG_COMMAND, e=e) + return False + + async def clear(self, cache_type: str | None = None) -> bool: + """清除缓存 + + 参数: + cache_type: 缓存类型,为None时清除所有缓存 + + 返回: + bool: 是否成功 + """ + # 如果缓存被禁用或缓存模式为NONE,直接返回False + if not self.enabled or cache_config.cache_mode == CacheMode.NONE: + return False + + try: + if cache_type: + # 清除指定类型的缓存 + # pattern = f"{cache_type.upper()}{CACHE_KEY_SEPARATOR}*" + # 由于aiocache可能没有delete_pattern方法,使用其他方式清除 + # 这里简化处理,直接清除所有缓存 + await self.cache_backend.clear() # type: ignore + else: + # 清除所有缓存 + await self.cache_backend.clear() # type: ignore + return True + except Exception as e: + logger.error("清除缓存失败", LOG_COMMAND, e=e) + return False + + async def close(self): + """关闭缓存连接""" + if self._cache_backend: + try: + await self._cache_backend.close() # type: ignore + except (AttributeError, Exception) as e: + logger.warning(f"关闭缓存连接失败: {e}", LOG_COMMAND) + self._cache_backend = None + + def register( + self, + name: str, + result_type: type | None = None, + expire: int = DEFAULT_EXPIRE, + key_format: str | None = None, + ) -> None: + """注册缓存类型 + + 参数: + name: 缓存名称 + result_type: 结果类型 + expire: 过期时间(秒) + key_format: 键格式 + """ + name = name.upper() + if name in self._registry: + logger.warning(f"缓存类型 {name} 已存在,将被覆盖", LOG_COMMAND) + + # 检查是否有特殊键格式 + if not key_format and name in SPECIAL_KEY_FORMATS: + key_format = SPECIAL_KEY_FORMATS[name] + + self._registry[name] = CacheModel( + name=name, + expire=expire, + result_type=result_type, + key_format=key_format, + ) + logger.debug( + f"注册缓存类型: {name}, 类型: {result_type}, 过期时间: {expire}秒", + LOG_COMMAND, + ) + + def get_model(self, name: str) -> CacheModel: + """获取缓存模型 + + 参数: + name: 缓存名称 + + 返回: + CacheModel: 缓存模型 + + 异常: + CacheException: 缓存类型不存在 + """ + name = name.upper() + if name not in self._registry: + raise CacheException(f"缓存类型 {name} 不存在") + return self._registry[name] + + def _build_key(self, cache_type: str, key: str | dict[str, Any]) -> str: + """构建缓存键 + + 参数: + cache_type: 缓存类型 + key: 键或键参数 + + 返回: + str: 完整缓存键 + """ + cache_type = cache_type.upper() + if cache_type not in self._registry: + raise CacheException(f"缓存类型 {cache_type} 不存在") + + model = self._registry[cache_type] + + # 如果key是字典,使用键格式 + if isinstance(key, dict) and model.key_format: + try: + formatted_key = model.key_format.format(**key) + except KeyError as e: + raise CacheException(f"键格式错误: {model.key_format}, 缺少参数: {e}") + return f"{cache_type}{CACHE_KEY_SEPARATOR}{formatted_key}" + + # 否则直接使用key + return f"{cache_type}{CACHE_KEY_SEPARATOR}{key}" + + def _serialize_value(self, value: Any) -> Any: + """序列化值 + + 参数: + value: 需要序列化的值 + + 返回: + Any: 序列化后的值 + """ + if value is None: + return None + + # 处理datetime + if isinstance(value, datetime): + return value.isoformat() + + # 处理Tortoise-ORM Model + if hasattr(value, "_meta") and hasattr(value, "__dict__"): + result = {} + for field in value._meta.fields: + try: + field_value = getattr(value, field) + # 跳过反向关系字段 + if isinstance(field_value, list | set) and hasattr( + field_value, "_related_name" + ): + continue + # 跳过外键关系字段 + if hasattr(field_value, "_meta"): + field_value = getattr( + field_value, value._meta.fields[field].related_name or "id" + ) + result[field] = self._serialize_value(field_value) + except AttributeError: + continue + return result + + # 处理Pydantic模型 + elif isinstance(value, BaseModel): + return model_dump(value) + elif isinstance(value, dict): + # 处理字典 + return {str(k): self._serialize_value(v) for k, v in value.items()} + elif isinstance(value, list | tuple | set): + # 处理列表、元组、集合 + return [self._serialize_value(item) for item in value] + elif isinstance(value, int | float | str | bool): + # 基本类型直接返回 + return value + else: + # 其他类型转换为字符串 + return str(value) + + def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any: + """反序列化值 + + 参数: + value: 需要反序列化的值 + target_type: 目标类型 + + 返回: + Any: 反序列化后的值 + """ + if value is None: + return None + + # 如果是字典且指定了目标类型 + if isinstance(value, dict) and target_type: + # 处理Tortoise-ORM Model + if hasattr(target_type, "_meta"): + return self._deserialize_tortoise_model(value, target_type) + elif hasattr(target_type, "model_validate"): + return target_type.model_validate(value) + elif hasattr(target_type, "from_dict"): + return target_type.from_dict(value) + elif hasattr(target_type, "parse_obj"): + return target_type.parse_obj(value) + else: + return target_type(**value) + + # 处理列表类型 + if isinstance(value, list): + if not value: + return value + if ( + target_type + and hasattr(target_type, "__origin__") + and target_type.__origin__ is list + ): + item_type = target_type.__args__[0] + return [self._deserialize_value(item, item_type) for item in value] + return [self._deserialize_value(item) for item in value] + + # 处理字典类型 + if isinstance(value, dict): + return {k: self._deserialize_value(v) for k, v in value.items()} + + return value + + def _deserialize_tortoise_model(self, value: dict, target_type: type) -> Any: + """反序列化Tortoise-ORM模型 + + 参数: + value: 字典数据 + target_type: 目标类型 + + 返回: + Any: 反序列化后的模型实例 + """ + # 处理字段值 + processed_value = {} + for field_name, field_value in value.items(): + if field := target_type._meta.fields_map.get(field_name): + # 跳过反向关系字段 + if hasattr(field, "_related_name"): + continue + processed_value[field_name] = field_value + + # 创建模型实例 + instance = target_type() + # 设置字段值 + for field_name, field_value in processed_value.items(): + if field_name in target_type._meta.fields_map: + field = target_type._meta.fields_map[field_name] + # 设置字段值 + try: + if hasattr(field, "to_python_value"): + if not field.field_type: + logger.debug(f"字段 {field_name} 类型为空", LOG_COMMAND) + continue + field_value = field.to_python_value(field_value) + setattr(instance, field_name, field_value) + except Exception as e: + logger.warning(f"设置字段 {field_name} 失败", LOG_COMMAND, e=e) + + # 设置 _saved_in_db 标志 + instance._saved_in_db = True + return instance + + +# 全局缓存管理器实例 +CacheRoot = CacheManager() + + +class CacheRegistry: + """缓存注册器""" + + @staticmethod + def register( + name: str, + result_type: type | None = None, + expire: int = DEFAULT_EXPIRE, + key_format: str | None = None, + ): + """注册缓存类型 + + 参数: + name: 缓存名称 + result_type: 结果类型 + expire: 过期时间(秒) + key_format: 键格式 + """ + CacheRoot.register(name, result_type, expire, key_format) + + @staticmethod + def invalidate(cache_type: str, key: str | dict[str, Any]): + """使缓存失效的装饰器 + + 参数: + cache_type: 缓存类型 + key: 键或键参数 + + 返回: + Callable: 装饰器 + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(*args, **kwargs): + # 执行函数 + result = ( + await func(*args, **kwargs) + if is_coroutine_callable(func) + else func(*args, **kwargs) + ) + + # 删除缓存 + if cache_config.cache_mode != CacheMode.NONE: + await CacheRoot.delete(cache_type, key) + + return result + + return wrapper + + return decorator + + +class Cache(Generic[T]): + """类型化缓存访问接口 + + 示例: + ```python + from zhenxun.services.cache import Cache + from zhenxun.models.level_user import LevelUser + from zhenxun.utils.enum import CacheType + + # 创建缓存访问对象 + level_cache = Cache[list[LevelUser]](CacheType.LEVEL) + + # 获取缓存数据 + users = await level_cache.get({"user_id": "123", "group_id": "456"}) + + # 设置缓存数据 + await level_cache.set({"user_id": "123", "group_id": "456"}, users) + ``` + """ + + def __init__(self, cache_type: str): + """初始化缓存访问对象 + + 参数: + cache_type: 缓存类型 + """ + self.cache_type = cache_type.upper() + + # 尝试从类型注解获取结果类型 + try: + type_hints = get_type_hints(self.__class__) + if "T" in type_hints: + result_type = type_hints["T"] + # 确保缓存类型已注册 + try: + CacheRoot.get_model(self.cache_type) + except CacheException: + CacheRoot.register(self.cache_type, result_type) + except Exception: + pass + + async def get( + self, key: str | dict[str, Any], default: T | None = None + ) -> T | None: + """获取缓存数据 + + 参数: + key: 键或键参数 + default: 默认值 + + 返回: + T | None: 缓存数据,如果不存在返回默认值 + """ + return await CacheRoot.get(self.cache_type, key, default) + + async def set( + self, key: str | dict[str, Any], value: T, expire: int | None = None + ) -> bool: + """设置缓存数据 + + 参数: + key: 键或键参数 + value: 值 + expire: 过期时间(秒),为None时使用默认值 + + 返回: + bool: 是否成功 + """ + return await CacheRoot.set(self.cache_type, key, value, expire) + + async def delete(self, key: str | dict[str, Any]) -> bool: + """删除缓存数据 + + 参数: + key: 键或键参数 + + 返回: + bool: 是否成功 + """ + return await CacheRoot.delete(self.cache_type, key) + + async def exists(self, key: str | dict[str, Any]) -> bool: + """检查缓存是否存在 + + 参数: + key: 键或键参数 + + 返回: + bool: 是否存在 + """ + return await CacheRoot.exists(self.cache_type, key) + + async def clear(self) -> bool: + """清除此类型的所有缓存 + + 返回: + bool: 是否成功 + """ + return await CacheRoot.clear(self.cache_type) + + +@driver.on_startup +async def _(): + CacheRoot.enabled = True + logger.info("缓存系统已启用", LOG_COMMAND) + + +@driver.on_shutdown +async def _(): + await CacheRoot.close() diff --git a/zhenxun/services/cache/cache_containers.py b/zhenxun/services/cache/cache_containers.py new file mode 100644 index 00000000..91690e9a --- /dev/null +++ b/zhenxun/services/cache/cache_containers.py @@ -0,0 +1,424 @@ +from typing import Any + +from zhenxun.services.log import logger + +from .config import LOG_COMMAND + + +class CacheDict: + """全局缓存字典类,提供类似普通字典的接口,但数据可以在内存中共享""" + + def __init__(self, name: str, expire: int = 0): + """初始化缓存字典 + + 参数: + name: 字典名称 + expire: 过期时间(秒),默认为0表示永不过期 + """ + self.name = name.upper() + self.expire = expire + self._data = {} + # 自动尝试加载数据 + self._try_load() + + def _try_load(self): + """尝试加载数据(非异步)""" + try: + # 延迟导入,避免循环引用 + from zhenxun.services.cache import CacheRoot + + # 检查是否已有缓存数据 + if self.name in CacheRoot._data: + # 如果有,直接获取 + data = CacheRoot._data[self.name]._data + if isinstance(data, dict): + self._data = data + except Exception: + # 忽略错误,使用空字典 + pass + + async def load(self) -> bool: + """从缓存加载数据 + + 返回: + bool: 是否成功加载 + """ + try: + # 延迟导入,避免循环引用 + from zhenxun.services.cache import CacheRoot + + data = await CacheRoot.get_cache_data(self.name) + if isinstance(data, dict): + self._data = data + return True + return False + except Exception as e: + logger.error(f"加载缓存字典 {self.name} 失败", LOG_COMMAND, e=e) + return False + + async def save(self) -> bool: + """保存数据到缓存 + + 返回: + bool: 是否成功保存 + """ + try: + # 延迟导入,避免循环引用 + from zhenxun.services.cache import CacheData, CacheRoot + + # 检查缓存是否存在 + if self.name not in CacheRoot._data: + # 创建缓存 + async def get_func(): + return self._data + + CacheRoot._data[self.name] = CacheData( + name=self.name, + func=get_func, + expire=self.expire, + lazy_load=False, + cache=CacheRoot._cache, + ) + # 直接设置数据,避免调用func + CacheRoot._data[self.name]._data = self._data + else: + # 直接更新数据 + CacheRoot._data[self.name]._data = self._data + + # 保存数据 + await CacheRoot._data[self.name].set_data(self._data) + return True + except Exception as e: + logger.error(f"保存缓存字典 {self.name} 失败", LOG_COMMAND, e=e) + return False + + def __getitem__(self, key: str) -> Any: + """获取字典项 + + 参数: + key: 字典键 + + 返回: + Any: 字典值 + """ + return self._data.get(key) + + def __setitem__(self, key: str, value: Any) -> None: + """设置字典项 + + 参数: + key: 字典键 + value: 字典值 + """ + self._data[key] = value + + def __delitem__(self, key: str) -> None: + """删除字典项 + + 参数: + key: 字典键 + """ + if key in self._data: + del self._data[key] + + def __contains__(self, key: str) -> bool: + """检查键是否存在 + + 参数: + key: 字典键 + + 返回: + bool: 是否存在 + """ + return key in self._data + + def get(self, key: str, default: Any = None) -> Any: + """获取字典项,如果不存在返回默认值 + + 参数: + key: 字典键 + default: 默认值 + + 返回: + Any: 字典值或默认值 + """ + return self._data.get(key, default) + + def set(self, key: str, value: Any) -> None: + """设置字典项 + + 参数: + key: 字典键 + value: 字典值 + """ + self._data[key] = value + + def pop(self, key: str, default: Any = None) -> Any: + """删除并返回字典项 + + 参数: + key: 字典键 + default: 默认值 + + 返回: + Any: 字典值或默认值 + """ + return self._data.pop(key, default) + + def clear(self) -> None: + """清空字典""" + self._data.clear() + + def keys(self) -> list[str]: + """获取所有键 + + 返回: + list[str]: 键列表 + """ + return list(self._data.keys()) + + def values(self) -> list[Any]: + """获取所有值 + + 返回: + list[Any]: 值列表 + """ + return list(self._data.values()) + + def items(self) -> list[tuple[str, Any]]: + """获取所有键值对 + + 返回: + list[tuple[str, Any]]: 键值对列表 + """ + return list(self._data.items()) + + def __len__(self) -> int: + """获取字典长度 + + 返回: + int: 字典长度 + """ + return len(self._data) + + def __str__(self) -> str: + """字符串表示 + + 返回: + str: 字符串表示 + """ + return f"CacheDict({self.name}, {len(self._data)} items)" + + +class CacheList: + """全局缓存列表类,提供类似普通列表的接口,但数据可以在内存中共享""" + + def __init__(self, name: str, expire: int = 0): + """初始化缓存列表 + + 参数: + name: 列表名称 + expire: 过期时间(秒),默认为0表示永不过期 + """ + self.name = name.upper() + self.expire = expire + self._data = [] + # 自动尝试加载数据 + self._try_load() + + def _try_load(self): + """尝试加载数据(非异步)""" + try: + # 延迟导入,避免循环引用 + from zhenxun.services.cache import CacheRoot + + # 检查是否已有缓存数据 + if self.name in CacheRoot._data: + # 如果有,直接获取 + data = CacheRoot._data[self.name]._data + if isinstance(data, list): + self._data = data + except Exception: + # 忽略错误,使用空列表 + pass + + async def load(self) -> bool: + """从缓存加载数据 + + 返回: + bool: 是否成功加载 + """ + try: + # 延迟导入,避免循环引用 + from zhenxun.services.cache import CacheRoot + + data = await CacheRoot.get_cache_data(self.name) + if isinstance(data, list): + self._data = data + return True + return False + except Exception as e: + logger.error(f"加载缓存列表 {self.name} 失败", LOG_COMMAND, e=e) + return False + + async def save(self) -> bool: + """保存数据到缓存 + + 返回: + bool: 是否成功保存 + """ + try: + # 延迟导入,避免循环引用 + from zhenxun.services.cache import CacheData, CacheRoot + + # 检查缓存是否存在 + if self.name not in CacheRoot._data: + # 创建缓存 + async def get_func(): + return self._data + + CacheRoot._data[self.name] = CacheData( + name=self.name, + func=get_func, + expire=self.expire, + lazy_load=False, + cache=CacheRoot._cache, + ) + # 直接设置数据,避免调用func + CacheRoot._data[self.name]._data = self._data + else: + # 直接更新数据 + CacheRoot._data[self.name]._data = self._data + + # 保存数据 + await CacheRoot._data[self.name].set_data(self._data) + return True + except Exception as e: + logger.error(f"保存缓存列表 {self.name} 失败", LOG_COMMAND, e=e) + return False + + def __getitem__(self, index: int) -> Any: + """获取列表项 + + 参数: + index: 列表索引 + + 返回: + Any: 列表值 + """ + if 0 <= index < len(self._data): + return self._data[index] + raise IndexError(f"列表索引 {index} 超出范围") + + def __setitem__(self, index: int, value: Any) -> None: + """设置列表项 + + 参数: + index: 列表索引 + value: 列表值 + """ + # 确保索引有效 + while len(self._data) <= index: + self._data.append(None) + self._data[index] = value + + def __delitem__(self, index: int) -> None: + """删除列表项 + + 参数: + index: 列表索引 + """ + if 0 <= index < len(self._data): + del self._data[index] + else: + raise IndexError(f"列表索引 {index} 超出范围") + + def __len__(self) -> int: + """获取列表长度 + + 返回: + int: 列表长度 + """ + return len(self._data) + + def append(self, value: Any) -> None: + """添加列表项 + + 参数: + value: 列表值 + """ + self._data.append(value) + + def extend(self, values: list[Any]) -> None: + """扩展列表 + + 参数: + values: 要添加的值列表 + """ + self._data.extend(values) + + def insert(self, index: int, value: Any) -> None: + """插入列表项 + + 参数: + index: 插入位置 + value: 列表值 + """ + self._data.insert(index, value) + + def pop(self, index: int = -1) -> Any: + """删除并返回列表项 + + 参数: + index: 列表索引,默认为最后一项 + + 返回: + Any: 列表值 + """ + return self._data.pop(index) + + def remove(self, value: Any) -> None: + """删除第一个匹配的列表项 + + 参数: + value: 要删除的值 + """ + self._data.remove(value) + + def clear(self) -> None: + """清空列表""" + self._data.clear() + + def index(self, value: Any, start: int = 0, end: int | None = None) -> int: + """查找值的索引 + + 参数: + value: 要查找的值 + start: 起始索引 + end: 结束索引 + + 返回: + int: 索引位置 + """ + return self._data.index( + value, start, end if end is not None else len(self._data) + ) + + def count(self, value: Any) -> int: + """计算值出现的次数 + + 参数: + value: 要计数的值 + + 返回: + int: 出现次数 + """ + return self._data.count(value) + + def __str__(self) -> str: + """字符串表示 + + 返回: + str: 字符串表示 + """ + return f"CacheList({self.name}, {len(self._data)} items)" diff --git a/zhenxun/services/cache/config.py b/zhenxun/services/cache/config.py new file mode 100644 index 00000000..b974787b --- /dev/null +++ b/zhenxun/services/cache/config.py @@ -0,0 +1,35 @@ +""" +缓存系统配置 +""" + +# 日志标识 +LOG_COMMAND = "CacheRoot" + +# 默认缓存过期时间(秒) +DEFAULT_EXPIRE = 600 + +# 缓存键前缀 +CACHE_KEY_PREFIX = "ZHENXUN" + +# 缓存键分隔符 +CACHE_KEY_SEPARATOR = ":" + +# 复合键分隔符(用于分隔tuple类型的cache_key_field) +COMPOSITE_KEY_SEPARATOR = "_" + + +# 缓存模式 +class CacheMode: + # 内存缓存 - 使用内存存储缓存数据 + MEMORY = "MEMORY" + # Redis缓存 - 使用Redis服务器存储缓存数据 + REDIS = "REDIS" + # 不使用缓存 - 将使用ttl=0的内存缓存,相当于直接从数据库获取数据 + NONE = "NONE" + + +SPECIAL_KEY_FORMATS = { + "LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}", + "BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}", + "GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}", +} diff --git a/zhenxun/services/data_access.py b/zhenxun/services/data_access.py index 52992b28..71b41524 100644 --- a/zhenxun/services/data_access.py +++ b/zhenxun/services/data_access.py @@ -1,7 +1,10 @@ from typing import Any, Generic, TypeVar, cast -from zhenxun.services.cache import CacheRoot -from zhenxun.services.cache import config as cache_config +from zhenxun.services.cache import Cache, CacheRoot, cache_config +from zhenxun.services.cache.config import ( + COMPOSITE_KEY_SEPARATOR, + CacheMode, +) from zhenxun.services.db_context import Model from zhenxun.services.log import logger @@ -37,15 +40,89 @@ class DataAccess(Generic[T]): ``` """ - def __init__(self, model_cls: type[T], cache_type: str | None = None): + def __init__( + self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None + ): """初始化数据访问对象 参数: model_cls: 模型类 - cache_type: 缓存类型,如果为None则使用模型类的cache_type属性 + key_field: 主键字段 """ self.model_cls = model_cls - self.cache_type = cache_type or getattr(model_cls, "cache_type", None) + self.key_field = getattr(model_cls, "cache_key_field", key_field) + self.cache_type = getattr(model_cls, "cache_type", cache_type) + + if not self.cache_type: + raise ValueError("缓存类型不能为空") + self.cache = Cache(self.cache_type) + + def _build_cache_key_from_kwargs(self, **kwargs) -> str | None: + """从关键字参数构建缓存键 + + 参数: + **kwargs: 关键字参数 + + 返回: + str | None: 缓存键,如果无法构建则返回None + """ + if isinstance(self.key_field, tuple): + # 多字段主键 + key_parts = [] + for field in self.key_field: + key_parts.append(str(kwargs.get(field, ""))) + + if key_parts: + return COMPOSITE_KEY_SEPARATOR.join(key_parts) + return None + elif self.key_field in kwargs: + # 单字段主键 + return str(kwargs[self.key_field]) + return None + + async def safe_get_or_none(self, *args, **kwargs) -> T | None: + """安全的获取单条数据 + + 参数: + *args: 查询参数 + **kwargs: 查询参数 + + 返回: + Optional[T]: 查询结果,如果不存在返回None + """ + # 如果没有缓存类型,直接从数据库获取 + if not self.cache_type or cache_config.cache_mode == CacheMode.NONE: + return await self.model_cls.safe_get_or_none(*args, **kwargs) + + # 尝试从缓存获取 + try: + # 尝试构建缓存键 + cache_key = self._build_cache_key_from_kwargs(**kwargs) + + # 如果成功构建缓存键,尝试从缓存获取 + if cache_key is not None: + data = await self.cache.get(cache_key) + if data: + return cast(T, data) + except Exception as e: + logger.error("从缓存获取数据失败", e=e) + + # 如果缓存中没有,从数据库获取 + data = await self.model_cls.safe_get_or_none(*args, **kwargs) + + # 如果获取到数据,存入缓存 + if data: + try: + # 生成缓存键 + cache_key = self._build_cache_key_for_item(data) + if cache_key is not None: + # 存入缓存 + await self.cache.set(cache_key, data) + logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}") + except Exception as e: + logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e) + + return data async def get_or_none(self, *args, **kwargs) -> T | None: """获取单条数据 @@ -57,23 +134,169 @@ class DataAccess(Generic[T]): 返回: Optional[T]: 查询结果,如果不存在返回None """ - # 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取 - if not cache_config.enable_cache or not self.cache_type: - return await self.model_cls.safe_get_or_none(*args, **kwargs) + # 如果没有缓存类型,直接从数据库获取 + if not self.cache_type or cache_config.cache_mode == CacheMode.NONE: + return await self.model_cls.get_or_none(*args, **kwargs) - # 从缓存获取 + # 尝试从缓存获取 try: - # 生成缓存键 - key = self._generate_cache_key(kwargs) - # 尝试从缓存获取 - data = await CacheRoot.get(self.cache_type, key) - if data: - return cast(T, data) + # 尝试构建缓存键 + cache_key = self._build_cache_key_from_kwargs(**kwargs) + + # 如果成功构建缓存键,尝试从缓存获取 + if cache_key is not None: + data = await self.cache.get(cache_key) + if data: + return cast(T, data) except Exception as e: logger.error("从缓存获取数据失败", e=e) # 如果缓存中没有,从数据库获取 - return await self.model_cls.safe_get_or_none(*args, **kwargs) + data = await self.model_cls.get_or_none(*args, **kwargs) + + # 如果获取到数据,存入缓存 + if data: + try: + cache_key = self._build_cache_key_for_item(data) + # 生成缓存键 + if cache_key is not None: + # 存入缓存 + await self.cache.set(cache_key, data) + logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}") + except Exception as e: + logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e) + + return data + + async def clear_cache(self, **kwargs) -> bool: + """只清除缓存,不影响数据库数据 + + 参数: + **kwargs: 查询参数,必须包含主键字段 + + 返回: + bool: 是否成功清除缓存 + """ + # 如果没有缓存类型,直接返回True + if not self.cache_type or cache_config.cache_mode == CacheMode.NONE: + return True + + try: + # 构建缓存键 + cache_key = self._build_cache_key_from_kwargs(**kwargs) + if cache_key is None: + if isinstance(self.key_field, tuple): + # 如果是复合键,检查缺少哪些字段 + missing_fields = [ + field for field in self.key_field if field not in kwargs + ] + logger.error( + f"清除{self.model_cls.__name__}缓存失败: " + f"缺少主键字段 {', '.join(missing_fields)}" + ) + else: + logger.error( + f"清除{self.model_cls.__name__}缓存失败: " + f"缺少主键字段 {self.key_field}" + ) + return False + + # 删除缓存 + await self.cache.delete(cache_key) + logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}") + return True + except Exception as e: + logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e) + return False + + def _build_composite_key(self, data: T) -> str | None: + """构建复合缓存键 + + 参数: + data: 数据对象 + + 返回: + str | None: 构建的缓存键,如果无法构建则返回None + """ + # 如果是元组,表示多个字段组成键 + if isinstance(self.key_field, tuple): + # 构建键参数列表 + key_parts = [] + for field in self.key_field: + value = getattr(data, field, "") + key_parts.append(value if value is not None else "") + + # 如果没有有效参数,返回None + if not key_parts: + return None + + return COMPOSITE_KEY_SEPARATOR.join(key_parts) + + # 单个字段作为键 + elif hasattr(data, self.key_field): + value = getattr(data, self.key_field, None) + return str(value) if value is not None else None + + return None + + def _build_cache_key_for_item(self, item: T) -> str | None: + """为数据项构建缓存键 + + 参数: + item: 数据项 + + 返回: + str | None: 缓存键,如果无法生成则返回None + """ + # 如果没有缓存类型,返回None + if not self.cache_type: + return None + + # 获取缓存类型的配置信息 + cache_model = CacheRoot.get_model(self.cache_type) + + # 如果有键格式定义,则需要构建特殊格式的键 + if cache_model.key_format: + # 构建键参数字典 + key_parts = [] + # 从格式字符串中提取所需的字段名 + import re + + field_names = re.findall(r"{([^}]+)}", cache_model.key_format) + + # 收集所有字段值 + for field in field_names: + value = getattr(item, field, "") + key_parts.append(value if value is not None else "") + + return COMPOSITE_KEY_SEPARATOR.join(key_parts) + else: + # 常规处理,使用主键作为缓存键 + return self._build_composite_key(item) + + async def _cache_items(self, data_list: list[T]) -> None: + """将数据列表存入缓存 + + 参数: + data_list: 数据列表 + """ + if ( + not data_list + or not self.cache_type + or cache_config.cache_mode == CacheMode.NONE + ): + return + + try: + # 遍历数据列表,将每条数据存入缓存 + for item in data_list: + cache_key = self._build_cache_key_for_item(item) + if cache_key is not None: + await self.cache.set(cache_key, item) + + logger.debug(f"{self.cache_type} 数据已存入缓存,数量: {len(data_list)}") + except Exception as e: + logger.error(f"{self.cache_type} 数据存入缓存失败", e=e) async def filter(self, *args, **kwargs) -> list[T]: """筛选数据 @@ -85,30 +308,13 @@ class DataAccess(Generic[T]): 返回: List[T]: 查询结果列表 """ - # 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取 - if not cache_config.enable_cache or not self.cache_type: - return await self.model_cls.filter(*args, **kwargs) + # 从数据库获取数据 + data_list = await self.model_cls.filter(*args, **kwargs) - # 尝试从缓存获取所有数据后筛选 - try: - # 获取缓存数据 - cache_data = await CacheRoot.get_cache_data(self.cache_type) - if isinstance(cache_data, dict) and cache_data: - # 在内存中筛选 - filtered_data = [] - for item in cache_data.values(): - match = not any( - not hasattr(item, k) or getattr(item, k) != v - for k, v in kwargs.items() - ) - if match: - filtered_data.append(item) - return cast(list[T], filtered_data) - except Exception as e: - logger.error("从缓存筛选数据失败", e=e) + # 将数据存入缓存 + await self._cache_items(data_list) - # 如果缓存中没有或筛选失败,从数据库获取 - return await self.model_cls.filter(*args, **kwargs) + return data_list async def all(self) -> list[T]: """获取所有数据 @@ -116,21 +322,13 @@ class DataAccess(Generic[T]): 返回: List[T]: 所有数据列表 """ - # 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取 - if not cache_config.enable_cache or not self.cache_type: - return await self.model_cls.all() + # 直接从数据库获取 + data_list = await self.model_cls.all() - # 尝试从缓存获取所有数据 - try: - # 获取缓存数据 - cache_data = await CacheRoot.get_cache_data(self.cache_type) - if isinstance(cache_data, dict) and cache_data: - return cast(list[T], list(cache_data.values())) - except Exception as e: - logger.error("从缓存获取所有数据失败", e=e) + # 将数据存入缓存 + await self._cache_items(data_list) - # 如果缓存中没有,从数据库获取 - return await self.model_cls.all() + return data_list async def count(self, *args, **kwargs) -> int: """获取数据数量 @@ -167,7 +365,24 @@ class DataAccess(Generic[T]): 返回: T: 创建的数据 """ - return await self.model_cls.create(**kwargs) + # 创建数据 + data = await self.model_cls.create(**kwargs) + + # 如果有缓存类型,将数据存入缓存 + if self.cache_type and cache_config.cache_mode != CacheMode.NONE: + try: + # 生成缓存键 + cache_key = self._build_cache_key_for_item(data) + if cache_key is not None: + # 存入缓存 + await self.cache.set(cache_key, data) + logger.debug( + f"{self.cache_type} 新创建的数据已存入缓存: {cache_key}" + ) + except Exception as e: + logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e) + + return data async def update_or_create( self, defaults: dict[str, Any] | None = None, **kwargs @@ -181,7 +396,24 @@ class DataAccess(Generic[T]): 返回: tuple[T, bool]: (数据, 是否创建) """ - return await self.model_cls.update_or_create(defaults=defaults, **kwargs) + # 更新或创建数据 + data, created = await self.model_cls.update_or_create( + defaults=defaults, **kwargs + ) + + # 如果有缓存类型,将数据存入缓存 + if self.cache_type and cache_config.cache_mode != CacheMode.NONE: + try: + # 生成缓存键 + cache_key = self._build_cache_key_for_item(data) + if cache_key is not None: + # 存入缓存 + await self.cache.set(cache_key, data) + logger.debug(f"更新或创建的数据已存入缓存: {cache_key}") + except Exception as e: + logger.error(f"存入缓存失败,参数: {kwargs}", e=e) + + return data, created async def delete(self, *args, **kwargs) -> int: """删除数据 @@ -193,18 +425,43 @@ class DataAccess(Generic[T]): 返回: int: 删除的数据数量 """ + # 如果有缓存类型且有key_field参数,先尝试删除缓存 + if self.cache_type and cache_config.cache_mode != CacheMode.NONE: + try: + # 尝试构建缓存键 + cache_key = self._build_cache_key_from_kwargs(**kwargs) + + if cache_key is not None: + # 如果成功构建缓存键,直接删除缓存 + await self.cache.delete(cache_key) + logger.debug(f"已删除缓存: {cache_key}") + else: + # 否则需要先查询出要删除的数据,然后删除对应的缓存 + items = await self.model_cls.filter(*args, **kwargs) + for item in items: + item_cache_key = self._build_cache_key_for_item(item) + if item_cache_key is not None: + await self.cache.delete(item_cache_key) + if items: + logger.debug(f"已删除{len(items)}条数据的缓存") + except Exception as e: + logger.error("删除缓存失败", e=e) + + # 删除数据 return await self.model_cls.filter(*args, **kwargs).delete() - def _generate_cache_key(self, kwargs: dict[str, Any]) -> str: - """根据查询参数生成缓存键 + def _generate_cache_key(self, data: T) -> str: + """根据数据对象生成缓存键 参数: - kwargs: 查询参数 + data: 数据对象 返回: str: 缓存键 """ - # 实现一个简单的键生成算法 - if not kwargs: - return "default" - return "_".join(f"{k}:{v}" for k, v in sorted(kwargs.items())) + # 使用新方法构建复合键 + if composite_key := self._build_composite_key(data): + return composite_key + + # 如果无法生成复合键,生成一个唯一键 + return f"object_{id(data)}" diff --git a/zhenxun/services/db_context.py b/zhenxun/services/db_context.py index d4022b94..a9fc9380 100644 --- a/zhenxun/services/db_context.py +++ b/zhenxun/services/db_context.py @@ -16,6 +16,7 @@ from zhenxun.utils.exception import HookPriorityException from zhenxun.utils.manager.priority_manager import PriorityLifecycle from .cache import CacheRoot +from .cache.config import COMPOSITE_KEY_SEPARATOR from .log import logger SCRIPT_METHOD = [] @@ -23,14 +24,6 @@ MODELS: list[str] = [] driver = nonebot.get_driver() -CACHE_FLAG = False - - -@driver.on_bot_connect -def _(): - global CACHE_FLAG - CACHE_FLAG = True - class Model(TortoiseModel): """ @@ -56,8 +49,55 @@ class Model(TortoiseModel): return cls.sem_data.get(cls.__module__, {}).get(lock_type, None) @classmethod - def get_cache_type(cls): - return getattr(cls, "cache_type", None) if CACHE_FLAG else None + def get_cache_type(cls) -> str | None: + return getattr(cls, "cache_type", None) + + @classmethod + def get_cache_key_field(cls) -> str | tuple[str]: + """获取缓存键字段名 + + 返回: + str | tuple[str]: 缓存键字段名,可能是单个字段名或字段名元组 + """ + if hasattr(cls, "cache_key_field"): + return getattr(cls, "cache_key_field", "id") + return "id" + + @classmethod + def get_cache_key(cls, instance) -> str | None: + """获取缓存键 + + 参数: + instance: 模型实例 + + 返回: + str | None + """ + key_field = cls.get_cache_key_field() + + # 如果是元组,表示多个字段组成键 + if isinstance(key_field, tuple): + # 构建键参数列表 + key_parts = [] + for field in key_field: + if hasattr(instance, field): + value = getattr(instance, field, None) + key_parts.append(value if value is not None else "") + else: + # 如果缺少任何必要的字段,返回None + return None + + # 如果没有有效参数,返回None + if not key_parts: + return None + + return COMPOSITE_KEY_SEPARATOR.join(str(param) for param in key_parts) + + # 单个字段作为键 + elif hasattr(instance, key_field): + return getattr(instance, key_field, None) + + return None @classmethod async def create( @@ -79,7 +119,11 @@ class Model(TortoiseModel): defaults=defaults, using_db=using_db, **kwargs ) if is_create and (cache_type := cls.get_cache_type()): - await CacheRoot.reload(cache_type) + # 获取缓存键 + key = cls.get_cache_key(result) + await CacheRoot.invalidate_cache( + cache_type, key if key is not None else None + ) return (result, is_create) else: # 如果没有锁,则执行原来的逻辑 @@ -87,7 +131,11 @@ class Model(TortoiseModel): defaults=defaults, using_db=using_db, **kwargs ) if is_create and (cache_type := cls.get_cache_type()): - await CacheRoot.reload(cache_type) + # 获取缓存键 + key = cls.get_cache_key(result) + await CacheRoot.invalidate_cache( + cache_type, key if key is not None else None + ) return (result, is_create) @classmethod @@ -104,7 +152,11 @@ class Model(TortoiseModel): defaults=defaults, using_db=using_db, **kwargs ) if cache_type := cls.get_cache_type(): - await CacheRoot.reload(cache_type) + # 获取缓存键 + key = cls.get_cache_key(result[0]) + await CacheRoot.invalidate_cache( + cache_type, key if key is not None else None + ) return result else: # 如果没有锁,则执行原来的逻辑 @@ -112,19 +164,23 @@ class Model(TortoiseModel): defaults=defaults, using_db=using_db, **kwargs ) if cache_type := cls.get_cache_type(): - await CacheRoot.reload(cache_type) + # 获取缓存键 + key = cls.get_cache_key(result[0]) + await CacheRoot.invalidate_cache( + cache_type, key if key is not None else None + ) return result @classmethod async def bulk_create( # type: ignore cls, - objects: Iterable[Self], + objects: Iterable[Self], # type: ignore batch_size: int | None = None, ignore_conflicts: bool = False, update_fields: Iterable[str] | None = None, on_conflict: Iterable[str] | None = None, using_db: BaseDBAsyncClient | None = None, - ) -> list[Self]: + ) -> list[Self]: # type: ignore result = await super().bulk_create( objects=objects, batch_size=batch_size, @@ -134,17 +190,18 @@ class Model(TortoiseModel): using_db=using_db, ) if cache_type := cls.get_cache_type(): - await CacheRoot.reload(cache_type) + # 批量创建时清除整个类型的缓存 + await CacheRoot.invalidate_cache(cache_type) return result @classmethod async def bulk_update( # type: ignore cls, - objects: Iterable[Self], + objects: Iterable[Self], # type: ignore fields: Iterable[str], batch_size: int | None = None, using_db: BaseDBAsyncClient | None = None, - ) -> int: + ) -> int: # type: ignore result = await super().bulk_update( objects=objects, fields=fields, @@ -152,7 +209,8 @@ class Model(TortoiseModel): using_db=using_db, ) if cache_type := cls.get_cache_type(): - await CacheRoot.reload(cache_type) + # 批量更新时清除整个类型的缓存 + await CacheRoot.invalidate_cache(cache_type) return result async def save( @@ -181,13 +239,24 @@ class Model(TortoiseModel): force_create=force_create, force_update=force_update, ) - if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)): - await CacheRoot.reload(cache_type) + if cache_type := getattr(self, "cache_type", None): + # 获取缓存键 + key = self.__class__.get_cache_key(self) + await CacheRoot.invalidate_cache(cache_type, key) async def delete(self, using_db: BaseDBAsyncClient | None = None): + # 在删除前获取缓存键 + cache_type = getattr(self, "cache_type", None) + key = None + if cache_type: + key = self.__class__.get_cache_key(self) + + # 执行删除操作 await super().delete(using_db=using_db) - if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)): - await CacheRoot.reload(cache_type) + + # 清除缓存 + if cache_type: + await CacheRoot.invalidate_cache(cache_type, key) @classmethod async def safe_get_or_none(