♻️ 重构cache

This commit is contained in:
HibiKier 2025-07-10 17:10:07 +08:00
parent 7d962bb4a8
commit 56162e24ea
24 changed files with 2091 additions and 1109 deletions

View File

@ -27,6 +27,8 @@ QBOT_ID_DATA = '{
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
DB_URL = ""
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
CACHE_MODE = NONE
# REDIS配置使用REDIS替换Cache内存缓存
# REDIS地址
# REDIS_HOST = "127.0.0.1"
@ -50,7 +52,7 @@ PLATFORM_SUPERUSERS = '
DRIVER=~fastapi+~httpx+~websockets
# LOG_LEVEL=DEBUG
# LOG_LEVEL = DEBUG
# 服务器和端口
HOST = 127.0.0.1
PORT = 8080

View File

@ -3,8 +3,7 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.utils.enum import CacheType
from zhenxun.services.data_access import DataAccess
from zhenxun.utils.utils import get_entity_ids
from .exception import SkipPluginException
@ -21,15 +20,21 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
if not plugin.admin_level:
return
entity = get_entity_ids(session)
cache = Cache[list[LevelUser]](CacheType.LEVEL)
user_list = await cache.get(session.user.id) or []
level_dao = DataAccess(LevelUser)
global_user = await level_dao.safe_get_or_none(
user_id=session.user.id, group_id__isnull=True
)
user_level = 0
if global_user:
user_level = global_user.user_level
if entity.group_id:
user_list += await cache.get(session.user.id, entity.group_id) or []
if user_list:
user = max(user_list, key=lambda x: x.user_level)
user_level = user.user_level
else:
user_level = 0
# 获取用户在当前群组的权限数据
group_users = await level_dao.safe_get_or_none(
user_id=session.user.id, group_id=entity.group_id
)
if group_users:
user_level = max(user_level, group_users.user_level)
if user_level < plugin.admin_level:
await send_message(
session,
@ -42,11 +47,12 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)
elif user_list:
user = max(user_list, key=lambda x: x.user_level)
if user.user_level < plugin.admin_level:
elif global_user:
if global_user.user_level < plugin.admin_level:
await send_message(
session,
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
)
raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...")
raise SkipPluginException(
f"{plugin.name}({plugin.module}) 管理员权限不足..."
)

View File

@ -1,20 +1,15 @@
import asyncio
from nonebot.adapters import Bot
from nonebot.matcher import Matcher
from nonebot_plugin_alconna import At
from nonebot_plugin_uninfo import Uninfo
from tortoise.exceptions import MultipleObjectsReturned
from zhenxun.configs.config import Config
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType, PluginType
from zhenxun.services.data_access import DataAccess
from zhenxun.utils.enum import PluginType
from zhenxun.utils.utils import EntityIDs, get_entity_ids
from .config import LOGGER_COMMAND
from .exception import SkipPluginException
from .utils import freq, send_message
@ -29,10 +24,18 @@ Config.add_plugin_config(
async def is_ban(user_id: str | None, group_id: str | None) -> int:
if not user_id and not group_id:
return 0
cache = Cache[BanConsole](CacheType.BAN)
group_user, user = await asyncio.gather(
cache.get(user_id, group_id), cache.get(user_id)
)
ban_dao = DataAccess(BanConsole)
# 分别获取用户在群组中的ban记录和全局ban记录
group_user = None
user = None
if user_id and group_id:
group_user = await ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id)
if user_id:
user = await ban_dao.safe_get_or_none(user_id=user_id, group_id="")
results = []
if group_user:
results.append(group_user)
@ -88,76 +91,55 @@ def format_time(time: float) -> str:
return time_str
async def group_handle(cache: Cache[list[BanConsole]], group_id: str):
async def group_handle(group_id: str):
"""群组ban检查
参数:
cache: cache
ban_dao: BanConsole数据访问对象
group_id: 群组id
异常:
SkipPluginException: 群组处于黑名单
"""
try:
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"群组黑名单数据重复过滤该次hook并移除多余数据...", LOGGER_COMMAND
)
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
if await is_ban(None, group_id):
raise SkipPluginException("群组处于黑名单中...")
async def user_handle(
module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo
):
async def user_handle(module: str, entity: EntityIDs, session: Uninfo):
"""用户ban检查
参数:
module: 插件模块名
cache: cache
user_id: 用户id
ban_dao: BanConsole数据访问对象
entity: 实体ID信息
session: Uninfo
异常:
SkipPluginException: 用户处于黑名单
"""
ban_result = Config.get_config("hook", "BAN_RESULT")
try:
time = await is_ban(entity.user_id, entity.group_id)
if not time:
return
time_str = format_time(time)
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
if (
db_plugin
# and not db_plugin.ignore_prompt
and time != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
):
await send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
)
raise SkipPluginException("用户处于黑名单中...")
except MultipleObjectsReturned:
logger.warning(
"用户黑名单数据重复过滤该次hook并移除多余数据...", LOGGER_COMMAND
time = await is_ban(entity.user_id, entity.group_id)
if not time:
return
time_str = format_time(time)
plugin_dao = DataAccess(PluginInfo)
db_plugin = await plugin_dao.safe_get_or_none(module=module)
if (
db_plugin
and not db_plugin.ignore_prompt
and time != -1
and ban_result
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
):
await send_message(
session,
[
At(flag="user", target=entity.user_id),
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
],
entity.user_id,
)
ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list(
"id", flat=True
)
await BanConsole.filter(id__in=ids[:-1]).delete()
await cache.reload()
raise SkipPluginException("用户处于黑名单中...")
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
@ -168,8 +150,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
entity = get_entity_ids(session)
if entity.user_id in bot.config.superusers:
return
cache = Cache[list[BanConsole]](CacheType.BAN)
if entity.group_id:
await group_handle(cache, entity.group_id)
await group_handle(entity.group_id)
if entity.user_id:
await user_handle(matcher.plugin_name, cache, entity, session)
await user_handle(matcher.plugin_name, entity, session)

View File

@ -1,8 +1,7 @@
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.data_access import DataAccess
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import CacheType
from .exception import SkipPluginException
@ -18,11 +17,11 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
SkipPluginException: 忽略插件
SkipPluginException: 忽略插件
"""
if cache := Cache[BotConsole](CacheType.BOT):
bot = await cache.get(bot_id)
if not bot or not bot.status:
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
if CommonUtils.format(plugin.module) in bot.block_plugins:
raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
)
bot_dao = DataAccess(BotConsole)
bot = await bot_dao.safe_get_or_none(bot_id=bot_id)
if not bot or not bot.status:
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
if CommonUtils.format(plugin.module) in bot.block_plugins:
raise SkipPluginException(
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
)

View File

@ -11,6 +11,7 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
"""检测是否满足金币条件
参数:
user: UserConsole
plugin: PluginInfo
session: Uninfo

View File

@ -2,8 +2,7 @@ from nonebot_plugin_alconna import UniMsg
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.utils.enum import CacheType
from zhenxun.services.data_access import DataAccess
from zhenxun.utils.utils import EntityIDs
from .config import SwitchEnum
@ -21,7 +20,10 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
if not entity.group_id:
return
text = message.extract_plain_text()
group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id)
group_dao = DataAccess(GroupConsole)
group = await group_dao.safe_get_or_none(
group_id=entity.group_id, channel_id__isnull=True
)
if not group:
raise SkipPluginException("群组信息不存在...")
if group.level < 0:

View File

@ -3,9 +3,9 @@ from nonebot_plugin_uninfo import Uninfo
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.services.cache import Cache
from zhenxun.services.data_access import DataAccess
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import BlockType, CacheType
from zhenxun.utils.enum import BlockType
from zhenxun.utils.utils import get_entity_ids
from .exception import IsSuperuserException, SkipPluginException
@ -20,10 +20,12 @@ class GroupCheck:
self.session = session
self.is_poke = is_poke
self.plugin = plugin
self.group_dao = DataAccess(GroupConsole)
async def __get_data(self):
cache = Cache[GroupConsole](CacheType.GROUPS)
return await cache.get(self.group_id)
return await self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
)
async def check(self):
await self.check_superuser_block(self.plugin)
@ -89,6 +91,7 @@ class PluginCheck:
self.session = session
self.is_poke = is_poke
self.group_id = group_id
self.group_dao = DataAccess(GroupConsole)
async def check_user(self, plugin: PluginInfo):
"""全局私聊禁用检测
@ -118,9 +121,11 @@ class PluginCheck:
if plugin.status or plugin.block_type != BlockType.ALL:
return
"""全局状态"""
cache = Cache[GroupConsole](CacheType.GROUPS)
if self.group_id and (group := await cache.get(self.group_id)):
if group.is_super:
if self.group_id:
group = await self.group_dao.safe_get_or_none(
group_id=self.group_id, channel_id__isnull=True
)
if group and group.is_super:
raise IsSuperuserException()
sid = self.group_id or self.session.user.id
if freq.is_send_limit_message(plugin, sid, self.is_poke):

View File

@ -9,13 +9,9 @@ from tortoise.exceptions import IntegrityError
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import Cache
from zhenxun.services.data_access import DataAccess
from zhenxun.services.log import logger
from zhenxun.utils.enum import (
CacheType,
GoldHandle,
PluginType,
)
from zhenxun.utils.enum import GoldHandle, PluginType
from zhenxun.utils.exception import InsufficientGold
from zhenxun.utils.platform import PlatformUtils
from zhenxun.utils.utils import get_entity_ids
@ -54,8 +50,9 @@ async def get_plugin_and_user(
返回:
tuple[PluginInfo, UserConsole]: 插件信息用户信息
"""
user_cache = Cache[UserConsole](CacheType.USERS)
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
user_dao = DataAccess(UserConsole)
plugin_dao = DataAccess(PluginInfo)
plugin = await plugin_dao.safe_get_or_none(module=module)
if not plugin:
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
if plugin.plugin_type == PluginType.HIDDEN:
@ -64,7 +61,7 @@ async def get_plugin_and_user(
)
user = None
try:
user = await user_cache.get(user_id)
user = await user_dao.safe_get_or_none(user_id=user_id)
except IntegrityError as e:
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
if not user:
@ -108,7 +105,7 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo
cost_gold: 消耗金币
session: Uninfo
"""
user_cache = Cache[UserConsole](CacheType.USERS)
user_dao = DataAccess(UserConsole)
try:
await UserConsole.reduce_gold(
user_id,
@ -121,8 +118,8 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo
if u := await UserConsole.get_user(user_id):
u.gold = 0
await u.save(update_fields=["gold"])
# 更新缓存
await user_cache.update(user_id)
# 清除缓存,使下次查询时从数据库获取最新数据
await user_dao.clear_cache(user_id=user_id)
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)

View File

@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType
from zhenxun.utils.manager.message_manager import MessageManager
from zhenxun.utils.platform import PlatformUtils
LOG_COMMAND = "MessageHook"
def replace_message(message: Message) -> str:
"""将消息中的at、image、record、face替换为字符串
@ -54,11 +56,11 @@ async def handle_api_result(
if user_id and message_id:
MessageManager.add(str(user_id), str(message_id))
logger.debug(
f"收集消息iduser_id: {user_id}, msg_id: {message_id}", "msg_hook"
f"收集消息iduser_id: {user_id}, msg_id: {message_id}", LOG_COMMAND
)
except Exception as e:
logger.warning(
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e
)
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
return
@ -80,6 +82,6 @@ async def handle_api_result(
except Exception as e:
logger.warning(
f"消息发送记录发生错误...data: {data}, result: {result}",
"msg_hook",
LOG_COMMAND,
e=e,
)

View File

@ -4,23 +4,25 @@ import nonebot
from nonebot.adapters import Bot
from zhenxun.models.group_console import GroupConsole
from zhenxun.services.cache import DbCacheException
from zhenxun.services.cache import CacheException
from zhenxun.services.log import logger
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from zhenxun.utils.platform import PlatformUtils
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
try:
from .__init_cache import CacheRoot
except DbCacheException as e:
from .__init_cache import register_cache_types
except CacheException as e:
raise SystemError(f"ERROR{e}")
driver = nonebot.get_driver()
@driver.on_startup
@PriorityLifecycle.on_startup(priority=5)
async def _():
await CacheRoot.init_non_lazy_caches()
register_cache_types()
logger.info("缓存类型注册完成")
@driver.on_bot_connect

View File

@ -1,208 +1,35 @@
"""
缓存初始化模块
负责注册各种缓存类型实现按需缓存机制
"""
from zhenxun.models.ban_console import BanConsole
from zhenxun.models.bot_console import BotConsole
from zhenxun.models.group_console import GroupConsole
from zhenxun.models.level_user import LevelUser
from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.user_console import UserConsole
from zhenxun.services.cache import CacheData, CacheRoot
from zhenxun.services.cache import CacheRegistry, cache_config
from zhenxun.services.cache.config import CacheMode
from zhenxun.services.log import logger
from zhenxun.utils.enum import CacheType
@CacheRoot.new(CacheType.PLUGINS)
async def _():
"""初始化插件缓存"""
data_list = await PluginInfo.get_plugins()
return {p.module: p for p in data_list}
# 注册缓存类型
def register_cache_types():
"""注册所有缓存类型"""
CacheRegistry.register(CacheType.PLUGINS, PluginInfo)
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
CacheRegistry.register(CacheType.BOT, BotConsole)
CacheRegistry.register(CacheType.USERS, UserConsole)
CacheRegistry.register(
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
)
CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}")
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
async def _(cache_data: CacheData, module: str):
"""获取插件缓存"""
data = await cache_data.get_key(module)
if not data:
if plugin := await PluginInfo.get_plugin(module=module):
await cache_data.set_key(module, plugin)
logger.debug(f"插件 {module} 数据已设置到缓存")
return plugin
return data
@CacheRoot.with_refresh(CacheType.PLUGINS)
async def _(cache_data: CacheData, data: dict[str, PluginInfo] | None):
"""刷新插件缓存"""
if not data:
return
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
for plugin in plugins:
await cache_data.set_key(plugin.module, plugin)
@CacheRoot.new(CacheType.GROUPS)
async def _():
"""初始化群组缓存"""
data_list = await GroupConsole.all()
return {p.group_id: p for p in data_list if not p.channel_id}
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
async def _(cache_data: CacheData, group_id: str):
"""获取群组缓存"""
data = await cache_data.get_key(group_id)
if not data:
if group := await GroupConsole.get_group(group_id=group_id):
await cache_data.set_key(group_id, group)
return group
return data
@CacheRoot.with_refresh(CacheType.GROUPS)
async def _(cache_data: CacheData, data: dict[str, GroupConsole] | None):
"""刷新群组缓存"""
if not data:
return
groups = await GroupConsole.filter(
group_id__in=data.keys(), channel_id__isnull=True
).all()
for group in groups:
await cache_data.set_key(group.group_id, group)
@CacheRoot.new(CacheType.BOT)
async def _():
"""初始化机器人缓存"""
data_list = await BotConsole.all()
return {p.bot_id: p for p in data_list}
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
async def _(cache_data: CacheData, bot_id: str):
"""获取机器人缓存"""
data = await cache_data.get_key(bot_id)
if not data:
if bot := await BotConsole.get_or_none(bot_id=bot_id):
await cache_data.set_key(bot_id, bot)
return bot
return data
@CacheRoot.with_refresh(CacheType.BOT)
async def _(cache_data: CacheData, data: dict[str, BotConsole] | None):
"""刷新机器人缓存"""
if not data:
return
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
for bot in bots:
await cache_data.set_key(bot.bot_id, bot)
@CacheRoot.new(CacheType.USERS)
async def _():
"""初始化用户缓存"""
data_list = await UserConsole.all()
return {p.user_id: p for p in data_list}
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
async def _(cache_data: CacheData, user_id: str):
"""获取用户缓存"""
data = await cache_data.get_key(user_id)
if not data:
if user := await UserConsole.get_user(user_id=user_id):
await cache_data.set_key(user_id, user)
return user
return data
@CacheRoot.with_refresh(CacheType.USERS)
async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
"""刷新用户缓存"""
if not data:
return
users = await UserConsole.filter(user_id__in=data.keys()).all()
for user in users:
await cache_data.set_key(user.user_id, user)
@CacheRoot.new(CacheType.LEVEL)
async def _():
"""初始化等级缓存"""
data_list = await LevelUser().all()
return {f"{d.user_id}:{d.group_id or ''}": d for d in data_list}
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
"""获取等级缓存"""
key = f"{user_id}:{group_id or ''}"
data = await cache_data.get_key(key)
if not data:
if group_id:
data = await LevelUser.filter(user_id=user_id, group_id=group_id).all()
else:
data = await LevelUser.filter(user_id=user_id, group_id__isnull=True).all()
if data:
await cache_data.set_key(key, data)
return data
return data or []
@CacheRoot.new(CacheType.BAN, False)
async def _():
"""初始化封禁缓存"""
data_list = await BanConsole.all()
return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list}
@CacheRoot.getter(CacheType.BAN, result_model=BanConsole)
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
"""获取封禁缓存"""
if not user_id and not group_id:
return []
key = f"{group_id or ''}:{user_id or ''}"
data = await cache_data.get_key(key)
# if not data:
# start = time.time()
# if user_id and group_id:
# data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
# elif user_id:
# data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
# elif group_id:
# data = await BanConsole.filter(
# user_id__isnull=True, group_id=group_id
# ).all()
# logger.info(
# f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}"
# )
# if data:
# await cache_data.set_key(key, data)
# return data
return data or []
# @CacheRoot.new(CacheType.LIMIT)
# async def _():
# """初始化限制缓存"""
# data_list = await PluginLimit.filter(status=True).all()
# return {data.module: data for data in data_list}
# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
# async def _(cache_data: CacheData, module: str):
# """获取限制缓存"""
# data = await cache_data.get_key(module)
# if not data:
# if limits := await PluginLimit.filter(module=module, status=True):
# await cache_data.set_key(module, limits)
# return limits
# return data or []
# @CacheRoot.with_refresh(CacheType.LIMIT)
# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
# """刷新限制缓存"""
# if not data:
# return
# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
# for limit in limits:
# await cache_data.set_key(limit.module, limit)
if cache_config.cache_mode == CacheMode.NONE:
logger.info("缓存功能已禁用,将直接从数据库获取数据")
else:
logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}")
logger.info("使用增量缓存模式,数据将按需加载到缓存中")

View File

@ -205,7 +205,7 @@ class Manager:
self.cd_data: dict[str, PluginCdBlock] = {}
if self.cd_file.exists():
with open(self.cd_file, encoding="utf8") as f:
temp = _yaml.load(f)
temp = _yaml.load(f) or {}
if "PluginCdLimit" in temp.keys():
for k, v in temp["PluginCdLimit"].items():
if "." in k:
@ -216,7 +216,7 @@ class Manager:
self.block_data: dict[str, BaseBlock] = {}
if self.block_file.exists():
with open(self.block_file, encoding="utf8") as f:
temp = _yaml.load(f)
temp = _yaml.load(f) or {}
if "PluginBlockLimit" in temp.keys():
for k, v in temp["PluginBlockLimit"].items():
if "." in k:
@ -227,7 +227,7 @@ class Manager:
self.count_data: dict[str, PluginCountBlock] = {}
if self.count_file.exists():
with open(self.count_file, encoding="utf8") as f:
temp = _yaml.load(f)
temp = _yaml.load(f) or {}
if "PluginCountLimit" in temp.keys():
for k, v in temp["PluginCountLimit"].items():
if "." in k:

View File

@ -33,6 +33,8 @@ class BanConsole(Model):
cache_type = CacheType.BAN
"""缓存类型"""
cache_key_field = ("user_id", "group_id")
"""缓存键字段"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""

View File

@ -31,6 +31,9 @@ class BotConsole(Model):
table_description = "Bot数据表"
cache_type = CacheType.BOT
"""缓存类型"""
cache_key_field = "bot_id"
"""缓存键字段"""
@staticmethod
def format(name: str) -> str:

View File

@ -90,6 +90,8 @@ class GroupConsole(Model):
cache_type = CacheType.GROUPS
"""缓存类型"""
cache_key_field = ("group_id", "channel_id")
"""缓存键字段"""
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
"""开启锁"""
@ -123,7 +125,18 @@ class GroupConsole(Model):
)
@classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def _update_cache(cls, instance):
"""更新缓存
参数:
instance: 需要更新缓存的实例
"""
if cache_type := cls.get_cache_type():
key = cls.get_cache_key(instance)
if key is not None:
await CacheRoot.invalidate_cache(cache_type, key)
@classmethod
async def create(
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
) -> Self:
@ -136,6 +149,9 @@ class GroupConsole(Model):
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
# 更新缓存
await cls._update_cache(group)
return group
@classmethod
@ -187,14 +203,13 @@ class GroupConsole(Model):
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
# 更新缓存
if is_create:
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
await cache.update(group.group_id, group)
await cls._update_cache(group)
return group, is_create
@classmethod
@CacheRoot.listener(CacheType.GROUPS)
async def update_or_create(
cls,
defaults: dict | None = None,
@ -214,6 +229,9 @@ class GroupConsole(Model):
if task_modules or plugin_modules:
await cls._update_modules(group, task_modules, plugin_modules, using_db)
# 更新缓存
await cls._update_cache(group)
return group, is_create
@classmethod
@ -327,6 +345,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
async def set_unblock_plugin(
cls,
@ -363,6 +384,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
async def is_normal_block_plugin(
cls, group_id: str, module: str, channel_id: str | None = None
@ -466,6 +490,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
async def set_unblock_task(
cls,
@ -500,6 +527,9 @@ class GroupConsole(Model):
if update_fields:
await group.save(update_fields=update_fields)
# 更新缓存
await cls._update_cache(group)
@classmethod
def _run_script(cls):
return [

View File

@ -22,6 +22,9 @@ class LevelUser(Model):
unique_together = ("user_id", "group_id")
cache_type = CacheType.LEVEL
"""缓存类型"""
cache_key_field = ("user_id", "group_id")
"""缓存键字段"""
@classmethod
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:

View File

@ -60,6 +60,9 @@ class PluginInfo(Model):
table_description = "插件基本信息"
cache_type = CacheType.PLUGINS
"""缓存类型"""
cache_key_field = "module"
"""缓存键字段"""
@classmethod
async def get_plugin(

View File

@ -31,6 +31,9 @@ class UserConsole(Model):
table_description = "用户数据表"
cache_type = CacheType.USERS
"""缓存类型"""
cache_key_field = "user_id"
"""缓存键字段"""
@classmethod
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":

View File

@ -1,703 +0,0 @@
from collections.abc import Callable
from datetime import datetime
from functools import wraps
import inspect
from typing import Any, ClassVar, Generic, TypeVar
from aiocache import Cache as AioCache
# from aiocache.backends.redis import RedisCache
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer
import nonebot
from nonebot.compat import model_dump
from nonebot.utils import is_coroutine_callable
from pydantic import BaseModel
from zhenxun.services.log import logger
__all__ = ["Cache", "CacheData", "CacheRoot"]
T = TypeVar("T")
LOG_COMMAND = "cache"
driver = nonebot.get_driver()
class Config(BaseModel):
enable_cache: bool = True
"""是否开启缓存功能"""
redis_host: str | None = None
"""redis地址"""
redis_port: int | None = None
"""redis端口"""
redis_password: str | None = None
"""redis密码"""
redis_expire: int = 600
"""redis过期时间"""
config = nonebot.get_plugin_config(Config)
class DbCacheException(Exception):
"""缓存相关异常"""
def __init__(self, info: str):
self.info = info
def __str__(self) -> str:
return self.info
def validate_name(func: Callable):
"""验证缓存名称是否存在的装饰器"""
def wrapper(self, name: str, *args, **kwargs):
_name = name.upper()
if _name not in CacheManager._data:
raise DbCacheException(f"缓存数据 {name} 不存在")
return func(self, _name, *args, **kwargs)
return wrapper
class CacheGetter(BaseModel, Generic[T]):
"""缓存数据获取器"""
get_func: Callable[..., Any] | None = None
get_all_func: Callable[..., Any] | None = None
async def get(self, cache_data: "CacheData", key: str, *args, **kwargs) -> T:
"""获取单个缓存数据"""
if not self.get_func:
data = await cache_data.get_key(key)
if cache_data.result_model:
return cache_data._deserialize_value(data, cache_data.result_model)
return data
if is_coroutine_callable(self.get_func):
data = await self.get_func(cache_data, key, *args, **kwargs)
else:
data = self.get_func(cache_data, key, *args, **kwargs)
if cache_data.result_model:
return cache_data._deserialize_value(data, cache_data.result_model)
return data
async def get_all(self, cache_data: "CacheData", *args, **kwargs) -> dict[str, T]:
"""获取所有缓存数据"""
if not self.get_all_func:
data = await cache_data.get_all_data()
if cache_data.result_model:
return {
k: cache_data._deserialize_value(v, cache_data.result_model)
for k, v in data.items()
}
return data
if is_coroutine_callable(self.get_all_func):
data = await self.get_all_func(cache_data, *args, **kwargs)
else:
data = self.get_all_func(cache_data, *args, **kwargs)
if cache_data.result_model:
return {
k: cache_data._deserialize_value(v, cache_data.result_model)
for k, v in data.items()
}
return data
class CacheData(BaseModel):
"""缓存数据模型"""
name: str
func: Callable[..., Any]
getter: CacheGetter | None = None
updater: Callable[..., Any] | None = None
with_refresh: Callable[..., Any] | None = None
expire: int = 600 # 默认10分钟过期
reload_count: int = 0
lazy_load: bool = True # 默认延迟加载
result_model: type | None = None
_keys: set[str] = set() # 存储所有缓存键
cache: BaseCache | AioCache
class Config:
arbitrary_types_allowed = True
def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any:
"""反序列化值将JSON数据转换回原始类型
参数:
value: 需要反序列化的值
target_type: 目标类型用于指导反序列化
返回:
反序列化后的值
"""
if value is None:
return None
# 如果是字典且指定了目标类型
if isinstance(value, dict) and target_type:
# 处理Tortoise-ORM Model
if hasattr(target_type, "_meta"):
return self._extracted_from__deserialize_value_19(value, target_type)
elif hasattr(target_type, "model_validate"):
return target_type.model_validate(value)
elif hasattr(target_type, "from_dict"):
return target_type.from_dict(value)
elif hasattr(target_type, "parse_obj"):
return target_type.parse_obj(value)
else:
return target_type(**value)
# 处理列表类型
if isinstance(value, list):
if not value:
return value
if (
target_type
and hasattr(target_type, "__origin__")
and target_type.__origin__ is list
):
item_type = target_type.__args__[0]
return [self._deserialize_value(item, item_type) for item in value]
return [self._deserialize_value(item) for item in value]
# 处理字典类型
if isinstance(value, dict):
return {k: self._deserialize_value(v) for k, v in value.items()}
return value
def _extracted_from__deserialize_value_19(self, value, target_type):
# 处理字段值
processed_value = {}
for field_name, field_value in value.items():
if field := target_type._meta.fields_map.get(field_name):
# 跳过反向关系字段
if hasattr(field, "_related_name"):
continue
processed_value[field_name] = field_value
logger.debug(f"处理后的值: {processed_value}")
# 创建模型实例
instance = target_type()
# 设置字段值
for field_name, field_value in processed_value.items():
if field_name in target_type._meta.fields_map:
field = target_type._meta.fields_map[field_name]
# 设置字段值
try:
if hasattr(field, "to_python_value"):
if not field.field_type:
logger.debug(f"字段 {field_name} 类型为空")
continue
field_value = field.to_python_value(field_value)
setattr(instance, field_name, field_value)
except Exception as e:
logger.warning(f"设置字段 {field_name} 失败", e=e)
# 设置 _saved_in_db 标志
instance._saved_in_db = True
return instance
async def get_data(self) -> Any:
"""从缓存获取数据"""
try:
data = await self.cache.get(self.name) # type: ignore
logger.debug(f"获取缓存 {self.name} 数据: {data}")
# 如果数据为空,尝试重新加载
# if data is None:
# logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载")
# try:
# if self.has_args():
# new_data = (
# await self.func()
# if is_coroutine_callable(self.func)
# else self.func()
# )
# else:
# new_data = (
# await self.func()
# if is_coroutine_callable(self.func)
# else self.func()
# )
# await self.set_data(new_data)
# self.reload_count += 1
# logger.info(f"重新加载缓存 {self.name} 完成")
# return new_data
# except Exception as e:
# logger.error(f"重新加载缓存 {self.name} 失败: {e}")
# return None
# 使用 result_model 进行反序列化
if self.result_model:
return self._deserialize_value(data, self.result_model)
return data
except Exception as e:
logger.error(f"获取缓存 {self.name} 失败: {e}")
return None
def _serialize_value(self, value: Any) -> Any:
"""序列化值将数据转换为JSON可序列化的格式
参数:
value: 需要序列化的值
返回:
JSON可序列化的值
"""
if value is None:
return None
# 处理datetime
if isinstance(value, datetime):
return value.isoformat()
# 处理Tortoise-ORM Model
if hasattr(value, "_meta") and hasattr(value, "__dict__"):
result = {}
for field in value._meta.fields:
try:
field_value = getattr(value, field)
# 跳过反向关系字段
if isinstance(field_value, list | set) and hasattr(
field_value, "_related_name"
):
continue
# 跳过外键关系字段
if hasattr(field_value, "_meta"):
field_value = getattr(
field_value, value._meta.fields[field].related_name or "id"
)
result[field] = self._serialize_value(field_value)
except AttributeError:
continue
return result
# 处理Pydantic模型
elif isinstance(value, BaseModel):
return model_dump(value)
elif isinstance(value, dict):
# 处理字典
return {str(k): self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, list | tuple | set):
# 处理列表、元组、集合
return [self._serialize_value(item) for item in value]
elif isinstance(value, int | float | str | bool):
# 基本类型直接返回
return value
else:
# 其他类型转换为字符串
return str(value)
async def set_data(self, value: Any):
"""设置缓存数据"""
try:
# 1. 序列化数据
serialized_value = self._serialize_value(value)
logger.debug(f"设置缓存 {self.name} 原始数据: {value}")
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
# 2. 删除旧数据
await self.cache.delete(self.name) # type: ignore
logger.debug(f"删除缓存 {self.name} 旧数据")
# 3. 设置新数据
await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore
logger.debug(f"设置缓存 {self.name} 新数据完成")
except Exception as e:
logger.error(f"设置缓存 {self.name} 失败: {e}")
raise # 重新抛出异常,让上层处理
async def delete_data(self):
"""删除缓存数据"""
try:
await self.cache.delete(self.name) # type: ignore
except Exception as e:
logger.error(f"删除缓存 {self.name}", e=e)
async def get(self, key: str, *args, **kwargs) -> Any:
"""获取缓存"""
if not self.reload_count and not self.lazy_load:
await self.reload(*args, **kwargs)
if not self.getter:
return await self.get_key(key)
return await self.getter.get(self, key, *args, **kwargs)
async def get_all(self, *args, **kwargs) -> dict[str, Any]:
"""获取所有缓存数据"""
if not self.reload_count and not self.lazy_load:
await self.reload(*args, **kwargs)
if not self.getter:
return await self.get_all_data()
return await self.getter.get_all(self, *args, **kwargs)
async def update(self, key: str, value: Any = None, *args, **kwargs):
"""更新单个缓存项"""
if not self.updater:
logger.warning(f"缓存 {self.name} 未配置更新方法")
return
current_data = await self.get_key(key) or {}
if is_coroutine_callable(self.updater):
await self.updater(current_data, key, value, *args, **kwargs)
else:
self.updater(current_data, key, value, *args, **kwargs)
await self.set_key(key, current_data)
logger.debug(f"更新缓存 {self.name}.{key}")
async def refresh(self, *args, **kwargs):
"""刷新缓存数据"""
if not self.with_refresh:
return await self.reload(*args, **kwargs)
current_data = await self.get_data()
if current_data:
if is_coroutine_callable(self.with_refresh):
await self.with_refresh(current_data, *args, **kwargs)
else:
self.with_refresh(current_data, *args, **kwargs)
await self.set_data(current_data)
logger.debug(f"刷新缓存 {self.name}")
async def reload(self, *args, **kwargs):
"""重新加载全部数据"""
try:
if self.has_args():
new_data = (
await self.func(*args, **kwargs)
if is_coroutine_callable(self.func)
else self.func(*args, **kwargs)
)
else:
new_data = (
await self.func()
if is_coroutine_callable(self.func)
else self.func()
)
# 如果是字典,则分别存储每个键值对
if isinstance(new_data, dict):
for key, value in new_data.items():
await self.set_key(key, value)
else:
# 如果不是字典,则存储为单个键值对
await self.set_key("default", new_data)
self.reload_count += 1
logger.info(f"重新加载缓存 {self.name} 完成")
except Exception as e:
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
raise
def has_args(self) -> bool:
"""检查函数是否需要参数"""
sig = inspect.signature(self.func)
return any(
param.kind
in (
param.POSITIONAL_OR_KEYWORD,
param.POSITIONAL_ONLY,
param.VAR_POSITIONAL,
)
for param in sig.parameters.values()
)
async def get_key(self, key: str) -> Any:
"""获取缓存中指定键的数据
参数:
key: 要获取的键名
返回:
键对应的值如果不存在返回None
"""
cache_key = self._get_cache_key(key)
try:
data = await self.cache.get(cache_key) # type: ignore
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
if self.result_model:
return self._deserialize_value(data, self.result_model)
return data
except Exception as e:
logger.error(f"获取缓存 {cache_key} 失败: {e}")
return None
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
"""获取缓存中多个键的数据
参数:
keys: 要获取的键名列表
返回:
包含所有请求键值的字典不存在的键值为None
"""
try:
data = await self.get_data()
if isinstance(data, dict):
return {key: data.get(key) for key in keys}
return dict.fromkeys(keys)
except Exception as e:
logger.error(f"获取缓存 {self.name} 的多个键失败: {e}")
return dict.fromkeys(keys)
def _get_cache_key(self, key: str) -> str:
"""获取缓存键名"""
return f"{self.name}:{key}"
async def get_all_data(self) -> dict[str, Any]:
"""获取所有缓存数据"""
try:
result = {}
for key in self._keys:
# 提取原始键名(去掉前缀)
original_key = key.split(":", 1)[1]
data = await self.cache.get(key) # type: ignore
if self.result_model:
result[original_key] = self._deserialize_value(
data, self.result_model
)
else:
result[original_key] = data
return result
except Exception as e:
logger.error(f"获取所有缓存数据失败: {e}")
return {}
async def set_key(self, key: str, value: Any):
"""设置指定键的缓存数据"""
cache_key = self._get_cache_key(key)
try:
serialized_value = self._serialize_value(value)
await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore
self._keys.add(cache_key) # 添加到键列表
logger.debug(f"设置缓存 {cache_key} 数据完成")
except Exception as e:
logger.error(f"设置缓存 {cache_key} 失败: {e}")
raise
async def delete_key(self, key: str):
"""删除指定键的缓存数据"""
cache_key = self._get_cache_key(key)
try:
await self.cache.delete(cache_key) # type: ignore
self._keys.discard(cache_key) # 从键列表中移除
logger.debug(f"删除缓存 {cache_key} 完成")
except Exception as e:
logger.error(f"删除缓存 {cache_key} 失败: {e}")
async def clear(self):
"""清除所有缓存数据"""
try:
for key in list(self._keys): # 使用列表复制避免在迭代时修改
await self.cache.delete(key) # type: ignore
self._keys.clear()
logger.debug(f"清除缓存 {self.name} 完成")
except Exception as e:
logger.error(f"清除缓存 {self.name} 失败: {e}")
class CacheManager:
"""全局缓存管理器"""
_cache_instance: BaseCache | AioCache | None = None
_data: ClassVar[dict[str, CacheData]] = {}
@property
def _cache(self) -> BaseCache | AioCache:
"""获取aiocache实例"""
if self._cache_instance is None:
if config.redis_host:
self._cache_instance = AioCache(
AioCache.REDIS, # type: ignore
serializer=JsonSerializer(),
namespace="zhenxun_cache",
timeout=30, # 操作超时时间
ttl=config.redis_expire, # 设置默认过期时间
endpoint=config.redis_host,
port=config.redis_port,
password=config.redis_password,
)
else:
self._cache_instance = AioCache(
AioCache.MEMORY,
serializer=JsonSerializer(),
namespace="zhenxun_cache",
timeout=30, # 操作超时时间
ttl=config.redis_expire, # 设置默认过期时间
)
logger.info("初始化缓存完成...", LOG_COMMAND)
return self._cache_instance
async def close(self):
if self._cache_instance:
await self._cache_instance.close() # type: ignore
async def verify_connection(self):
"""连接测试"""
try:
await self._cache.get("__test__") # type: ignore
except Exception as e:
logger.error("连接失败", LOG_COMMAND, e=e)
raise
async def init_non_lazy_caches(self):
"""初始化所有非延迟加载的缓存"""
await self.verify_connection()
for name, cache in self._data.items():
cache.cache = self._cache
if not cache.lazy_load:
try:
await cache.reload()
logger.info(f"初始化缓存 {name} 完成")
except Exception as e:
logger.error(f"初始化缓存 {name} 失败: {e}")
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
"""注册新缓存
参数:
name: 缓存名称
lazy_load: 是否延迟加载默认为True为False时会在程序启动时自动加载
expire: 过期时间
"""
def wrapper(func: Callable):
_name = name.upper()
if _name in self._data:
raise DbCacheException(f"缓存 {name} 已存在")
self._data[_name] = CacheData(
name=_name,
func=func,
expire=expire,
lazy_load=lazy_load,
cache=self._cache,
)
return func
return wrapper
def listener(self, name: str):
"""创建缓存监听器"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return (
await func(*args, **kwargs)
if is_coroutine_callable(func)
else func(*args, **kwargs)
)
finally:
cache = self._data.get(name.upper())
if cache and cache.with_refresh:
await cache.refresh()
logger.debug(f"监听器触发缓存 {name} 刷新")
return wrapper
return decorator
@validate_name
def updater(self, name: str):
"""设置缓存更新方法"""
def wrapper(func: Callable):
self._data[name].updater = func
return func
return wrapper
@validate_name
def getter(self, name: str, result_model: type):
"""设置缓存获取方法"""
def wrapper(func: Callable):
self._data[name].getter = CacheGetter[result_model](get_func=func)
self._data[name].result_model = result_model
return func
return wrapper
@validate_name
def with_refresh(self, name: str):
"""设置缓存刷新方法"""
def wrapper(func: Callable):
self._data[name].with_refresh = func
return func
return wrapper
async def get_cache_data(self, name: str) -> Any | None:
"""获取缓存数据"""
cache = await self.get_cache(name.upper())
return await cache.get_data() if cache else None
async def get_cache(self, name: str) -> CacheData | None:
"""获取缓存对象"""
return self._data.get(name.upper())
async def get(self, name: str, *args, **kwargs) -> Any:
"""获取缓存内容"""
cache = await self.get_cache(name.upper())
return await cache.get(*args, **kwargs) if cache else None
async def update(self, name: str, key: str, value: Any = None, *args, **kwargs):
"""更新缓存项"""
cache = await self.get_cache(name.upper())
if cache:
await cache.update(key, value, *args, **kwargs)
async def reload(self, name: str, *args, **kwargs):
"""重新加载缓存"""
cache = await self.get_cache(name.upper())
if cache:
await cache.reload(*args, **kwargs)
# 全局缓存管理器实例
CacheRoot = CacheManager()
class Cache(Generic[T]):
"""类型化缓存访问接口"""
def __init__(self, module: str):
self.module = module.upper()
async def get(self, *args, **kwargs) -> T | None:
"""获取缓存"""
return await CacheRoot.get(self.module, *args, **kwargs)
async def update(self, key: str, value: Any = None, *args, **kwargs):
"""更新缓存项"""
await CacheRoot.update(self.module, key, value, *args, **kwargs)
async def reload(self, *args, **kwargs):
"""重新加载缓存"""
await CacheRoot.reload(self.module, *args, **kwargs)
@driver.on_shutdown
async def _():
await CacheRoot.close()

1032
zhenxun/services/cache/__init__.py vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,424 @@
from typing import Any
from zhenxun.services.log import logger
from .config import LOG_COMMAND
class CacheDict:
"""全局缓存字典类,提供类似普通字典的接口,但数据可以在内存中共享"""
def __init__(self, name: str, expire: int = 0):
"""初始化缓存字典
参数:
name: 字典名称
expire: 过期时间默认为0表示永不过期
"""
self.name = name.upper()
self.expire = expire
self._data = {}
# 自动尝试加载数据
self._try_load()
def _try_load(self):
"""尝试加载数据(非异步)"""
try:
# 延迟导入,避免循环引用
from zhenxun.services.cache import CacheRoot
# 检查是否已有缓存数据
if self.name in CacheRoot._data:
# 如果有,直接获取
data = CacheRoot._data[self.name]._data
if isinstance(data, dict):
self._data = data
except Exception:
# 忽略错误,使用空字典
pass
async def load(self) -> bool:
"""从缓存加载数据
返回:
bool: 是否成功加载
"""
try:
# 延迟导入,避免循环引用
from zhenxun.services.cache import CacheRoot
data = await CacheRoot.get_cache_data(self.name)
if isinstance(data, dict):
self._data = data
return True
return False
except Exception as e:
logger.error(f"加载缓存字典 {self.name} 失败", LOG_COMMAND, e=e)
return False
async def save(self) -> bool:
"""保存数据到缓存
返回:
bool: 是否成功保存
"""
try:
# 延迟导入,避免循环引用
from zhenxun.services.cache import CacheData, CacheRoot
# 检查缓存是否存在
if self.name not in CacheRoot._data:
# 创建缓存
async def get_func():
return self._data
CacheRoot._data[self.name] = CacheData(
name=self.name,
func=get_func,
expire=self.expire,
lazy_load=False,
cache=CacheRoot._cache,
)
# 直接设置数据避免调用func
CacheRoot._data[self.name]._data = self._data
else:
# 直接更新数据
CacheRoot._data[self.name]._data = self._data
# 保存数据
await CacheRoot._data[self.name].set_data(self._data)
return True
except Exception as e:
logger.error(f"保存缓存字典 {self.name} 失败", LOG_COMMAND, e=e)
return False
def __getitem__(self, key: str) -> Any:
"""获取字典项
参数:
key: 字典键
返回:
Any: 字典值
"""
return self._data.get(key)
def __setitem__(self, key: str, value: Any) -> None:
"""设置字典项
参数:
key: 字典键
value: 字典值
"""
self._data[key] = value
def __delitem__(self, key: str) -> None:
"""删除字典项
参数:
key: 字典键
"""
if key in self._data:
del self._data[key]
def __contains__(self, key: str) -> bool:
"""检查键是否存在
参数:
key: 字典键
返回:
bool: 是否存在
"""
return key in self._data
def get(self, key: str, default: Any = None) -> Any:
"""获取字典项,如果不存在返回默认值
参数:
key: 字典键
default: 默认值
返回:
Any: 字典值或默认值
"""
return self._data.get(key, default)
def set(self, key: str, value: Any) -> None:
"""设置字典项
参数:
key: 字典键
value: 字典值
"""
self._data[key] = value
def pop(self, key: str, default: Any = None) -> Any:
"""删除并返回字典项
参数:
key: 字典键
default: 默认值
返回:
Any: 字典值或默认值
"""
return self._data.pop(key, default)
def clear(self) -> None:
"""清空字典"""
self._data.clear()
def keys(self) -> list[str]:
"""获取所有键
返回:
list[str]: 键列表
"""
return list(self._data.keys())
def values(self) -> list[Any]:
"""获取所有值
返回:
list[Any]: 值列表
"""
return list(self._data.values())
def items(self) -> list[tuple[str, Any]]:
"""获取所有键值对
返回:
list[tuple[str, Any]]: 键值对列表
"""
return list(self._data.items())
def __len__(self) -> int:
"""获取字典长度
返回:
int: 字典长度
"""
return len(self._data)
def __str__(self) -> str:
"""字符串表示
返回:
str: 字符串表示
"""
return f"CacheDict({self.name}, {len(self._data)} items)"
class CacheList:
"""全局缓存列表类,提供类似普通列表的接口,但数据可以在内存中共享"""
def __init__(self, name: str, expire: int = 0):
"""初始化缓存列表
参数:
name: 列表名称
expire: 过期时间默认为0表示永不过期
"""
self.name = name.upper()
self.expire = expire
self._data = []
# 自动尝试加载数据
self._try_load()
def _try_load(self):
"""尝试加载数据(非异步)"""
try:
# 延迟导入,避免循环引用
from zhenxun.services.cache import CacheRoot
# 检查是否已有缓存数据
if self.name in CacheRoot._data:
# 如果有,直接获取
data = CacheRoot._data[self.name]._data
if isinstance(data, list):
self._data = data
except Exception:
# 忽略错误,使用空列表
pass
async def load(self) -> bool:
"""从缓存加载数据
返回:
bool: 是否成功加载
"""
try:
# 延迟导入,避免循环引用
from zhenxun.services.cache import CacheRoot
data = await CacheRoot.get_cache_data(self.name)
if isinstance(data, list):
self._data = data
return True
return False
except Exception as e:
logger.error(f"加载缓存列表 {self.name} 失败", LOG_COMMAND, e=e)
return False
async def save(self) -> bool:
"""保存数据到缓存
返回:
bool: 是否成功保存
"""
try:
# 延迟导入,避免循环引用
from zhenxun.services.cache import CacheData, CacheRoot
# 检查缓存是否存在
if self.name not in CacheRoot._data:
# 创建缓存
async def get_func():
return self._data
CacheRoot._data[self.name] = CacheData(
name=self.name,
func=get_func,
expire=self.expire,
lazy_load=False,
cache=CacheRoot._cache,
)
# 直接设置数据避免调用func
CacheRoot._data[self.name]._data = self._data
else:
# 直接更新数据
CacheRoot._data[self.name]._data = self._data
# 保存数据
await CacheRoot._data[self.name].set_data(self._data)
return True
except Exception as e:
logger.error(f"保存缓存列表 {self.name} 失败", LOG_COMMAND, e=e)
return False
def __getitem__(self, index: int) -> Any:
"""获取列表项
参数:
index: 列表索引
返回:
Any: 列表值
"""
if 0 <= index < len(self._data):
return self._data[index]
raise IndexError(f"列表索引 {index} 超出范围")
def __setitem__(self, index: int, value: Any) -> None:
"""设置列表项
参数:
index: 列表索引
value: 列表值
"""
# 确保索引有效
while len(self._data) <= index:
self._data.append(None)
self._data[index] = value
def __delitem__(self, index: int) -> None:
"""删除列表项
参数:
index: 列表索引
"""
if 0 <= index < len(self._data):
del self._data[index]
else:
raise IndexError(f"列表索引 {index} 超出范围")
def __len__(self) -> int:
"""获取列表长度
返回:
int: 列表长度
"""
return len(self._data)
def append(self, value: Any) -> None:
"""添加列表项
参数:
value: 列表值
"""
self._data.append(value)
def extend(self, values: list[Any]) -> None:
"""扩展列表
参数:
values: 要添加的值列表
"""
self._data.extend(values)
def insert(self, index: int, value: Any) -> None:
"""插入列表项
参数:
index: 插入位置
value: 列表值
"""
self._data.insert(index, value)
def pop(self, index: int = -1) -> Any:
"""删除并返回列表项
参数:
index: 列表索引默认为最后一项
返回:
Any: 列表值
"""
return self._data.pop(index)
def remove(self, value: Any) -> None:
"""删除第一个匹配的列表项
参数:
value: 要删除的值
"""
self._data.remove(value)
def clear(self) -> None:
"""清空列表"""
self._data.clear()
def index(self, value: Any, start: int = 0, end: int | None = None) -> int:
"""查找值的索引
参数:
value: 要查找的值
start: 起始索引
end: 结束索引
返回:
int: 索引位置
"""
return self._data.index(
value, start, end if end is not None else len(self._data)
)
def count(self, value: Any) -> int:
"""计算值出现的次数
参数:
value: 要计数的值
返回:
int: 出现次数
"""
return self._data.count(value)
def __str__(self) -> str:
"""字符串表示
返回:
str: 字符串表示
"""
return f"CacheList({self.name}, {len(self._data)} items)"

35
zhenxun/services/cache/config.py vendored Normal file
View File

@ -0,0 +1,35 @@
"""
缓存系统配置
"""
# 日志标识
LOG_COMMAND = "CacheRoot"
# 默认缓存过期时间(秒)
DEFAULT_EXPIRE = 600
# 缓存键前缀
CACHE_KEY_PREFIX = "ZHENXUN"
# 缓存键分隔符
CACHE_KEY_SEPARATOR = ":"
# 复合键分隔符用于分隔tuple类型的cache_key_field
COMPOSITE_KEY_SEPARATOR = "_"
# 缓存模式
class CacheMode:
# 内存缓存 - 使用内存存储缓存数据
MEMORY = "MEMORY"
# Redis缓存 - 使用Redis服务器存储缓存数据
REDIS = "REDIS"
# 不使用缓存 - 将使用ttl=0的内存缓存相当于直接从数据库获取数据
NONE = "NONE"
SPECIAL_KEY_FORMATS = {
"LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
"BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
"GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}",
}

View File

@ -1,7 +1,10 @@
from typing import Any, Generic, TypeVar, cast
from zhenxun.services.cache import CacheRoot
from zhenxun.services.cache import config as cache_config
from zhenxun.services.cache import Cache, CacheRoot, cache_config
from zhenxun.services.cache.config import (
COMPOSITE_KEY_SEPARATOR,
CacheMode,
)
from zhenxun.services.db_context import Model
from zhenxun.services.log import logger
@ -37,15 +40,89 @@ class DataAccess(Generic[T]):
```
"""
def __init__(self, model_cls: type[T], cache_type: str | None = None):
def __init__(
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
):
"""初始化数据访问对象
参数:
model_cls: 模型类
cache_type: 缓存类型如果为None则使用模型类的cache_type属性
key_field: 主键字段
"""
self.model_cls = model_cls
self.cache_type = cache_type or getattr(model_cls, "cache_type", None)
self.key_field = getattr(model_cls, "cache_key_field", key_field)
self.cache_type = getattr(model_cls, "cache_type", cache_type)
if not self.cache_type:
raise ValueError("缓存类型不能为空")
self.cache = Cache(self.cache_type)
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
"""从关键字参数构建缓存键
参数:
**kwargs: 关键字参数
返回:
str | None: 缓存键如果无法构建则返回None
"""
if isinstance(self.key_field, tuple):
# 多字段主键
key_parts = []
for field in self.key_field:
key_parts.append(str(kwargs.get(field, "")))
if key_parts:
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
return None
elif self.key_field in kwargs:
# 单字段主键
return str(kwargs[self.key_field])
return None
async def safe_get_or_none(self, *args, **kwargs) -> T | None:
"""安全的获取单条数据
参数:
*args: 查询参数
**kwargs: 查询参数
返回:
Optional[T]: 查询结果如果不存在返回None
"""
# 如果没有缓存类型,直接从数据库获取
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
return await self.model_cls.safe_get_or_none(*args, **kwargs)
# 尝试从缓存获取
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
# 如果成功构建缓存键,尝试从缓存获取
if cache_key is not None:
data = await self.cache.get(cache_key)
if data:
return cast(T, data)
except Exception as e:
logger.error("从缓存获取数据失败", e=e)
# 如果缓存中没有,从数据库获取
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
# 如果获取到数据,存入缓存
if data:
try:
# 生成缓存键
cache_key = self._build_cache_key_for_item(data)
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
except Exception as e:
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
return data
async def get_or_none(self, *args, **kwargs) -> T | None:
"""获取单条数据
@ -57,23 +134,169 @@ class DataAccess(Generic[T]):
返回:
Optional[T]: 查询结果如果不存在返回None
"""
# 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取
if not cache_config.enable_cache or not self.cache_type:
return await self.model_cls.safe_get_or_none(*args, **kwargs)
# 如果没有缓存类型,直接从数据库获取
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
return await self.model_cls.get_or_none(*args, **kwargs)
# 从缓存获取
# 尝试从缓存获取
try:
# 生成缓存键
key = self._generate_cache_key(kwargs)
# 尝试从缓存获取
data = await CacheRoot.get(self.cache_type, key)
if data:
return cast(T, data)
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
# 如果成功构建缓存键,尝试从缓存获取
if cache_key is not None:
data = await self.cache.get(cache_key)
if data:
return cast(T, data)
except Exception as e:
logger.error("从缓存获取数据失败", e=e)
# 如果缓存中没有,从数据库获取
return await self.model_cls.safe_get_or_none(*args, **kwargs)
data = await self.model_cls.get_or_none(*args, **kwargs)
# 如果获取到数据,存入缓存
if data:
try:
cache_key = self._build_cache_key_for_item(data)
# 生成缓存键
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
except Exception as e:
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
return data
async def clear_cache(self, **kwargs) -> bool:
"""只清除缓存,不影响数据库数据
参数:
**kwargs: 查询参数必须包含主键字段
返回:
bool: 是否成功清除缓存
"""
# 如果没有缓存类型直接返回True
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
return True
try:
# 构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
if cache_key is None:
if isinstance(self.key_field, tuple):
# 如果是复合键,检查缺少哪些字段
missing_fields = [
field for field in self.key_field if field not in kwargs
]
logger.error(
f"清除{self.model_cls.__name__}缓存失败: "
f"缺少主键字段 {', '.join(missing_fields)}"
)
else:
logger.error(
f"清除{self.model_cls.__name__}缓存失败: "
f"缺少主键字段 {self.key_field}"
)
return False
# 删除缓存
await self.cache.delete(cache_key)
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
return True
except Exception as e:
logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e)
return False
def _build_composite_key(self, data: T) -> str | None:
"""构建复合缓存键
参数:
data: 数据对象
返回:
str | None: 构建的缓存键如果无法构建则返回None
"""
# 如果是元组,表示多个字段组成键
if isinstance(self.key_field, tuple):
# 构建键参数列表
key_parts = []
for field in self.key_field:
value = getattr(data, field, "")
key_parts.append(value if value is not None else "")
# 如果没有有效参数返回None
if not key_parts:
return None
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
# 单个字段作为键
elif hasattr(data, self.key_field):
value = getattr(data, self.key_field, None)
return str(value) if value is not None else None
return None
def _build_cache_key_for_item(self, item: T) -> str | None:
"""为数据项构建缓存键
参数:
item: 数据项
返回:
str | None: 缓存键如果无法生成则返回None
"""
# 如果没有缓存类型返回None
if not self.cache_type:
return None
# 获取缓存类型的配置信息
cache_model = CacheRoot.get_model(self.cache_type)
# 如果有键格式定义,则需要构建特殊格式的键
if cache_model.key_format:
# 构建键参数字典
key_parts = []
# 从格式字符串中提取所需的字段名
import re
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
# 收集所有字段值
for field in field_names:
value = getattr(item, field, "")
key_parts.append(value if value is not None else "")
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
else:
# 常规处理,使用主键作为缓存键
return self._build_composite_key(item)
async def _cache_items(self, data_list: list[T]) -> None:
"""将数据列表存入缓存
参数:
data_list: 数据列表
"""
if (
not data_list
or not self.cache_type
or cache_config.cache_mode == CacheMode.NONE
):
return
try:
# 遍历数据列表,将每条数据存入缓存
for item in data_list:
cache_key = self._build_cache_key_for_item(item)
if cache_key is not None:
await self.cache.set(cache_key, item)
logger.debug(f"{self.cache_type} 数据已存入缓存,数量: {len(data_list)}")
except Exception as e:
logger.error(f"{self.cache_type} 数据存入缓存失败", e=e)
async def filter(self, *args, **kwargs) -> list[T]:
"""筛选数据
@ -85,30 +308,13 @@ class DataAccess(Generic[T]):
返回:
List[T]: 查询结果列表
"""
# 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取
if not cache_config.enable_cache or not self.cache_type:
return await self.model_cls.filter(*args, **kwargs)
# 从数据库获取数据
data_list = await self.model_cls.filter(*args, **kwargs)
# 尝试从缓存获取所有数据后筛选
try:
# 获取缓存数据
cache_data = await CacheRoot.get_cache_data(self.cache_type)
if isinstance(cache_data, dict) and cache_data:
# 在内存中筛选
filtered_data = []
for item in cache_data.values():
match = not any(
not hasattr(item, k) or getattr(item, k) != v
for k, v in kwargs.items()
)
if match:
filtered_data.append(item)
return cast(list[T], filtered_data)
except Exception as e:
logger.error("从缓存筛选数据失败", e=e)
# 将数据存入缓存
await self._cache_items(data_list)
# 如果缓存中没有或筛选失败,从数据库获取
return await self.model_cls.filter(*args, **kwargs)
return data_list
async def all(self) -> list[T]:
"""获取所有数据
@ -116,21 +322,13 @@ class DataAccess(Generic[T]):
返回:
List[T]: 所有数据列表
"""
# 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取
if not cache_config.enable_cache or not self.cache_type:
return await self.model_cls.all()
# 直接从数据库获取
data_list = await self.model_cls.all()
# 尝试从缓存获取所有数据
try:
# 获取缓存数据
cache_data = await CacheRoot.get_cache_data(self.cache_type)
if isinstance(cache_data, dict) and cache_data:
return cast(list[T], list(cache_data.values()))
except Exception as e:
logger.error("从缓存获取所有数据失败", e=e)
# 将数据存入缓存
await self._cache_items(data_list)
# 如果缓存中没有,从数据库获取
return await self.model_cls.all()
return data_list
async def count(self, *args, **kwargs) -> int:
"""获取数据数量
@ -167,7 +365,24 @@ class DataAccess(Generic[T]):
返回:
T: 创建的数据
"""
return await self.model_cls.create(**kwargs)
# 创建数据
data = await self.model_cls.create(**kwargs)
# 如果有缓存类型,将数据存入缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
# 生成缓存键
cache_key = self._build_cache_key_for_item(data)
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
logger.debug(
f"{self.cache_type} 新创建的数据已存入缓存: {cache_key}"
)
except Exception as e:
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
return data
async def update_or_create(
self, defaults: dict[str, Any] | None = None, **kwargs
@ -181,7 +396,24 @@ class DataAccess(Generic[T]):
返回:
tuple[T, bool]: (数据, 是否创建)
"""
return await self.model_cls.update_or_create(defaults=defaults, **kwargs)
# 更新或创建数据
data, created = await self.model_cls.update_or_create(
defaults=defaults, **kwargs
)
# 如果有缓存类型,将数据存入缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
# 生成缓存键
cache_key = self._build_cache_key_for_item(data)
if cache_key is not None:
# 存入缓存
await self.cache.set(cache_key, data)
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
except Exception as e:
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
return data, created
async def delete(self, *args, **kwargs) -> int:
"""删除数据
@ -193,18 +425,43 @@ class DataAccess(Generic[T]):
返回:
int: 删除的数据数量
"""
# 如果有缓存类型且有key_field参数先尝试删除缓存
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
try:
# 尝试构建缓存键
cache_key = self._build_cache_key_from_kwargs(**kwargs)
if cache_key is not None:
# 如果成功构建缓存键,直接删除缓存
await self.cache.delete(cache_key)
logger.debug(f"已删除缓存: {cache_key}")
else:
# 否则需要先查询出要删除的数据,然后删除对应的缓存
items = await self.model_cls.filter(*args, **kwargs)
for item in items:
item_cache_key = self._build_cache_key_for_item(item)
if item_cache_key is not None:
await self.cache.delete(item_cache_key)
if items:
logger.debug(f"已删除{len(items)}条数据的缓存")
except Exception as e:
logger.error("删除缓存失败", e=e)
# 删除数据
return await self.model_cls.filter(*args, **kwargs).delete()
def _generate_cache_key(self, kwargs: dict[str, Any]) -> str:
"""根据查询参数生成缓存键
def _generate_cache_key(self, data: T) -> str:
"""根据据对象生成缓存键
参数:
kwargs: 查询参数
data: 数据对象
返回:
str: 缓存键
"""
# 实现一个简单的键生成算法
if not kwargs:
return "default"
return "_".join(f"{k}:{v}" for k, v in sorted(kwargs.items()))
# 使用新方法构建复合键
if composite_key := self._build_composite_key(data):
return composite_key
# 如果无法生成复合键,生成一个唯一键
return f"object_{id(data)}"

View File

@ -16,6 +16,7 @@ from zhenxun.utils.exception import HookPriorityException
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
from .cache import CacheRoot
from .cache.config import COMPOSITE_KEY_SEPARATOR
from .log import logger
SCRIPT_METHOD = []
@ -23,14 +24,6 @@ MODELS: list[str] = []
driver = nonebot.get_driver()
CACHE_FLAG = False
@driver.on_bot_connect
def _():
global CACHE_FLAG
CACHE_FLAG = True
class Model(TortoiseModel):
"""
@ -56,8 +49,55 @@ class Model(TortoiseModel):
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
@classmethod
def get_cache_type(cls):
return getattr(cls, "cache_type", None) if CACHE_FLAG else None
def get_cache_type(cls) -> str | None:
return getattr(cls, "cache_type", None)
@classmethod
def get_cache_key_field(cls) -> str | tuple[str]:
"""获取缓存键字段名
返回:
str | tuple[str]: 缓存键字段名可能是单个字段名或字段名元组
"""
if hasattr(cls, "cache_key_field"):
return getattr(cls, "cache_key_field", "id")
return "id"
@classmethod
def get_cache_key(cls, instance) -> str | None:
"""获取缓存键
参数:
instance: 模型实例
返回:
str | None
"""
key_field = cls.get_cache_key_field()
# 如果是元组,表示多个字段组成键
if isinstance(key_field, tuple):
# 构建键参数列表
key_parts = []
for field in key_field:
if hasattr(instance, field):
value = getattr(instance, field, None)
key_parts.append(value if value is not None else "")
else:
# 如果缺少任何必要的字段返回None
return None
# 如果没有有效参数返回None
if not key_parts:
return None
return COMPOSITE_KEY_SEPARATOR.join(str(param) for param in key_parts)
# 单个字段作为键
elif hasattr(instance, key_field):
return getattr(instance, key_field, None)
return None
@classmethod
async def create(
@ -79,7 +119,11 @@ class Model(TortoiseModel):
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
# 获取缓存键
key = cls.get_cache_key(result)
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return (result, is_create)
else:
# 如果没有锁,则执行原来的逻辑
@ -87,7 +131,11 @@ class Model(TortoiseModel):
defaults=defaults, using_db=using_db, **kwargs
)
if is_create and (cache_type := cls.get_cache_type()):
await CacheRoot.reload(cache_type)
# 获取缓存键
key = cls.get_cache_key(result)
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return (result, is_create)
@classmethod
@ -104,7 +152,11 @@ class Model(TortoiseModel):
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
# 获取缓存键
key = cls.get_cache_key(result[0])
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return result
else:
# 如果没有锁,则执行原来的逻辑
@ -112,19 +164,23 @@ class Model(TortoiseModel):
defaults=defaults, using_db=using_db, **kwargs
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
# 获取缓存键
key = cls.get_cache_key(result[0])
await CacheRoot.invalidate_cache(
cache_type, key if key is not None else None
)
return result
@classmethod
async def bulk_create( # type: ignore
cls,
objects: Iterable[Self],
objects: Iterable[Self], # type: ignore
batch_size: int | None = None,
ignore_conflicts: bool = False,
update_fields: Iterable[str] | None = None,
on_conflict: Iterable[str] | None = None,
using_db: BaseDBAsyncClient | None = None,
) -> list[Self]:
) -> list[Self]: # type: ignore
result = await super().bulk_create(
objects=objects,
batch_size=batch_size,
@ -134,17 +190,18 @@ class Model(TortoiseModel):
using_db=using_db,
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
# 批量创建时清除整个类型的缓存
await CacheRoot.invalidate_cache(cache_type)
return result
@classmethod
async def bulk_update( # type: ignore
cls,
objects: Iterable[Self],
objects: Iterable[Self], # type: ignore
fields: Iterable[str],
batch_size: int | None = None,
using_db: BaseDBAsyncClient | None = None,
) -> int:
) -> int: # type: ignore
result = await super().bulk_update(
objects=objects,
fields=fields,
@ -152,7 +209,8 @@ class Model(TortoiseModel):
using_db=using_db,
)
if cache_type := cls.get_cache_type():
await CacheRoot.reload(cache_type)
# 批量更新时清除整个类型的缓存
await CacheRoot.invalidate_cache(cache_type)
return result
async def save(
@ -181,13 +239,24 @@ class Model(TortoiseModel):
force_create=force_create,
force_update=force_update,
)
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type)
if cache_type := getattr(self, "cache_type", None):
# 获取缓存键
key = self.__class__.get_cache_key(self)
await CacheRoot.invalidate_cache(cache_type, key)
async def delete(self, using_db: BaseDBAsyncClient | None = None):
# 在删除前获取缓存键
cache_type = getattr(self, "cache_type", None)
key = None
if cache_type:
key = self.__class__.get_cache_key(self)
# 执行删除操作
await super().delete(using_db=using_db)
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
await CacheRoot.reload(cache_type)
# 清除缓存
if cache_type:
await CacheRoot.invalidate_cache(cache_type, key)
@classmethod
async def safe_get_or_none(