mirror of
https://github.com/zhenxun-org/zhenxun_bot.git
synced 2025-12-15 14:22:55 +08:00
♻️ 重构cache
This commit is contained in:
parent
7d962bb4a8
commit
56162e24ea
4
.env.dev
4
.env.dev
@ -27,6 +27,8 @@ QBOT_ID_DATA = '{
|
|||||||
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
# 示例: "sqlite:data/db/zhenxun.db" 在data目录下建立db文件夹
|
||||||
DB_URL = ""
|
DB_URL = ""
|
||||||
|
|
||||||
|
# NONE: 不使用缓存, MEMORY: 使用内存缓存, REDIS: 使用Redis缓存
|
||||||
|
CACHE_MODE = NONE
|
||||||
# REDIS配置,使用REDIS替换Cache内存缓存
|
# REDIS配置,使用REDIS替换Cache内存缓存
|
||||||
# REDIS地址
|
# REDIS地址
|
||||||
# REDIS_HOST = "127.0.0.1"
|
# REDIS_HOST = "127.0.0.1"
|
||||||
@ -50,7 +52,7 @@ PLATFORM_SUPERUSERS = '
|
|||||||
DRIVER=~fastapi+~httpx+~websockets
|
DRIVER=~fastapi+~httpx+~websockets
|
||||||
|
|
||||||
|
|
||||||
# LOG_LEVEL=DEBUG
|
# LOG_LEVEL = DEBUG
|
||||||
# 服务器和端口
|
# 服务器和端口
|
||||||
HOST = 127.0.0.1
|
HOST = 127.0.0.1
|
||||||
PORT = 8080
|
PORT = 8080
|
||||||
|
|||||||
@ -3,8 +3,7 @@ from nonebot_plugin_uninfo import Uninfo
|
|||||||
|
|
||||||
from zhenxun.models.level_user import LevelUser
|
from zhenxun.models.level_user import LevelUser
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.data_access import DataAccess
|
||||||
from zhenxun.utils.enum import CacheType
|
|
||||||
from zhenxun.utils.utils import get_entity_ids
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
@ -21,15 +20,21 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
|||||||
if not plugin.admin_level:
|
if not plugin.admin_level:
|
||||||
return
|
return
|
||||||
entity = get_entity_ids(session)
|
entity = get_entity_ids(session)
|
||||||
cache = Cache[list[LevelUser]](CacheType.LEVEL)
|
level_dao = DataAccess(LevelUser)
|
||||||
user_list = await cache.get(session.user.id) or []
|
global_user = await level_dao.safe_get_or_none(
|
||||||
|
user_id=session.user.id, group_id__isnull=True
|
||||||
|
)
|
||||||
|
user_level = 0
|
||||||
|
if global_user:
|
||||||
|
user_level = global_user.user_level
|
||||||
if entity.group_id:
|
if entity.group_id:
|
||||||
user_list += await cache.get(session.user.id, entity.group_id) or []
|
# 获取用户在当前群组的权限数据
|
||||||
if user_list:
|
group_users = await level_dao.safe_get_or_none(
|
||||||
user = max(user_list, key=lambda x: x.user_level)
|
user_id=session.user.id, group_id=entity.group_id
|
||||||
user_level = user.user_level
|
)
|
||||||
else:
|
if group_users:
|
||||||
user_level = 0
|
user_level = max(user_level, group_users.user_level)
|
||||||
|
|
||||||
if user_level < plugin.admin_level:
|
if user_level < plugin.admin_level:
|
||||||
await send_message(
|
await send_message(
|
||||||
session,
|
session,
|
||||||
@ -42,11 +47,12 @@ async def auth_admin(plugin: PluginInfo, session: Uninfo):
|
|||||||
raise SkipPluginException(
|
raise SkipPluginException(
|
||||||
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||||
)
|
)
|
||||||
elif user_list:
|
elif global_user:
|
||||||
user = max(user_list, key=lambda x: x.user_level)
|
if global_user.user_level < plugin.admin_level:
|
||||||
if user.user_level < plugin.admin_level:
|
|
||||||
await send_message(
|
await send_message(
|
||||||
session,
|
session,
|
||||||
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
f"你的权限不足喔,该功能需要的权限等级: {plugin.admin_level}",
|
||||||
)
|
)
|
||||||
raise SkipPluginException(f"{plugin.name}({plugin.module}) 管理员权限不足...")
|
raise SkipPluginException(
|
||||||
|
f"{plugin.name}({plugin.module}) 管理员权限不足..."
|
||||||
|
)
|
||||||
|
|||||||
@ -1,20 +1,15 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot_plugin_alconna import At
|
from nonebot_plugin_alconna import At
|
||||||
from nonebot_plugin_uninfo import Uninfo
|
from nonebot_plugin_uninfo import Uninfo
|
||||||
from tortoise.exceptions import MultipleObjectsReturned
|
|
||||||
|
|
||||||
from zhenxun.configs.config import Config
|
from zhenxun.configs.config import Config
|
||||||
from zhenxun.models.ban_console import BanConsole
|
from zhenxun.models.ban_console import BanConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.data_access import DataAccess
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.utils.enum import PluginType
|
||||||
from zhenxun.utils.enum import CacheType, PluginType
|
|
||||||
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
from zhenxun.utils.utils import EntityIDs, get_entity_ids
|
||||||
|
|
||||||
from .config import LOGGER_COMMAND
|
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
from .utils import freq, send_message
|
from .utils import freq, send_message
|
||||||
|
|
||||||
@ -29,10 +24,18 @@ Config.add_plugin_config(
|
|||||||
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
async def is_ban(user_id: str | None, group_id: str | None) -> int:
|
||||||
if not user_id and not group_id:
|
if not user_id and not group_id:
|
||||||
return 0
|
return 0
|
||||||
cache = Cache[BanConsole](CacheType.BAN)
|
ban_dao = DataAccess(BanConsole)
|
||||||
group_user, user = await asyncio.gather(
|
|
||||||
cache.get(user_id, group_id), cache.get(user_id)
|
# 分别获取用户在群组中的ban记录和全局ban记录
|
||||||
)
|
group_user = None
|
||||||
|
user = None
|
||||||
|
|
||||||
|
if user_id and group_id:
|
||||||
|
group_user = await ban_dao.safe_get_or_none(user_id=user_id, group_id=group_id)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
user = await ban_dao.safe_get_or_none(user_id=user_id, group_id="")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if group_user:
|
if group_user:
|
||||||
results.append(group_user)
|
results.append(group_user)
|
||||||
@ -88,76 +91,55 @@ def format_time(time: float) -> str:
|
|||||||
return time_str
|
return time_str
|
||||||
|
|
||||||
|
|
||||||
async def group_handle(cache: Cache[list[BanConsole]], group_id: str):
|
async def group_handle(group_id: str):
|
||||||
"""群组ban检查
|
"""群组ban检查
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
cache: cache
|
ban_dao: BanConsole数据访问对象
|
||||||
group_id: 群组id
|
group_id: 群组id
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
SkipPluginException: 群组处于黑名单
|
SkipPluginException: 群组处于黑名单
|
||||||
"""
|
"""
|
||||||
try:
|
if await is_ban(None, group_id):
|
||||||
if await is_ban(None, group_id):
|
raise SkipPluginException("群组处于黑名单中...")
|
||||||
raise SkipPluginException("群组处于黑名单中...")
|
|
||||||
except MultipleObjectsReturned:
|
|
||||||
logger.warning(
|
|
||||||
"群组黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND
|
|
||||||
)
|
|
||||||
ids = await BanConsole.filter(user_id="", group_id=group_id).values_list(
|
|
||||||
"id", flat=True
|
|
||||||
)
|
|
||||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
|
||||||
await cache.reload()
|
|
||||||
|
|
||||||
|
|
||||||
async def user_handle(
|
async def user_handle(module: str, entity: EntityIDs, session: Uninfo):
|
||||||
module: str, cache: Cache[list[BanConsole]], entity: EntityIDs, session: Uninfo
|
|
||||||
):
|
|
||||||
"""用户ban检查
|
"""用户ban检查
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
module: 插件模块名
|
module: 插件模块名
|
||||||
cache: cache
|
ban_dao: BanConsole数据访问对象
|
||||||
user_id: 用户id
|
entity: 实体ID信息
|
||||||
session: Uninfo
|
session: Uninfo
|
||||||
|
|
||||||
异常:
|
异常:
|
||||||
SkipPluginException: 用户处于黑名单
|
SkipPluginException: 用户处于黑名单
|
||||||
"""
|
"""
|
||||||
ban_result = Config.get_config("hook", "BAN_RESULT")
|
ban_result = Config.get_config("hook", "BAN_RESULT")
|
||||||
try:
|
time = await is_ban(entity.user_id, entity.group_id)
|
||||||
time = await is_ban(entity.user_id, entity.group_id)
|
if not time:
|
||||||
if not time:
|
return
|
||||||
return
|
time_str = format_time(time)
|
||||||
time_str = format_time(time)
|
plugin_dao = DataAccess(PluginInfo)
|
||||||
db_plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
db_plugin = await plugin_dao.safe_get_or_none(module=module)
|
||||||
if (
|
if (
|
||||||
db_plugin
|
db_plugin
|
||||||
# and not db_plugin.ignore_prompt
|
and not db_plugin.ignore_prompt
|
||||||
and time != -1
|
and time != -1
|
||||||
and ban_result
|
and ban_result
|
||||||
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
and freq.is_send_limit_message(db_plugin, entity.user_id, False)
|
||||||
):
|
):
|
||||||
await send_message(
|
await send_message(
|
||||||
session,
|
session,
|
||||||
[
|
[
|
||||||
At(flag="user", target=entity.user_id),
|
At(flag="user", target=entity.user_id),
|
||||||
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
f"{ban_result}\n在..在 {time_str} 后才会理你喔",
|
||||||
],
|
],
|
||||||
entity.user_id,
|
entity.user_id,
|
||||||
)
|
|
||||||
raise SkipPluginException("用户处于黑名单中...")
|
|
||||||
except MultipleObjectsReturned:
|
|
||||||
logger.warning(
|
|
||||||
"用户黑名单数据重复,过滤该次hook并移除多余数据...", LOGGER_COMMAND
|
|
||||||
)
|
)
|
||||||
ids = await BanConsole.filter(user_id=entity.user_id, group_id="").values_list(
|
raise SkipPluginException("用户处于黑名单中...")
|
||||||
"id", flat=True
|
|
||||||
)
|
|
||||||
await BanConsole.filter(id__in=ids[:-1]).delete()
|
|
||||||
await cache.reload()
|
|
||||||
|
|
||||||
|
|
||||||
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
|
async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
|
||||||
@ -168,8 +150,7 @@ async def auth_ban(matcher: Matcher, bot: Bot, session: Uninfo):
|
|||||||
entity = get_entity_ids(session)
|
entity = get_entity_ids(session)
|
||||||
if entity.user_id in bot.config.superusers:
|
if entity.user_id in bot.config.superusers:
|
||||||
return
|
return
|
||||||
cache = Cache[list[BanConsole]](CacheType.BAN)
|
|
||||||
if entity.group_id:
|
if entity.group_id:
|
||||||
await group_handle(cache, entity.group_id)
|
await group_handle(entity.group_id)
|
||||||
if entity.user_id:
|
if entity.user_id:
|
||||||
await user_handle(matcher.plugin_name, cache, entity, session)
|
await user_handle(matcher.plugin_name, entity, session)
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from zhenxun.models.bot_console import BotConsole
|
from zhenxun.models.bot_console import BotConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.data_access import DataAccess
|
||||||
from zhenxun.utils.common_utils import CommonUtils
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import CacheType
|
|
||||||
|
|
||||||
from .exception import SkipPluginException
|
from .exception import SkipPluginException
|
||||||
|
|
||||||
@ -18,11 +17,11 @@ async def auth_bot(plugin: PluginInfo, bot_id: str):
|
|||||||
SkipPluginException: 忽略插件
|
SkipPluginException: 忽略插件
|
||||||
SkipPluginException: 忽略插件
|
SkipPluginException: 忽略插件
|
||||||
"""
|
"""
|
||||||
if cache := Cache[BotConsole](CacheType.BOT):
|
bot_dao = DataAccess(BotConsole)
|
||||||
bot = await cache.get(bot_id)
|
bot = await bot_dao.safe_get_or_none(bot_id=bot_id)
|
||||||
if not bot or not bot.status:
|
if not bot or not bot.status:
|
||||||
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
raise SkipPluginException("Bot不存在或休眠中阻断权限检测...")
|
||||||
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
if CommonUtils.format(plugin.module) in bot.block_plugins:
|
||||||
raise SkipPluginException(
|
raise SkipPluginException(
|
||||||
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
f"Bot插件 {plugin.name}({plugin.module}) 权限检查结果为关闭..."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ async def auth_cost(user: UserConsole, plugin: PluginInfo, session: Uninfo) -> i
|
|||||||
"""检测是否满足金币条件
|
"""检测是否满足金币条件
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
|
user: UserConsole
|
||||||
plugin: PluginInfo
|
plugin: PluginInfo
|
||||||
session: Uninfo
|
session: Uninfo
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,7 @@ from nonebot_plugin_alconna import UniMsg
|
|||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.data_access import DataAccess
|
||||||
from zhenxun.utils.enum import CacheType
|
|
||||||
from zhenxun.utils.utils import EntityIDs
|
from zhenxun.utils.utils import EntityIDs
|
||||||
|
|
||||||
from .config import SwitchEnum
|
from .config import SwitchEnum
|
||||||
@ -21,7 +20,10 @@ async def auth_group(plugin: PluginInfo, entity: EntityIDs, message: UniMsg):
|
|||||||
if not entity.group_id:
|
if not entity.group_id:
|
||||||
return
|
return
|
||||||
text = message.extract_plain_text()
|
text = message.extract_plain_text()
|
||||||
group = await Cache[GroupConsole](CacheType.GROUPS).get(entity.group_id)
|
group_dao = DataAccess(GroupConsole)
|
||||||
|
group = await group_dao.safe_get_or_none(
|
||||||
|
group_id=entity.group_id, channel_id__isnull=True
|
||||||
|
)
|
||||||
if not group:
|
if not group:
|
||||||
raise SkipPluginException("群组信息不存在...")
|
raise SkipPluginException("群组信息不存在...")
|
||||||
if group.level < 0:
|
if group.level < 0:
|
||||||
|
|||||||
@ -3,9 +3,9 @@ from nonebot_plugin_uninfo import Uninfo
|
|||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.data_access import DataAccess
|
||||||
from zhenxun.utils.common_utils import CommonUtils
|
from zhenxun.utils.common_utils import CommonUtils
|
||||||
from zhenxun.utils.enum import BlockType, CacheType
|
from zhenxun.utils.enum import BlockType
|
||||||
from zhenxun.utils.utils import get_entity_ids
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
|
|
||||||
from .exception import IsSuperuserException, SkipPluginException
|
from .exception import IsSuperuserException, SkipPluginException
|
||||||
@ -20,10 +20,12 @@ class GroupCheck:
|
|||||||
self.session = session
|
self.session = session
|
||||||
self.is_poke = is_poke
|
self.is_poke = is_poke
|
||||||
self.plugin = plugin
|
self.plugin = plugin
|
||||||
|
self.group_dao = DataAccess(GroupConsole)
|
||||||
|
|
||||||
async def __get_data(self):
|
async def __get_data(self):
|
||||||
cache = Cache[GroupConsole](CacheType.GROUPS)
|
return await self.group_dao.safe_get_or_none(
|
||||||
return await cache.get(self.group_id)
|
group_id=self.group_id, channel_id__isnull=True
|
||||||
|
)
|
||||||
|
|
||||||
async def check(self):
|
async def check(self):
|
||||||
await self.check_superuser_block(self.plugin)
|
await self.check_superuser_block(self.plugin)
|
||||||
@ -89,6 +91,7 @@ class PluginCheck:
|
|||||||
self.session = session
|
self.session = session
|
||||||
self.is_poke = is_poke
|
self.is_poke = is_poke
|
||||||
self.group_id = group_id
|
self.group_id = group_id
|
||||||
|
self.group_dao = DataAccess(GroupConsole)
|
||||||
|
|
||||||
async def check_user(self, plugin: PluginInfo):
|
async def check_user(self, plugin: PluginInfo):
|
||||||
"""全局私聊禁用检测
|
"""全局私聊禁用检测
|
||||||
@ -118,9 +121,11 @@ class PluginCheck:
|
|||||||
if plugin.status or plugin.block_type != BlockType.ALL:
|
if plugin.status or plugin.block_type != BlockType.ALL:
|
||||||
return
|
return
|
||||||
"""全局状态"""
|
"""全局状态"""
|
||||||
cache = Cache[GroupConsole](CacheType.GROUPS)
|
if self.group_id:
|
||||||
if self.group_id and (group := await cache.get(self.group_id)):
|
group = await self.group_dao.safe_get_or_none(
|
||||||
if group.is_super:
|
group_id=self.group_id, channel_id__isnull=True
|
||||||
|
)
|
||||||
|
if group and group.is_super:
|
||||||
raise IsSuperuserException()
|
raise IsSuperuserException()
|
||||||
sid = self.group_id or self.session.user.id
|
sid = self.group_id or self.session.user.id
|
||||||
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
if freq.is_send_limit_message(plugin, sid, self.is_poke):
|
||||||
|
|||||||
@ -9,13 +9,9 @@ from tortoise.exceptions import IntegrityError
|
|||||||
|
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.user_console import UserConsole
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.cache import Cache
|
from zhenxun.services.data_access import DataAccess
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import (
|
from zhenxun.utils.enum import GoldHandle, PluginType
|
||||||
CacheType,
|
|
||||||
GoldHandle,
|
|
||||||
PluginType,
|
|
||||||
)
|
|
||||||
from zhenxun.utils.exception import InsufficientGold
|
from zhenxun.utils.exception import InsufficientGold
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
from zhenxun.utils.utils import get_entity_ids
|
from zhenxun.utils.utils import get_entity_ids
|
||||||
@ -54,8 +50,9 @@ async def get_plugin_and_user(
|
|||||||
返回:
|
返回:
|
||||||
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
tuple[PluginInfo, UserConsole]: 插件信息,用户信息
|
||||||
"""
|
"""
|
||||||
user_cache = Cache[UserConsole](CacheType.USERS)
|
user_dao = DataAccess(UserConsole)
|
||||||
plugin = await Cache[PluginInfo](CacheType.PLUGINS).get(module)
|
plugin_dao = DataAccess(PluginInfo)
|
||||||
|
plugin = await plugin_dao.safe_get_or_none(module=module)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
raise PermissionExemption(f"插件:{module} 数据不存在,已跳过权限检查...")
|
||||||
if plugin.plugin_type == PluginType.HIDDEN:
|
if plugin.plugin_type == PluginType.HIDDEN:
|
||||||
@ -64,7 +61,7 @@ async def get_plugin_and_user(
|
|||||||
)
|
)
|
||||||
user = None
|
user = None
|
||||||
try:
|
try:
|
||||||
user = await user_cache.get(user_id)
|
user = await user_dao.safe_get_or_none(user_id=user_id)
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
|
raise PermissionExemption("重复创建用户,已跳过该次权限检查...") from e
|
||||||
if not user:
|
if not user:
|
||||||
@ -108,7 +105,7 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo
|
|||||||
cost_gold: 消耗金币
|
cost_gold: 消耗金币
|
||||||
session: Uninfo
|
session: Uninfo
|
||||||
"""
|
"""
|
||||||
user_cache = Cache[UserConsole](CacheType.USERS)
|
user_dao = DataAccess(UserConsole)
|
||||||
try:
|
try:
|
||||||
await UserConsole.reduce_gold(
|
await UserConsole.reduce_gold(
|
||||||
user_id,
|
user_id,
|
||||||
@ -121,8 +118,8 @@ async def reduce_gold(user_id: str, module: str, cost_gold: int, session: Uninfo
|
|||||||
if u := await UserConsole.get_user(user_id):
|
if u := await UserConsole.get_user(user_id):
|
||||||
u.gold = 0
|
u.gold = 0
|
||||||
await u.save(update_fields=["gold"])
|
await u.save(update_fields=["gold"])
|
||||||
# 更新缓存
|
# 清除缓存,使下次查询时从数据库获取最新数据
|
||||||
await user_cache.update(user_id)
|
await user_dao.clear_cache(user_id=user_id)
|
||||||
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
logger.debug(f"调用功能花费金币: {cost_gold}", LOGGER_COMMAND, session=session)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from zhenxun.utils.enum import BotSentType
|
|||||||
from zhenxun.utils.manager.message_manager import MessageManager
|
from zhenxun.utils.manager.message_manager import MessageManager
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
|
LOG_COMMAND = "MessageHook"
|
||||||
|
|
||||||
|
|
||||||
def replace_message(message: Message) -> str:
|
def replace_message(message: Message) -> str:
|
||||||
"""将消息中的at、image、record、face替换为字符串
|
"""将消息中的at、image、record、face替换为字符串
|
||||||
@ -54,11 +56,11 @@ async def handle_api_result(
|
|||||||
if user_id and message_id:
|
if user_id and message_id:
|
||||||
MessageManager.add(str(user_id), str(message_id))
|
MessageManager.add(str(user_id), str(message_id))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", "msg_hook"
|
f"收集消息id,user_id: {user_id}, msg_id: {message_id}", LOG_COMMAND
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"收集消息id发生错误...data: {data}, result: {result}", "msg_hook", e=e
|
f"收集消息id发生错误...data: {data}, result: {result}", LOG_COMMAND, e=e
|
||||||
)
|
)
|
||||||
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
if not Config.get_config("hook", "RECORD_BOT_SENT_MESSAGES"):
|
||||||
return
|
return
|
||||||
@ -80,6 +82,6 @@ async def handle_api_result(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"消息发送记录发生错误...data: {data}, result: {result}",
|
f"消息发送记录发生错误...data: {data}, result: {result}",
|
||||||
"msg_hook",
|
LOG_COMMAND,
|
||||||
e=e,
|
e=e,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,23 +4,25 @@ import nonebot
|
|||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
|
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.services.cache import DbCacheException
|
from zhenxun.services.cache import CacheException
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
from zhenxun.utils.platform import PlatformUtils
|
from zhenxun.utils.platform import PlatformUtils
|
||||||
|
|
||||||
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
nonebot.load_plugins(str(Path(__file__).parent.resolve()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .__init_cache import CacheRoot
|
from .__init_cache import register_cache_types
|
||||||
except DbCacheException as e:
|
except CacheException as e:
|
||||||
raise SystemError(f"ERROR:{e}")
|
raise SystemError(f"ERROR:{e}")
|
||||||
|
|
||||||
driver = nonebot.get_driver()
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
@PriorityLifecycle.on_startup(priority=5)
|
||||||
async def _():
|
async def _():
|
||||||
await CacheRoot.init_non_lazy_caches()
|
register_cache_types()
|
||||||
|
logger.info("缓存类型注册完成")
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_connect
|
@driver.on_bot_connect
|
||||||
|
|||||||
@ -1,208 +1,35 @@
|
|||||||
|
"""
|
||||||
|
缓存初始化模块
|
||||||
|
|
||||||
|
负责注册各种缓存类型,实现按需缓存机制
|
||||||
|
"""
|
||||||
|
|
||||||
from zhenxun.models.ban_console import BanConsole
|
from zhenxun.models.ban_console import BanConsole
|
||||||
from zhenxun.models.bot_console import BotConsole
|
from zhenxun.models.bot_console import BotConsole
|
||||||
from zhenxun.models.group_console import GroupConsole
|
from zhenxun.models.group_console import GroupConsole
|
||||||
from zhenxun.models.level_user import LevelUser
|
from zhenxun.models.level_user import LevelUser
|
||||||
from zhenxun.models.plugin_info import PluginInfo
|
from zhenxun.models.plugin_info import PluginInfo
|
||||||
from zhenxun.models.user_console import UserConsole
|
from zhenxun.models.user_console import UserConsole
|
||||||
from zhenxun.services.cache import CacheData, CacheRoot
|
from zhenxun.services.cache import CacheRegistry, cache_config
|
||||||
|
from zhenxun.services.cache.config import CacheMode
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
from zhenxun.utils.enum import CacheType
|
from zhenxun.utils.enum import CacheType
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.PLUGINS)
|
# 注册缓存类型
|
||||||
async def _():
|
def register_cache_types():
|
||||||
"""初始化插件缓存"""
|
"""注册所有缓存类型"""
|
||||||
data_list = await PluginInfo.get_plugins()
|
CacheRegistry.register(CacheType.PLUGINS, PluginInfo)
|
||||||
return {p.module: p for p in data_list}
|
CacheRegistry.register(CacheType.GROUPS, GroupConsole)
|
||||||
|
CacheRegistry.register(CacheType.BOT, BotConsole)
|
||||||
|
CacheRegistry.register(CacheType.USERS, UserConsole)
|
||||||
|
CacheRegistry.register(
|
||||||
|
CacheType.LEVEL, LevelUser, key_format="{user_id}_{group_id}"
|
||||||
|
)
|
||||||
|
CacheRegistry.register(CacheType.BAN, BanConsole, key_format="{user_id}_{group_id}")
|
||||||
|
|
||||||
|
if cache_config.cache_mode == CacheMode.NONE:
|
||||||
@CacheRoot.getter(CacheType.PLUGINS, result_model=PluginInfo)
|
logger.info("缓存功能已禁用,将直接从数据库获取数据")
|
||||||
async def _(cache_data: CacheData, module: str):
|
else:
|
||||||
"""获取插件缓存"""
|
logger.info(f"已注册所有缓存类型,缓存模式: {cache_config.cache_mode}")
|
||||||
data = await cache_data.get_key(module)
|
logger.info("使用增量缓存模式,数据将按需加载到缓存中")
|
||||||
if not data:
|
|
||||||
if plugin := await PluginInfo.get_plugin(module=module):
|
|
||||||
await cache_data.set_key(module, plugin)
|
|
||||||
logger.debug(f"插件 {module} 数据已设置到缓存")
|
|
||||||
return plugin
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.PLUGINS)
|
|
||||||
async def _(cache_data: CacheData, data: dict[str, PluginInfo] | None):
|
|
||||||
"""刷新插件缓存"""
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
plugins = await PluginInfo.filter(module__in=data.keys(), load_status=True).all()
|
|
||||||
for plugin in plugins:
|
|
||||||
await cache_data.set_key(plugin.module, plugin)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.GROUPS)
|
|
||||||
async def _():
|
|
||||||
"""初始化群组缓存"""
|
|
||||||
data_list = await GroupConsole.all()
|
|
||||||
return {p.group_id: p for p in data_list if not p.channel_id}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.GROUPS, result_model=GroupConsole)
|
|
||||||
async def _(cache_data: CacheData, group_id: str):
|
|
||||||
"""获取群组缓存"""
|
|
||||||
data = await cache_data.get_key(group_id)
|
|
||||||
if not data:
|
|
||||||
if group := await GroupConsole.get_group(group_id=group_id):
|
|
||||||
await cache_data.set_key(group_id, group)
|
|
||||||
return group
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.GROUPS)
|
|
||||||
async def _(cache_data: CacheData, data: dict[str, GroupConsole] | None):
|
|
||||||
"""刷新群组缓存"""
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
groups = await GroupConsole.filter(
|
|
||||||
group_id__in=data.keys(), channel_id__isnull=True
|
|
||||||
).all()
|
|
||||||
for group in groups:
|
|
||||||
await cache_data.set_key(group.group_id, group)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BOT)
|
|
||||||
async def _():
|
|
||||||
"""初始化机器人缓存"""
|
|
||||||
data_list = await BotConsole.all()
|
|
||||||
return {p.bot_id: p for p in data_list}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BOT, result_model=BotConsole)
|
|
||||||
async def _(cache_data: CacheData, bot_id: str):
|
|
||||||
"""获取机器人缓存"""
|
|
||||||
data = await cache_data.get_key(bot_id)
|
|
||||||
if not data:
|
|
||||||
if bot := await BotConsole.get_or_none(bot_id=bot_id):
|
|
||||||
await cache_data.set_key(bot_id, bot)
|
|
||||||
return bot
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.BOT)
|
|
||||||
async def _(cache_data: CacheData, data: dict[str, BotConsole] | None):
|
|
||||||
"""刷新机器人缓存"""
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
bots = await BotConsole.filter(bot_id__in=data.keys()).all()
|
|
||||||
for bot in bots:
|
|
||||||
await cache_data.set_key(bot.bot_id, bot)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.USERS)
|
|
||||||
async def _():
|
|
||||||
"""初始化用户缓存"""
|
|
||||||
data_list = await UserConsole.all()
|
|
||||||
return {p.user_id: p for p in data_list}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.USERS, result_model=UserConsole)
|
|
||||||
async def _(cache_data: CacheData, user_id: str):
|
|
||||||
"""获取用户缓存"""
|
|
||||||
data = await cache_data.get_key(user_id)
|
|
||||||
if not data:
|
|
||||||
if user := await UserConsole.get_user(user_id=user_id):
|
|
||||||
await cache_data.set_key(user_id, user)
|
|
||||||
return user
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.with_refresh(CacheType.USERS)
|
|
||||||
async def _(cache_data: CacheData, data: dict[str, UserConsole] | None):
|
|
||||||
"""刷新用户缓存"""
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
users = await UserConsole.filter(user_id__in=data.keys()).all()
|
|
||||||
for user in users:
|
|
||||||
await cache_data.set_key(user.user_id, user)
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.LEVEL)
|
|
||||||
async def _():
|
|
||||||
"""初始化等级缓存"""
|
|
||||||
data_list = await LevelUser().all()
|
|
||||||
return {f"{d.user_id}:{d.group_id or ''}": d for d in data_list}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.LEVEL, result_model=list[LevelUser])
|
|
||||||
async def _(cache_data: CacheData, user_id: str, group_id: str | None = None):
|
|
||||||
"""获取等级缓存"""
|
|
||||||
key = f"{user_id}:{group_id or ''}"
|
|
||||||
data = await cache_data.get_key(key)
|
|
||||||
if not data:
|
|
||||||
if group_id:
|
|
||||||
data = await LevelUser.filter(user_id=user_id, group_id=group_id).all()
|
|
||||||
else:
|
|
||||||
data = await LevelUser.filter(user_id=user_id, group_id__isnull=True).all()
|
|
||||||
if data:
|
|
||||||
await cache_data.set_key(key, data)
|
|
||||||
return data
|
|
||||||
return data or []
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.new(CacheType.BAN, False)
|
|
||||||
async def _():
|
|
||||||
"""初始化封禁缓存"""
|
|
||||||
data_list = await BanConsole.all()
|
|
||||||
return {f"{d.group_id or ''}:{d.user_id or ''}": d for d in data_list}
|
|
||||||
|
|
||||||
|
|
||||||
@CacheRoot.getter(CacheType.BAN, result_model=BanConsole)
|
|
||||||
async def _(cache_data: CacheData, user_id: str | None, group_id: str | None = None):
|
|
||||||
"""获取封禁缓存"""
|
|
||||||
if not user_id and not group_id:
|
|
||||||
return []
|
|
||||||
key = f"{group_id or ''}:{user_id or ''}"
|
|
||||||
data = await cache_data.get_key(key)
|
|
||||||
# if not data:
|
|
||||||
# start = time.time()
|
|
||||||
# if user_id and group_id:
|
|
||||||
# data = await BanConsole.filter(user_id=user_id, group_id=group_id).all()
|
|
||||||
# elif user_id:
|
|
||||||
# data = await BanConsole.filter(user_id=user_id, group_id__isnull=True).all()
|
|
||||||
# elif group_id:
|
|
||||||
# data = await BanConsole.filter(
|
|
||||||
# user_id__isnull=True, group_id=group_id
|
|
||||||
# ).all()
|
|
||||||
# logger.info(
|
|
||||||
# f"获取封禁缓存耗时: {time.time() - start:.2f}秒, key: {key}, data: {data}"
|
|
||||||
# )
|
|
||||||
# if data:
|
|
||||||
# await cache_data.set_key(key, data)
|
|
||||||
# return data
|
|
||||||
return data or []
|
|
||||||
|
|
||||||
|
|
||||||
# @CacheRoot.new(CacheType.LIMIT)
|
|
||||||
# async def _():
|
|
||||||
# """初始化限制缓存"""
|
|
||||||
# data_list = await PluginLimit.filter(status=True).all()
|
|
||||||
# return {data.module: data for data in data_list}
|
|
||||||
|
|
||||||
|
|
||||||
# @CacheRoot.getter(CacheType.LIMIT, result_model=list[PluginLimit])
|
|
||||||
# async def _(cache_data: CacheData, module: str):
|
|
||||||
# """获取限制缓存"""
|
|
||||||
# data = await cache_data.get_key(module)
|
|
||||||
# if not data:
|
|
||||||
# if limits := await PluginLimit.filter(module=module, status=True):
|
|
||||||
# await cache_data.set_key(module, limits)
|
|
||||||
# return limits
|
|
||||||
# return data or []
|
|
||||||
|
|
||||||
|
|
||||||
# @CacheRoot.with_refresh(CacheType.LIMIT)
|
|
||||||
# async def _(cache_data: CacheData, data: dict[str, list[PluginLimit]] | None):
|
|
||||||
# """刷新限制缓存"""
|
|
||||||
# if not data:
|
|
||||||
# return
|
|
||||||
# limits = await PluginLimit.filter(module__in=data.keys(), load_status=True).all()
|
|
||||||
# for limit in limits:
|
|
||||||
# await cache_data.set_key(limit.module, limit)
|
|
||||||
|
|||||||
@ -205,7 +205,7 @@ class Manager:
|
|||||||
self.cd_data: dict[str, PluginCdBlock] = {}
|
self.cd_data: dict[str, PluginCdBlock] = {}
|
||||||
if self.cd_file.exists():
|
if self.cd_file.exists():
|
||||||
with open(self.cd_file, encoding="utf8") as f:
|
with open(self.cd_file, encoding="utf8") as f:
|
||||||
temp = _yaml.load(f)
|
temp = _yaml.load(f) or {}
|
||||||
if "PluginCdLimit" in temp.keys():
|
if "PluginCdLimit" in temp.keys():
|
||||||
for k, v in temp["PluginCdLimit"].items():
|
for k, v in temp["PluginCdLimit"].items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
@ -216,7 +216,7 @@ class Manager:
|
|||||||
self.block_data: dict[str, BaseBlock] = {}
|
self.block_data: dict[str, BaseBlock] = {}
|
||||||
if self.block_file.exists():
|
if self.block_file.exists():
|
||||||
with open(self.block_file, encoding="utf8") as f:
|
with open(self.block_file, encoding="utf8") as f:
|
||||||
temp = _yaml.load(f)
|
temp = _yaml.load(f) or {}
|
||||||
if "PluginBlockLimit" in temp.keys():
|
if "PluginBlockLimit" in temp.keys():
|
||||||
for k, v in temp["PluginBlockLimit"].items():
|
for k, v in temp["PluginBlockLimit"].items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
@ -227,7 +227,7 @@ class Manager:
|
|||||||
self.count_data: dict[str, PluginCountBlock] = {}
|
self.count_data: dict[str, PluginCountBlock] = {}
|
||||||
if self.count_file.exists():
|
if self.count_file.exists():
|
||||||
with open(self.count_file, encoding="utf8") as f:
|
with open(self.count_file, encoding="utf8") as f:
|
||||||
temp = _yaml.load(f)
|
temp = _yaml.load(f) or {}
|
||||||
if "PluginCountLimit" in temp.keys():
|
if "PluginCountLimit" in temp.keys():
|
||||||
for k, v in temp["PluginCountLimit"].items():
|
for k, v in temp["PluginCountLimit"].items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
|||||||
@ -33,6 +33,8 @@ class BanConsole(Model):
|
|||||||
|
|
||||||
cache_type = CacheType.BAN
|
cache_type = CacheType.BAN
|
||||||
"""缓存类型"""
|
"""缓存类型"""
|
||||||
|
cache_key_field = ("user_id", "group_id")
|
||||||
|
"""缓存键字段"""
|
||||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||||
"""开启锁"""
|
"""开启锁"""
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,9 @@ class BotConsole(Model):
|
|||||||
table_description = "Bot数据表"
|
table_description = "Bot数据表"
|
||||||
|
|
||||||
cache_type = CacheType.BOT
|
cache_type = CacheType.BOT
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = "bot_id"
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format(name: str) -> str:
|
def format(name: str) -> str:
|
||||||
|
|||||||
@ -90,6 +90,8 @@ class GroupConsole(Model):
|
|||||||
|
|
||||||
cache_type = CacheType.GROUPS
|
cache_type = CacheType.GROUPS
|
||||||
"""缓存类型"""
|
"""缓存类型"""
|
||||||
|
cache_key_field = ("group_id", "channel_id")
|
||||||
|
"""缓存键字段"""
|
||||||
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
enable_lock: ClassVar[list[DbLockType]] = [DbLockType.CREATE]
|
||||||
"""开启锁"""
|
"""开启锁"""
|
||||||
|
|
||||||
@ -123,7 +125,18 @@ class GroupConsole(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@CacheRoot.listener(CacheType.GROUPS)
|
async def _update_cache(cls, instance):
|
||||||
|
"""更新缓存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
instance: 需要更新缓存的实例
|
||||||
|
"""
|
||||||
|
if cache_type := cls.get_cache_type():
|
||||||
|
key = cls.get_cache_key(instance)
|
||||||
|
if key is not None:
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
cls, using_db: BaseDBAsyncClient | None = None, **kwargs: Any
|
||||||
) -> Self:
|
) -> Self:
|
||||||
@ -136,6 +149,9 @@ class GroupConsole(Model):
|
|||||||
if task_modules or plugin_modules:
|
if task_modules or plugin_modules:
|
||||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -187,14 +203,13 @@ class GroupConsole(Model):
|
|||||||
if task_modules or plugin_modules:
|
if task_modules or plugin_modules:
|
||||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
if is_create:
|
if is_create:
|
||||||
if cache := await CacheRoot.get_cache(CacheType.GROUPS):
|
await cls._update_cache(group)
|
||||||
await cache.update(group.group_id, group)
|
|
||||||
|
|
||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@CacheRoot.listener(CacheType.GROUPS)
|
|
||||||
async def update_or_create(
|
async def update_or_create(
|
||||||
cls,
|
cls,
|
||||||
defaults: dict | None = None,
|
defaults: dict | None = None,
|
||||||
@ -214,6 +229,9 @@ class GroupConsole(Model):
|
|||||||
if task_modules or plugin_modules:
|
if task_modules or plugin_modules:
|
||||||
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
await cls._update_modules(group, task_modules, plugin_modules, using_db)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
return group, is_create
|
return group, is_create
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -327,6 +345,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_unblock_plugin(
|
async def set_unblock_plugin(
|
||||||
cls,
|
cls,
|
||||||
@ -363,6 +384,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def is_normal_block_plugin(
|
async def is_normal_block_plugin(
|
||||||
cls, group_id: str, module: str, channel_id: str | None = None
|
cls, group_id: str, module: str, channel_id: str | None = None
|
||||||
@ -466,6 +490,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def set_unblock_task(
|
async def set_unblock_task(
|
||||||
cls,
|
cls,
|
||||||
@ -500,6 +527,9 @@ class GroupConsole(Model):
|
|||||||
if update_fields:
|
if update_fields:
|
||||||
await group.save(update_fields=update_fields)
|
await group.save(update_fields=update_fields)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
await cls._update_cache(group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _run_script(cls):
|
def _run_script(cls):
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -22,6 +22,9 @@ class LevelUser(Model):
|
|||||||
unique_together = ("user_id", "group_id")
|
unique_together = ("user_id", "group_id")
|
||||||
|
|
||||||
cache_type = CacheType.LEVEL
|
cache_type = CacheType.LEVEL
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = ("user_id", "group_id")
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
async def get_user_level(cls, user_id: str, group_id: str | None) -> int:
|
||||||
|
|||||||
@ -60,6 +60,9 @@ class PluginInfo(Model):
|
|||||||
table_description = "插件基本信息"
|
table_description = "插件基本信息"
|
||||||
|
|
||||||
cache_type = CacheType.PLUGINS
|
cache_type = CacheType.PLUGINS
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = "module"
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_plugin(
|
async def get_plugin(
|
||||||
|
|||||||
@ -31,6 +31,9 @@ class UserConsole(Model):
|
|||||||
table_description = "用户数据表"
|
table_description = "用户数据表"
|
||||||
|
|
||||||
cache_type = CacheType.USERS
|
cache_type = CacheType.USERS
|
||||||
|
"""缓存类型"""
|
||||||
|
cache_key_field = "user_id"
|
||||||
|
"""缓存键字段"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
async def get_user(cls, user_id: str, platform: str | None = None) -> "UserConsole":
|
||||||
|
|||||||
@ -1,703 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from datetime import datetime
|
|
||||||
from functools import wraps
|
|
||||||
import inspect
|
|
||||||
from typing import Any, ClassVar, Generic, TypeVar
|
|
||||||
|
|
||||||
from aiocache import Cache as AioCache
|
|
||||||
|
|
||||||
# from aiocache.backends.redis import RedisCache
|
|
||||||
from aiocache.base import BaseCache
|
|
||||||
from aiocache.serializers import JsonSerializer
|
|
||||||
import nonebot
|
|
||||||
from nonebot.compat import model_dump
|
|
||||||
from nonebot.utils import is_coroutine_callable
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from zhenxun.services.log import logger
|
|
||||||
|
|
||||||
__all__ = ["Cache", "CacheData", "CacheRoot"]
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
LOG_COMMAND = "cache"
|
|
||||||
|
|
||||||
driver = nonebot.get_driver()
|
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
|
||||||
enable_cache: bool = True
|
|
||||||
"""是否开启缓存功能"""
|
|
||||||
redis_host: str | None = None
|
|
||||||
"""redis地址"""
|
|
||||||
redis_port: int | None = None
|
|
||||||
"""redis端口"""
|
|
||||||
redis_password: str | None = None
|
|
||||||
"""redis密码"""
|
|
||||||
redis_expire: int = 600
|
|
||||||
"""redis过期时间"""
|
|
||||||
|
|
||||||
|
|
||||||
config = nonebot.get_plugin_config(Config)
|
|
||||||
|
|
||||||
|
|
||||||
class DbCacheException(Exception):
|
|
||||||
"""缓存相关异常"""
|
|
||||||
|
|
||||||
def __init__(self, info: str):
|
|
||||||
self.info = info
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.info
|
|
||||||
|
|
||||||
|
|
||||||
def validate_name(func: Callable):
|
|
||||||
"""验证缓存名称是否存在的装饰器"""
|
|
||||||
|
|
||||||
def wrapper(self, name: str, *args, **kwargs):
|
|
||||||
_name = name.upper()
|
|
||||||
if _name not in CacheManager._data:
|
|
||||||
raise DbCacheException(f"缓存数据 {name} 不存在")
|
|
||||||
return func(self, _name, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class CacheGetter(BaseModel, Generic[T]):
|
|
||||||
"""缓存数据获取器"""
|
|
||||||
|
|
||||||
get_func: Callable[..., Any] | None = None
|
|
||||||
get_all_func: Callable[..., Any] | None = None
|
|
||||||
|
|
||||||
async def get(self, cache_data: "CacheData", key: str, *args, **kwargs) -> T:
|
|
||||||
"""获取单个缓存数据"""
|
|
||||||
if not self.get_func:
|
|
||||||
data = await cache_data.get_key(key)
|
|
||||||
if cache_data.result_model:
|
|
||||||
return cache_data._deserialize_value(data, cache_data.result_model)
|
|
||||||
return data
|
|
||||||
|
|
||||||
if is_coroutine_callable(self.get_func):
|
|
||||||
data = await self.get_func(cache_data, key, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
data = self.get_func(cache_data, key, *args, **kwargs)
|
|
||||||
|
|
||||||
if cache_data.result_model:
|
|
||||||
return cache_data._deserialize_value(data, cache_data.result_model)
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def get_all(self, cache_data: "CacheData", *args, **kwargs) -> dict[str, T]:
|
|
||||||
"""获取所有缓存数据"""
|
|
||||||
if not self.get_all_func:
|
|
||||||
data = await cache_data.get_all_data()
|
|
||||||
if cache_data.result_model:
|
|
||||||
return {
|
|
||||||
k: cache_data._deserialize_value(v, cache_data.result_model)
|
|
||||||
for k, v in data.items()
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
|
|
||||||
if is_coroutine_callable(self.get_all_func):
|
|
||||||
data = await self.get_all_func(cache_data, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
data = self.get_all_func(cache_data, *args, **kwargs)
|
|
||||||
|
|
||||||
if cache_data.result_model:
|
|
||||||
return {
|
|
||||||
k: cache_data._deserialize_value(v, cache_data.result_model)
|
|
||||||
for k, v in data.items()
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class CacheData(BaseModel):
|
|
||||||
"""缓存数据模型"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
func: Callable[..., Any]
|
|
||||||
getter: CacheGetter | None = None
|
|
||||||
updater: Callable[..., Any] | None = None
|
|
||||||
with_refresh: Callable[..., Any] | None = None
|
|
||||||
expire: int = 600 # 默认10分钟过期
|
|
||||||
reload_count: int = 0
|
|
||||||
lazy_load: bool = True # 默认延迟加载
|
|
||||||
result_model: type | None = None
|
|
||||||
_keys: set[str] = set() # 存储所有缓存键
|
|
||||||
cache: BaseCache | AioCache
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def _deserialize_value(self, value: Any, target_type: type | None = None) -> Any:
|
|
||||||
"""反序列化值,将JSON数据转换回原始类型
|
|
||||||
|
|
||||||
参数:
|
|
||||||
value: 需要反序列化的值
|
|
||||||
target_type: 目标类型,用于指导反序列化
|
|
||||||
|
|
||||||
返回:
|
|
||||||
反序列化后的值
|
|
||||||
"""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 如果是字典且指定了目标类型
|
|
||||||
if isinstance(value, dict) and target_type:
|
|
||||||
# 处理Tortoise-ORM Model
|
|
||||||
if hasattr(target_type, "_meta"):
|
|
||||||
return self._extracted_from__deserialize_value_19(value, target_type)
|
|
||||||
elif hasattr(target_type, "model_validate"):
|
|
||||||
return target_type.model_validate(value)
|
|
||||||
elif hasattr(target_type, "from_dict"):
|
|
||||||
return target_type.from_dict(value)
|
|
||||||
elif hasattr(target_type, "parse_obj"):
|
|
||||||
return target_type.parse_obj(value)
|
|
||||||
else:
|
|
||||||
return target_type(**value)
|
|
||||||
|
|
||||||
# 处理列表类型
|
|
||||||
if isinstance(value, list):
|
|
||||||
if not value:
|
|
||||||
return value
|
|
||||||
if (
|
|
||||||
target_type
|
|
||||||
and hasattr(target_type, "__origin__")
|
|
||||||
and target_type.__origin__ is list
|
|
||||||
):
|
|
||||||
item_type = target_type.__args__[0]
|
|
||||||
return [self._deserialize_value(item, item_type) for item in value]
|
|
||||||
return [self._deserialize_value(item) for item in value]
|
|
||||||
|
|
||||||
# 处理字典类型
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return {k: self._deserialize_value(v) for k, v in value.items()}
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _extracted_from__deserialize_value_19(self, value, target_type):
|
|
||||||
# 处理字段值
|
|
||||||
processed_value = {}
|
|
||||||
for field_name, field_value in value.items():
|
|
||||||
if field := target_type._meta.fields_map.get(field_name):
|
|
||||||
# 跳过反向关系字段
|
|
||||||
if hasattr(field, "_related_name"):
|
|
||||||
continue
|
|
||||||
processed_value[field_name] = field_value
|
|
||||||
|
|
||||||
logger.debug(f"处理后的值: {processed_value}")
|
|
||||||
|
|
||||||
# 创建模型实例
|
|
||||||
instance = target_type()
|
|
||||||
# 设置字段值
|
|
||||||
for field_name, field_value in processed_value.items():
|
|
||||||
if field_name in target_type._meta.fields_map:
|
|
||||||
field = target_type._meta.fields_map[field_name]
|
|
||||||
# 设置字段值
|
|
||||||
try:
|
|
||||||
if hasattr(field, "to_python_value"):
|
|
||||||
if not field.field_type:
|
|
||||||
logger.debug(f"字段 {field_name} 类型为空")
|
|
||||||
continue
|
|
||||||
field_value = field.to_python_value(field_value)
|
|
||||||
setattr(instance, field_name, field_value)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"设置字段 {field_name} 失败", e=e)
|
|
||||||
|
|
||||||
# 设置 _saved_in_db 标志
|
|
||||||
instance._saved_in_db = True
|
|
||||||
return instance
|
|
||||||
|
|
||||||
async def get_data(self) -> Any:
|
|
||||||
"""从缓存获取数据"""
|
|
||||||
try:
|
|
||||||
data = await self.cache.get(self.name) # type: ignore
|
|
||||||
logger.debug(f"获取缓存 {self.name} 数据: {data}")
|
|
||||||
|
|
||||||
# 如果数据为空,尝试重新加载
|
|
||||||
# if data is None:
|
|
||||||
# logger.debug(f"缓存 {self.name} 数据为空,尝试重新加载")
|
|
||||||
# try:
|
|
||||||
# if self.has_args():
|
|
||||||
# new_data = (
|
|
||||||
# await self.func()
|
|
||||||
# if is_coroutine_callable(self.func)
|
|
||||||
# else self.func()
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# new_data = (
|
|
||||||
# await self.func()
|
|
||||||
# if is_coroutine_callable(self.func)
|
|
||||||
# else self.func()
|
|
||||||
# )
|
|
||||||
|
|
||||||
# await self.set_data(new_data)
|
|
||||||
# self.reload_count += 1
|
|
||||||
# logger.info(f"重新加载缓存 {self.name} 完成")
|
|
||||||
# return new_data
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"重新加载缓存 {self.name} 失败: {e}")
|
|
||||||
# return None
|
|
||||||
|
|
||||||
# 使用 result_model 进行反序列化
|
|
||||||
if self.result_model:
|
|
||||||
return self._deserialize_value(data, self.result_model)
|
|
||||||
|
|
||||||
return data
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取缓存 {self.name} 失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _serialize_value(self, value: Any) -> Any:
|
|
||||||
"""序列化值,将数据转换为JSON可序列化的格式
|
|
||||||
|
|
||||||
参数:
|
|
||||||
value: 需要序列化的值
|
|
||||||
|
|
||||||
返回:
|
|
||||||
JSON可序列化的值
|
|
||||||
"""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 处理datetime
|
|
||||||
if isinstance(value, datetime):
|
|
||||||
return value.isoformat()
|
|
||||||
|
|
||||||
# 处理Tortoise-ORM Model
|
|
||||||
if hasattr(value, "_meta") and hasattr(value, "__dict__"):
|
|
||||||
result = {}
|
|
||||||
for field in value._meta.fields:
|
|
||||||
try:
|
|
||||||
field_value = getattr(value, field)
|
|
||||||
# 跳过反向关系字段
|
|
||||||
if isinstance(field_value, list | set) and hasattr(
|
|
||||||
field_value, "_related_name"
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
# 跳过外键关系字段
|
|
||||||
if hasattr(field_value, "_meta"):
|
|
||||||
field_value = getattr(
|
|
||||||
field_value, value._meta.fields[field].related_name or "id"
|
|
||||||
)
|
|
||||||
result[field] = self._serialize_value(field_value)
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 处理Pydantic模型
|
|
||||||
elif isinstance(value, BaseModel):
|
|
||||||
return model_dump(value)
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
# 处理字典
|
|
||||||
return {str(k): self._serialize_value(v) for k, v in value.items()}
|
|
||||||
elif isinstance(value, list | tuple | set):
|
|
||||||
# 处理列表、元组、集合
|
|
||||||
return [self._serialize_value(item) for item in value]
|
|
||||||
elif isinstance(value, int | float | str | bool):
|
|
||||||
# 基本类型直接返回
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
# 其他类型转换为字符串
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
async def set_data(self, value: Any):
|
|
||||||
"""设置缓存数据"""
|
|
||||||
try:
|
|
||||||
# 1. 序列化数据
|
|
||||||
serialized_value = self._serialize_value(value)
|
|
||||||
logger.debug(f"设置缓存 {self.name} 原始数据: {value}")
|
|
||||||
logger.debug(f"设置缓存 {self.name} 序列化后数据: {serialized_value}")
|
|
||||||
|
|
||||||
# 2. 删除旧数据
|
|
||||||
await self.cache.delete(self.name) # type: ignore
|
|
||||||
logger.debug(f"删除缓存 {self.name} 旧数据")
|
|
||||||
|
|
||||||
# 3. 设置新数据
|
|
||||||
await self.cache.set(self.name, serialized_value, ttl=self.expire) # type: ignore
|
|
||||||
logger.debug(f"设置缓存 {self.name} 新数据完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"设置缓存 {self.name} 失败: {e}")
|
|
||||||
raise # 重新抛出异常,让上层处理
|
|
||||||
|
|
||||||
async def delete_data(self):
|
|
||||||
"""删除缓存数据"""
|
|
||||||
try:
|
|
||||||
await self.cache.delete(self.name) # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"删除缓存 {self.name}", e=e)
|
|
||||||
|
|
||||||
async def get(self, key: str, *args, **kwargs) -> Any:
|
|
||||||
"""获取缓存"""
|
|
||||||
if not self.reload_count and not self.lazy_load:
|
|
||||||
await self.reload(*args, **kwargs)
|
|
||||||
|
|
||||||
if not self.getter:
|
|
||||||
return await self.get_key(key)
|
|
||||||
|
|
||||||
return await self.getter.get(self, key, *args, **kwargs)
|
|
||||||
|
|
||||||
async def get_all(self, *args, **kwargs) -> dict[str, Any]:
|
|
||||||
"""获取所有缓存数据"""
|
|
||||||
if not self.reload_count and not self.lazy_load:
|
|
||||||
await self.reload(*args, **kwargs)
|
|
||||||
|
|
||||||
if not self.getter:
|
|
||||||
return await self.get_all_data()
|
|
||||||
|
|
||||||
return await self.getter.get_all(self, *args, **kwargs)
|
|
||||||
|
|
||||||
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
|
||||||
"""更新单个缓存项"""
|
|
||||||
if not self.updater:
|
|
||||||
logger.warning(f"缓存 {self.name} 未配置更新方法")
|
|
||||||
return
|
|
||||||
|
|
||||||
current_data = await self.get_key(key) or {}
|
|
||||||
if is_coroutine_callable(self.updater):
|
|
||||||
await self.updater(current_data, key, value, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
self.updater(current_data, key, value, *args, **kwargs)
|
|
||||||
|
|
||||||
await self.set_key(key, current_data)
|
|
||||||
logger.debug(f"更新缓存 {self.name}.{key}")
|
|
||||||
|
|
||||||
async def refresh(self, *args, **kwargs):
|
|
||||||
"""刷新缓存数据"""
|
|
||||||
if not self.with_refresh:
|
|
||||||
return await self.reload(*args, **kwargs)
|
|
||||||
|
|
||||||
current_data = await self.get_data()
|
|
||||||
if current_data:
|
|
||||||
if is_coroutine_callable(self.with_refresh):
|
|
||||||
await self.with_refresh(current_data, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
self.with_refresh(current_data, *args, **kwargs)
|
|
||||||
await self.set_data(current_data)
|
|
||||||
logger.debug(f"刷新缓存 {self.name}")
|
|
||||||
|
|
||||||
async def reload(self, *args, **kwargs):
|
|
||||||
"""重新加载全部数据"""
|
|
||||||
try:
|
|
||||||
if self.has_args():
|
|
||||||
new_data = (
|
|
||||||
await self.func(*args, **kwargs)
|
|
||||||
if is_coroutine_callable(self.func)
|
|
||||||
else self.func(*args, **kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_data = (
|
|
||||||
await self.func()
|
|
||||||
if is_coroutine_callable(self.func)
|
|
||||||
else self.func()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果是字典,则分别存储每个键值对
|
|
||||||
if isinstance(new_data, dict):
|
|
||||||
for key, value in new_data.items():
|
|
||||||
await self.set_key(key, value)
|
|
||||||
else:
|
|
||||||
# 如果不是字典,则存储为单个键值对
|
|
||||||
await self.set_key("default", new_data)
|
|
||||||
|
|
||||||
self.reload_count += 1
|
|
||||||
logger.info(f"重新加载缓存 {self.name} 完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"重新加载缓存 {self.name} 失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def has_args(self) -> bool:
|
|
||||||
"""检查函数是否需要参数"""
|
|
||||||
sig = inspect.signature(self.func)
|
|
||||||
return any(
|
|
||||||
param.kind
|
|
||||||
in (
|
|
||||||
param.POSITIONAL_OR_KEYWORD,
|
|
||||||
param.POSITIONAL_ONLY,
|
|
||||||
param.VAR_POSITIONAL,
|
|
||||||
)
|
|
||||||
for param in sig.parameters.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_key(self, key: str) -> Any:
|
|
||||||
"""获取缓存中指定键的数据
|
|
||||||
|
|
||||||
参数:
|
|
||||||
key: 要获取的键名
|
|
||||||
|
|
||||||
返回:
|
|
||||||
键对应的值,如果不存在返回None
|
|
||||||
"""
|
|
||||||
cache_key = self._get_cache_key(key)
|
|
||||||
try:
|
|
||||||
data = await self.cache.get(cache_key) # type: ignore
|
|
||||||
logger.debug(f"获取缓存 {cache_key} 数据: {data}")
|
|
||||||
|
|
||||||
if self.result_model:
|
|
||||||
return self._deserialize_value(data, self.result_model)
|
|
||||||
return data
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取缓存 {cache_key} 失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_keys(self, keys: list[str]) -> dict[str, Any]:
|
|
||||||
"""获取缓存中多个键的数据
|
|
||||||
|
|
||||||
参数:
|
|
||||||
keys: 要获取的键名列表
|
|
||||||
|
|
||||||
返回:
|
|
||||||
包含所有请求键值的字典,不存在的键值为None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
data = await self.get_data()
|
|
||||||
if isinstance(data, dict):
|
|
||||||
return {key: data.get(key) for key in keys}
|
|
||||||
return dict.fromkeys(keys)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取缓存 {self.name} 的多个键失败: {e}")
|
|
||||||
return dict.fromkeys(keys)
|
|
||||||
|
|
||||||
def _get_cache_key(self, key: str) -> str:
|
|
||||||
"""获取缓存键名"""
|
|
||||||
return f"{self.name}:{key}"
|
|
||||||
|
|
||||||
async def get_all_data(self) -> dict[str, Any]:
|
|
||||||
"""获取所有缓存数据"""
|
|
||||||
try:
|
|
||||||
result = {}
|
|
||||||
for key in self._keys:
|
|
||||||
# 提取原始键名(去掉前缀)
|
|
||||||
original_key = key.split(":", 1)[1]
|
|
||||||
data = await self.cache.get(key) # type: ignore
|
|
||||||
if self.result_model:
|
|
||||||
result[original_key] = self._deserialize_value(
|
|
||||||
data, self.result_model
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result[original_key] = data
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取所有缓存数据失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def set_key(self, key: str, value: Any):
|
|
||||||
"""设置指定键的缓存数据"""
|
|
||||||
cache_key = self._get_cache_key(key)
|
|
||||||
try:
|
|
||||||
serialized_value = self._serialize_value(value)
|
|
||||||
await self.cache.set(cache_key, serialized_value, ttl=self.expire) # type: ignore
|
|
||||||
self._keys.add(cache_key) # 添加到键列表
|
|
||||||
logger.debug(f"设置缓存 {cache_key} 数据完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"设置缓存 {cache_key} 失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete_key(self, key: str):
|
|
||||||
"""删除指定键的缓存数据"""
|
|
||||||
cache_key = self._get_cache_key(key)
|
|
||||||
try:
|
|
||||||
await self.cache.delete(cache_key) # type: ignore
|
|
||||||
self._keys.discard(cache_key) # 从键列表中移除
|
|
||||||
logger.debug(f"删除缓存 {cache_key} 完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"删除缓存 {cache_key} 失败: {e}")
|
|
||||||
|
|
||||||
async def clear(self):
|
|
||||||
"""清除所有缓存数据"""
|
|
||||||
try:
|
|
||||||
for key in list(self._keys): # 使用列表复制避免在迭代时修改
|
|
||||||
await self.cache.delete(key) # type: ignore
|
|
||||||
self._keys.clear()
|
|
||||||
logger.debug(f"清除缓存 {self.name} 完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清除缓存 {self.name} 失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
|
||||||
"""全局缓存管理器"""
|
|
||||||
|
|
||||||
_cache_instance: BaseCache | AioCache | None = None
|
|
||||||
_data: ClassVar[dict[str, CacheData]] = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _cache(self) -> BaseCache | AioCache:
|
|
||||||
"""获取aiocache实例"""
|
|
||||||
if self._cache_instance is None:
|
|
||||||
if config.redis_host:
|
|
||||||
self._cache_instance = AioCache(
|
|
||||||
AioCache.REDIS, # type: ignore
|
|
||||||
serializer=JsonSerializer(),
|
|
||||||
namespace="zhenxun_cache",
|
|
||||||
timeout=30, # 操作超时时间
|
|
||||||
ttl=config.redis_expire, # 设置默认过期时间
|
|
||||||
endpoint=config.redis_host,
|
|
||||||
port=config.redis_port,
|
|
||||||
password=config.redis_password,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._cache_instance = AioCache(
|
|
||||||
AioCache.MEMORY,
|
|
||||||
serializer=JsonSerializer(),
|
|
||||||
namespace="zhenxun_cache",
|
|
||||||
timeout=30, # 操作超时时间
|
|
||||||
ttl=config.redis_expire, # 设置默认过期时间
|
|
||||||
)
|
|
||||||
logger.info("初始化缓存完成...", LOG_COMMAND)
|
|
||||||
return self._cache_instance
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
if self._cache_instance:
|
|
||||||
await self._cache_instance.close() # type: ignore
|
|
||||||
|
|
||||||
async def verify_connection(self):
|
|
||||||
"""连接测试"""
|
|
||||||
try:
|
|
||||||
await self._cache.get("__test__") # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("连接失败", LOG_COMMAND, e=e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def init_non_lazy_caches(self):
|
|
||||||
"""初始化所有非延迟加载的缓存"""
|
|
||||||
await self.verify_connection()
|
|
||||||
for name, cache in self._data.items():
|
|
||||||
cache.cache = self._cache
|
|
||||||
if not cache.lazy_load:
|
|
||||||
try:
|
|
||||||
await cache.reload()
|
|
||||||
logger.info(f"初始化缓存 {name} 完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"初始化缓存 {name} 失败: {e}")
|
|
||||||
|
|
||||||
def new(self, name: str, lazy_load: bool = True, expire: int = 600):
|
|
||||||
"""注册新缓存
|
|
||||||
|
|
||||||
参数:
|
|
||||||
name: 缓存名称
|
|
||||||
lazy_load: 是否延迟加载,默认为True。为False时会在程序启动时自动加载
|
|
||||||
expire: 过期时间(秒)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
|
||||||
_name = name.upper()
|
|
||||||
if _name in self._data:
|
|
||||||
raise DbCacheException(f"缓存 {name} 已存在")
|
|
||||||
|
|
||||||
self._data[_name] = CacheData(
|
|
||||||
name=_name,
|
|
||||||
func=func,
|
|
||||||
expire=expire,
|
|
||||||
lazy_load=lazy_load,
|
|
||||||
cache=self._cache,
|
|
||||||
)
|
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def listener(self, name: str):
|
|
||||||
"""创建缓存监听器"""
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
return (
|
|
||||||
await func(*args, **kwargs)
|
|
||||||
if is_coroutine_callable(func)
|
|
||||||
else func(*args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
cache = self._data.get(name.upper())
|
|
||||||
if cache and cache.with_refresh:
|
|
||||||
await cache.refresh()
|
|
||||||
logger.debug(f"监听器触发缓存 {name} 刷新")
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@validate_name
|
|
||||||
def updater(self, name: str):
|
|
||||||
"""设置缓存更新方法"""
|
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
|
||||||
self._data[name].updater = func
|
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
@validate_name
|
|
||||||
def getter(self, name: str, result_model: type):
|
|
||||||
"""设置缓存获取方法"""
|
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
|
||||||
self._data[name].getter = CacheGetter[result_model](get_func=func)
|
|
||||||
self._data[name].result_model = result_model
|
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
@validate_name
|
|
||||||
def with_refresh(self, name: str):
|
|
||||||
"""设置缓存刷新方法"""
|
|
||||||
|
|
||||||
def wrapper(func: Callable):
|
|
||||||
self._data[name].with_refresh = func
|
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
async def get_cache_data(self, name: str) -> Any | None:
|
|
||||||
"""获取缓存数据"""
|
|
||||||
cache = await self.get_cache(name.upper())
|
|
||||||
return await cache.get_data() if cache else None
|
|
||||||
|
|
||||||
async def get_cache(self, name: str) -> CacheData | None:
|
|
||||||
"""获取缓存对象"""
|
|
||||||
return self._data.get(name.upper())
|
|
||||||
|
|
||||||
async def get(self, name: str, *args, **kwargs) -> Any:
|
|
||||||
"""获取缓存内容"""
|
|
||||||
cache = await self.get_cache(name.upper())
|
|
||||||
return await cache.get(*args, **kwargs) if cache else None
|
|
||||||
|
|
||||||
async def update(self, name: str, key: str, value: Any = None, *args, **kwargs):
|
|
||||||
"""更新缓存项"""
|
|
||||||
cache = await self.get_cache(name.upper())
|
|
||||||
if cache:
|
|
||||||
await cache.update(key, value, *args, **kwargs)
|
|
||||||
|
|
||||||
async def reload(self, name: str, *args, **kwargs):
|
|
||||||
"""重新加载缓存"""
|
|
||||||
cache = await self.get_cache(name.upper())
|
|
||||||
if cache:
|
|
||||||
await cache.reload(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# 全局缓存管理器实例
|
|
||||||
CacheRoot = CacheManager()
|
|
||||||
|
|
||||||
|
|
||||||
class Cache(Generic[T]):
|
|
||||||
"""类型化缓存访问接口"""
|
|
||||||
|
|
||||||
def __init__(self, module: str):
|
|
||||||
self.module = module.upper()
|
|
||||||
|
|
||||||
async def get(self, *args, **kwargs) -> T | None:
|
|
||||||
"""获取缓存"""
|
|
||||||
return await CacheRoot.get(self.module, *args, **kwargs)
|
|
||||||
|
|
||||||
async def update(self, key: str, value: Any = None, *args, **kwargs):
|
|
||||||
"""更新缓存项"""
|
|
||||||
await CacheRoot.update(self.module, key, value, *args, **kwargs)
|
|
||||||
|
|
||||||
async def reload(self, *args, **kwargs):
|
|
||||||
"""重新加载缓存"""
|
|
||||||
await CacheRoot.reload(self.module, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_shutdown
|
|
||||||
async def _():
|
|
||||||
await CacheRoot.close()
|
|
||||||
1032
zhenxun/services/cache/__init__.py
vendored
Normal file
1032
zhenxun/services/cache/__init__.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
424
zhenxun/services/cache/cache_containers.py
vendored
Normal file
424
zhenxun/services/cache/cache_containers.py
vendored
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
|
from .config import LOG_COMMAND
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDict:
|
||||||
|
"""全局缓存字典类,提供类似普通字典的接口,但数据可以在内存中共享"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, expire: int = 0):
|
||||||
|
"""初始化缓存字典
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 字典名称
|
||||||
|
expire: 过期时间(秒),默认为0表示永不过期
|
||||||
|
"""
|
||||||
|
self.name = name.upper()
|
||||||
|
self.expire = expire
|
||||||
|
self._data = {}
|
||||||
|
# 自动尝试加载数据
|
||||||
|
self._try_load()
|
||||||
|
|
||||||
|
def _try_load(self):
|
||||||
|
"""尝试加载数据(非异步)"""
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免循环引用
|
||||||
|
from zhenxun.services.cache import CacheRoot
|
||||||
|
|
||||||
|
# 检查是否已有缓存数据
|
||||||
|
if self.name in CacheRoot._data:
|
||||||
|
# 如果有,直接获取
|
||||||
|
data = CacheRoot._data[self.name]._data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
self._data = data
|
||||||
|
except Exception:
|
||||||
|
# 忽略错误,使用空字典
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def load(self) -> bool:
|
||||||
|
"""从缓存加载数据
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否成功加载
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免循环引用
|
||||||
|
from zhenxun.services.cache import CacheRoot
|
||||||
|
|
||||||
|
data = await CacheRoot.get_cache_data(self.name)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
self._data = data
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载缓存字典 {self.name} 失败", LOG_COMMAND, e=e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def save(self) -> bool:
|
||||||
|
"""保存数据到缓存
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否成功保存
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免循环引用
|
||||||
|
from zhenxun.services.cache import CacheData, CacheRoot
|
||||||
|
|
||||||
|
# 检查缓存是否存在
|
||||||
|
if self.name not in CacheRoot._data:
|
||||||
|
# 创建缓存
|
||||||
|
async def get_func():
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
CacheRoot._data[self.name] = CacheData(
|
||||||
|
name=self.name,
|
||||||
|
func=get_func,
|
||||||
|
expire=self.expire,
|
||||||
|
lazy_load=False,
|
||||||
|
cache=CacheRoot._cache,
|
||||||
|
)
|
||||||
|
# 直接设置数据,避免调用func
|
||||||
|
CacheRoot._data[self.name]._data = self._data
|
||||||
|
else:
|
||||||
|
# 直接更新数据
|
||||||
|
CacheRoot._data[self.name]._data = self._data
|
||||||
|
|
||||||
|
# 保存数据
|
||||||
|
await CacheRoot._data[self.name].set_data(self._data)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存缓存字典 {self.name} 失败", LOG_COMMAND, e=e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
"""获取字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 字典值
|
||||||
|
"""
|
||||||
|
return self._data.get(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
|
"""设置字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
value: 字典值
|
||||||
|
"""
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key: str) -> None:
|
||||||
|
"""删除字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
"""
|
||||||
|
if key in self._data:
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""检查键是否存在
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否存在
|
||||||
|
"""
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""获取字典项,如果不存在返回默认值
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 字典值或默认值
|
||||||
|
"""
|
||||||
|
return self._data.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any) -> None:
|
||||||
|
"""设置字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
value: 字典值
|
||||||
|
"""
|
||||||
|
self._data[key] = value
|
||||||
|
|
||||||
|
def pop(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""删除并返回字典项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
key: 字典键
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 字典值或默认值
|
||||||
|
"""
|
||||||
|
return self._data.pop(key, default)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空字典"""
|
||||||
|
self._data.clear()
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
"""获取所有键
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[str]: 键列表
|
||||||
|
"""
|
||||||
|
return list(self._data.keys())
|
||||||
|
|
||||||
|
def values(self) -> list[Any]:
|
||||||
|
"""获取所有值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[Any]: 值列表
|
||||||
|
"""
|
||||||
|
return list(self._data.values())
|
||||||
|
|
||||||
|
def items(self) -> list[tuple[str, Any]]:
|
||||||
|
"""获取所有键值对
|
||||||
|
|
||||||
|
返回:
|
||||||
|
list[tuple[str, Any]]: 键值对列表
|
||||||
|
"""
|
||||||
|
return list(self._data.items())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""获取字典长度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 字典长度
|
||||||
|
"""
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""字符串表示
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 字符串表示
|
||||||
|
"""
|
||||||
|
return f"CacheDict({self.name}, {len(self._data)} items)"
|
||||||
|
|
||||||
|
|
||||||
|
class CacheList:
|
||||||
|
"""全局缓存列表类,提供类似普通列表的接口,但数据可以在内存中共享"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, expire: int = 0):
|
||||||
|
"""初始化缓存列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
name: 列表名称
|
||||||
|
expire: 过期时间(秒),默认为0表示永不过期
|
||||||
|
"""
|
||||||
|
self.name = name.upper()
|
||||||
|
self.expire = expire
|
||||||
|
self._data = []
|
||||||
|
# 自动尝试加载数据
|
||||||
|
self._try_load()
|
||||||
|
|
||||||
|
def _try_load(self):
|
||||||
|
"""尝试加载数据(非异步)"""
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免循环引用
|
||||||
|
from zhenxun.services.cache import CacheRoot
|
||||||
|
|
||||||
|
# 检查是否已有缓存数据
|
||||||
|
if self.name in CacheRoot._data:
|
||||||
|
# 如果有,直接获取
|
||||||
|
data = CacheRoot._data[self.name]._data
|
||||||
|
if isinstance(data, list):
|
||||||
|
self._data = data
|
||||||
|
except Exception:
|
||||||
|
# 忽略错误,使用空列表
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def load(self) -> bool:
|
||||||
|
"""从缓存加载数据
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否成功加载
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免循环引用
|
||||||
|
from zhenxun.services.cache import CacheRoot
|
||||||
|
|
||||||
|
data = await CacheRoot.get_cache_data(self.name)
|
||||||
|
if isinstance(data, list):
|
||||||
|
self._data = data
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载缓存列表 {self.name} 失败", LOG_COMMAND, e=e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def save(self) -> bool:
|
||||||
|
"""保存数据到缓存
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否成功保存
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 延迟导入,避免循环引用
|
||||||
|
from zhenxun.services.cache import CacheData, CacheRoot
|
||||||
|
|
||||||
|
# 检查缓存是否存在
|
||||||
|
if self.name not in CacheRoot._data:
|
||||||
|
# 创建缓存
|
||||||
|
async def get_func():
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
CacheRoot._data[self.name] = CacheData(
|
||||||
|
name=self.name,
|
||||||
|
func=get_func,
|
||||||
|
expire=self.expire,
|
||||||
|
lazy_load=False,
|
||||||
|
cache=CacheRoot._cache,
|
||||||
|
)
|
||||||
|
# 直接设置数据,避免调用func
|
||||||
|
CacheRoot._data[self.name]._data = self._data
|
||||||
|
else:
|
||||||
|
# 直接更新数据
|
||||||
|
CacheRoot._data[self.name]._data = self._data
|
||||||
|
|
||||||
|
# 保存数据
|
||||||
|
await CacheRoot._data[self.name].set_data(self._data)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存缓存列表 {self.name} 失败", LOG_COMMAND, e=e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Any:
|
||||||
|
"""获取列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 列表值
|
||||||
|
"""
|
||||||
|
if 0 <= index < len(self._data):
|
||||||
|
return self._data[index]
|
||||||
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
|
def __setitem__(self, index: int, value: Any) -> None:
|
||||||
|
"""设置列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引
|
||||||
|
value: 列表值
|
||||||
|
"""
|
||||||
|
# 确保索引有效
|
||||||
|
while len(self._data) <= index:
|
||||||
|
self._data.append(None)
|
||||||
|
self._data[index] = value
|
||||||
|
|
||||||
|
def __delitem__(self, index: int) -> None:
|
||||||
|
"""删除列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引
|
||||||
|
"""
|
||||||
|
if 0 <= index < len(self._data):
|
||||||
|
del self._data[index]
|
||||||
|
else:
|
||||||
|
raise IndexError(f"列表索引 {index} 超出范围")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""获取列表长度
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 列表长度
|
||||||
|
"""
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def append(self, value: Any) -> None:
|
||||||
|
"""添加列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 列表值
|
||||||
|
"""
|
||||||
|
self._data.append(value)
|
||||||
|
|
||||||
|
def extend(self, values: list[Any]) -> None:
|
||||||
|
"""扩展列表
|
||||||
|
|
||||||
|
参数:
|
||||||
|
values: 要添加的值列表
|
||||||
|
"""
|
||||||
|
self._data.extend(values)
|
||||||
|
|
||||||
|
def insert(self, index: int, value: Any) -> None:
|
||||||
|
"""插入列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 插入位置
|
||||||
|
value: 列表值
|
||||||
|
"""
|
||||||
|
self._data.insert(index, value)
|
||||||
|
|
||||||
|
def pop(self, index: int = -1) -> Any:
|
||||||
|
"""删除并返回列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index: 列表索引,默认为最后一项
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Any: 列表值
|
||||||
|
"""
|
||||||
|
return self._data.pop(index)
|
||||||
|
|
||||||
|
def remove(self, value: Any) -> None:
|
||||||
|
"""删除第一个匹配的列表项
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 要删除的值
|
||||||
|
"""
|
||||||
|
self._data.remove(value)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空列表"""
|
||||||
|
self._data.clear()
|
||||||
|
|
||||||
|
def index(self, value: Any, start: int = 0, end: int | None = None) -> int:
|
||||||
|
"""查找值的索引
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 要查找的值
|
||||||
|
start: 起始索引
|
||||||
|
end: 结束索引
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 索引位置
|
||||||
|
"""
|
||||||
|
return self._data.index(
|
||||||
|
value, start, end if end is not None else len(self._data)
|
||||||
|
)
|
||||||
|
|
||||||
|
def count(self, value: Any) -> int:
|
||||||
|
"""计算值出现的次数
|
||||||
|
|
||||||
|
参数:
|
||||||
|
value: 要计数的值
|
||||||
|
|
||||||
|
返回:
|
||||||
|
int: 出现次数
|
||||||
|
"""
|
||||||
|
return self._data.count(value)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""字符串表示
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 字符串表示
|
||||||
|
"""
|
||||||
|
return f"CacheList({self.name}, {len(self._data)} items)"
|
||||||
35
zhenxun/services/cache/config.py
vendored
Normal file
35
zhenxun/services/cache/config.py
vendored
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
缓存系统配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 日志标识
|
||||||
|
LOG_COMMAND = "CacheRoot"
|
||||||
|
|
||||||
|
# 默认缓存过期时间(秒)
|
||||||
|
DEFAULT_EXPIRE = 600
|
||||||
|
|
||||||
|
# 缓存键前缀
|
||||||
|
CACHE_KEY_PREFIX = "ZHENXUN"
|
||||||
|
|
||||||
|
# 缓存键分隔符
|
||||||
|
CACHE_KEY_SEPARATOR = ":"
|
||||||
|
|
||||||
|
# 复合键分隔符(用于分隔tuple类型的cache_key_field)
|
||||||
|
COMPOSITE_KEY_SEPARATOR = "_"
|
||||||
|
|
||||||
|
|
||||||
|
# 缓存模式
|
||||||
|
class CacheMode:
|
||||||
|
# 内存缓存 - 使用内存存储缓存数据
|
||||||
|
MEMORY = "MEMORY"
|
||||||
|
# Redis缓存 - 使用Redis服务器存储缓存数据
|
||||||
|
REDIS = "REDIS"
|
||||||
|
# 不使用缓存 - 将使用ttl=0的内存缓存,相当于直接从数据库获取数据
|
||||||
|
NONE = "NONE"
|
||||||
|
|
||||||
|
|
||||||
|
SPECIAL_KEY_FORMATS = {
|
||||||
|
"LEVEL": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||||
|
"BAN": "{user_id}" + COMPOSITE_KEY_SEPARATOR + "{group_id}",
|
||||||
|
"GROUPS": "{group_id}" + COMPOSITE_KEY_SEPARATOR + "{channel_id}",
|
||||||
|
}
|
||||||
@ -1,7 +1,10 @@
|
|||||||
from typing import Any, Generic, TypeVar, cast
|
from typing import Any, Generic, TypeVar, cast
|
||||||
|
|
||||||
from zhenxun.services.cache import CacheRoot
|
from zhenxun.services.cache import Cache, CacheRoot, cache_config
|
||||||
from zhenxun.services.cache import config as cache_config
|
from zhenxun.services.cache.config import (
|
||||||
|
COMPOSITE_KEY_SEPARATOR,
|
||||||
|
CacheMode,
|
||||||
|
)
|
||||||
from zhenxun.services.db_context import Model
|
from zhenxun.services.db_context import Model
|
||||||
from zhenxun.services.log import logger
|
from zhenxun.services.log import logger
|
||||||
|
|
||||||
@ -37,15 +40,89 @@ class DataAccess(Generic[T]):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_cls: type[T], cache_type: str | None = None):
|
def __init__(
|
||||||
|
self, model_cls: type[T], key_field: str = "id", cache_type: str | None = None
|
||||||
|
):
|
||||||
"""初始化数据访问对象
|
"""初始化数据访问对象
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
model_cls: 模型类
|
model_cls: 模型类
|
||||||
cache_type: 缓存类型,如果为None则使用模型类的cache_type属性
|
key_field: 主键字段
|
||||||
"""
|
"""
|
||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
self.cache_type = cache_type or getattr(model_cls, "cache_type", None)
|
self.key_field = getattr(model_cls, "cache_key_field", key_field)
|
||||||
|
self.cache_type = getattr(model_cls, "cache_type", cache_type)
|
||||||
|
|
||||||
|
if not self.cache_type:
|
||||||
|
raise ValueError("缓存类型不能为空")
|
||||||
|
self.cache = Cache(self.cache_type)
|
||||||
|
|
||||||
|
def _build_cache_key_from_kwargs(self, **kwargs) -> str | None:
|
||||||
|
"""从关键字参数构建缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**kwargs: 关键字参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 缓存键,如果无法构建则返回None
|
||||||
|
"""
|
||||||
|
if isinstance(self.key_field, tuple):
|
||||||
|
# 多字段主键
|
||||||
|
key_parts = []
|
||||||
|
for field in self.key_field:
|
||||||
|
key_parts.append(str(kwargs.get(field, "")))
|
||||||
|
|
||||||
|
if key_parts:
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||||
|
return None
|
||||||
|
elif self.key_field in kwargs:
|
||||||
|
# 单字段主键
|
||||||
|
return str(kwargs[self.key_field])
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def safe_get_or_none(self, *args, **kwargs) -> T | None:
|
||||||
|
"""安全的获取单条数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
*args: 查询参数
|
||||||
|
**kwargs: 查询参数
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Optional[T]: 查询结果,如果不存在返回None
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,直接从数据库获取
|
||||||
|
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
return await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||||
|
|
||||||
|
# 尝试从缓存获取
|
||||||
|
try:
|
||||||
|
# 尝试构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
|
||||||
|
# 如果成功构建缓存键,尝试从缓存获取
|
||||||
|
if cache_key is not None:
|
||||||
|
data = await self.cache.get(cache_key)
|
||||||
|
if data:
|
||||||
|
return cast(T, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("从缓存获取数据失败", e=e)
|
||||||
|
|
||||||
|
# 如果缓存中没有,从数据库获取
|
||||||
|
data = await self.model_cls.safe_get_or_none(*args, **kwargs)
|
||||||
|
|
||||||
|
# 如果获取到数据,存入缓存
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
# 生成缓存键
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
async def get_or_none(self, *args, **kwargs) -> T | None:
|
async def get_or_none(self, *args, **kwargs) -> T | None:
|
||||||
"""获取单条数据
|
"""获取单条数据
|
||||||
@ -57,23 +134,169 @@ class DataAccess(Generic[T]):
|
|||||||
返回:
|
返回:
|
||||||
Optional[T]: 查询结果,如果不存在返回None
|
Optional[T]: 查询结果,如果不存在返回None
|
||||||
"""
|
"""
|
||||||
# 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取
|
# 如果没有缓存类型,直接从数据库获取
|
||||||
if not cache_config.enable_cache or not self.cache_type:
|
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||||
return await self.model_cls.safe_get_or_none(*args, **kwargs)
|
return await self.model_cls.get_or_none(*args, **kwargs)
|
||||||
|
|
||||||
# 从缓存获取
|
# 尝试从缓存获取
|
||||||
try:
|
try:
|
||||||
# 生成缓存键
|
# 尝试构建缓存键
|
||||||
key = self._generate_cache_key(kwargs)
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
# 尝试从缓存获取
|
|
||||||
data = await CacheRoot.get(self.cache_type, key)
|
# 如果成功构建缓存键,尝试从缓存获取
|
||||||
if data:
|
if cache_key is not None:
|
||||||
return cast(T, data)
|
data = await self.cache.get(cache_key)
|
||||||
|
if data:
|
||||||
|
return cast(T, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从缓存获取数据失败", e=e)
|
logger.error("从缓存获取数据失败", e=e)
|
||||||
|
|
||||||
# 如果缓存中没有,从数据库获取
|
# 如果缓存中没有,从数据库获取
|
||||||
return await self.model_cls.safe_get_or_none(*args, **kwargs)
|
data = await self.model_cls.get_or_none(*args, **kwargs)
|
||||||
|
|
||||||
|
# 如果获取到数据,存入缓存
|
||||||
|
if data:
|
||||||
|
try:
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
# 生成缓存键
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
logger.debug(f"{self.cache_type} 数据已存入缓存: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def clear_cache(self, **kwargs) -> bool:
|
||||||
|
"""只清除缓存,不影响数据库数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
**kwargs: 查询参数,必须包含主键字段
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 是否成功清除缓存
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,直接返回True
|
||||||
|
if not self.cache_type or cache_config.cache_mode == CacheMode.NONE:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
if cache_key is None:
|
||||||
|
if isinstance(self.key_field, tuple):
|
||||||
|
# 如果是复合键,检查缺少哪些字段
|
||||||
|
missing_fields = [
|
||||||
|
field for field in self.key_field if field not in kwargs
|
||||||
|
]
|
||||||
|
logger.error(
|
||||||
|
f"清除{self.model_cls.__name__}缓存失败: "
|
||||||
|
f"缺少主键字段 {', '.join(missing_fields)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"清除{self.model_cls.__name__}缓存失败: "
|
||||||
|
f"缺少主键字段 {self.key_field}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 删除缓存
|
||||||
|
await self.cache.delete(cache_key)
|
||||||
|
logger.debug(f"已清除{self.model_cls.__name__}缓存: {cache_key}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清除{self.model_cls.__name__}缓存失败", e=e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _build_composite_key(self, data: T) -> str | None:
|
||||||
|
"""构建复合缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data: 数据对象
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 构建的缓存键,如果无法构建则返回None
|
||||||
|
"""
|
||||||
|
# 如果是元组,表示多个字段组成键
|
||||||
|
if isinstance(self.key_field, tuple):
|
||||||
|
# 构建键参数列表
|
||||||
|
key_parts = []
|
||||||
|
for field in self.key_field:
|
||||||
|
value = getattr(data, field, "")
|
||||||
|
key_parts.append(value if value is not None else "")
|
||||||
|
|
||||||
|
# 如果没有有效参数,返回None
|
||||||
|
if not key_parts:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||||
|
|
||||||
|
# 单个字段作为键
|
||||||
|
elif hasattr(data, self.key_field):
|
||||||
|
value = getattr(data, self.key_field, None)
|
||||||
|
return str(value) if value is not None else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_cache_key_for_item(self, item: T) -> str | None:
|
||||||
|
"""为数据项构建缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
item: 数据项
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None: 缓存键,如果无法生成则返回None
|
||||||
|
"""
|
||||||
|
# 如果没有缓存类型,返回None
|
||||||
|
if not self.cache_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取缓存类型的配置信息
|
||||||
|
cache_model = CacheRoot.get_model(self.cache_type)
|
||||||
|
|
||||||
|
# 如果有键格式定义,则需要构建特殊格式的键
|
||||||
|
if cache_model.key_format:
|
||||||
|
# 构建键参数字典
|
||||||
|
key_parts = []
|
||||||
|
# 从格式字符串中提取所需的字段名
|
||||||
|
import re
|
||||||
|
|
||||||
|
field_names = re.findall(r"{([^}]+)}", cache_model.key_format)
|
||||||
|
|
||||||
|
# 收集所有字段值
|
||||||
|
for field in field_names:
|
||||||
|
value = getattr(item, field, "")
|
||||||
|
key_parts.append(value if value is not None else "")
|
||||||
|
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(key_parts)
|
||||||
|
else:
|
||||||
|
# 常规处理,使用主键作为缓存键
|
||||||
|
return self._build_composite_key(item)
|
||||||
|
|
||||||
|
async def _cache_items(self, data_list: list[T]) -> None:
|
||||||
|
"""将数据列表存入缓存
|
||||||
|
|
||||||
|
参数:
|
||||||
|
data_list: 数据列表
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
not data_list
|
||||||
|
or not self.cache_type
|
||||||
|
or cache_config.cache_mode == CacheMode.NONE
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 遍历数据列表,将每条数据存入缓存
|
||||||
|
for item in data_list:
|
||||||
|
cache_key = self._build_cache_key_for_item(item)
|
||||||
|
if cache_key is not None:
|
||||||
|
await self.cache.set(cache_key, item)
|
||||||
|
|
||||||
|
logger.debug(f"{self.cache_type} 数据已存入缓存,数量: {len(data_list)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.cache_type} 数据存入缓存失败", e=e)
|
||||||
|
|
||||||
async def filter(self, *args, **kwargs) -> list[T]:
|
async def filter(self, *args, **kwargs) -> list[T]:
|
||||||
"""筛选数据
|
"""筛选数据
|
||||||
@ -85,30 +308,13 @@ class DataAccess(Generic[T]):
|
|||||||
返回:
|
返回:
|
||||||
List[T]: 查询结果列表
|
List[T]: 查询结果列表
|
||||||
"""
|
"""
|
||||||
# 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取
|
# 从数据库获取数据
|
||||||
if not cache_config.enable_cache or not self.cache_type:
|
data_list = await self.model_cls.filter(*args, **kwargs)
|
||||||
return await self.model_cls.filter(*args, **kwargs)
|
|
||||||
|
|
||||||
# 尝试从缓存获取所有数据后筛选
|
# 将数据存入缓存
|
||||||
try:
|
await self._cache_items(data_list)
|
||||||
# 获取缓存数据
|
|
||||||
cache_data = await CacheRoot.get_cache_data(self.cache_type)
|
|
||||||
if isinstance(cache_data, dict) and cache_data:
|
|
||||||
# 在内存中筛选
|
|
||||||
filtered_data = []
|
|
||||||
for item in cache_data.values():
|
|
||||||
match = not any(
|
|
||||||
not hasattr(item, k) or getattr(item, k) != v
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
)
|
|
||||||
if match:
|
|
||||||
filtered_data.append(item)
|
|
||||||
return cast(list[T], filtered_data)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("从缓存筛选数据失败", e=e)
|
|
||||||
|
|
||||||
# 如果缓存中没有或筛选失败,从数据库获取
|
return data_list
|
||||||
return await self.model_cls.filter(*args, **kwargs)
|
|
||||||
|
|
||||||
async def all(self) -> list[T]:
|
async def all(self) -> list[T]:
|
||||||
"""获取所有数据
|
"""获取所有数据
|
||||||
@ -116,21 +322,13 @@ class DataAccess(Generic[T]):
|
|||||||
返回:
|
返回:
|
||||||
List[T]: 所有数据列表
|
List[T]: 所有数据列表
|
||||||
"""
|
"""
|
||||||
# 如果缓存功能被禁用或模型没有缓存类型,直接从数据库获取
|
# 直接从数据库获取
|
||||||
if not cache_config.enable_cache or not self.cache_type:
|
data_list = await self.model_cls.all()
|
||||||
return await self.model_cls.all()
|
|
||||||
|
|
||||||
# 尝试从缓存获取所有数据
|
# 将数据存入缓存
|
||||||
try:
|
await self._cache_items(data_list)
|
||||||
# 获取缓存数据
|
|
||||||
cache_data = await CacheRoot.get_cache_data(self.cache_type)
|
|
||||||
if isinstance(cache_data, dict) and cache_data:
|
|
||||||
return cast(list[T], list(cache_data.values()))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("从缓存获取所有数据失败", e=e)
|
|
||||||
|
|
||||||
# 如果缓存中没有,从数据库获取
|
return data_list
|
||||||
return await self.model_cls.all()
|
|
||||||
|
|
||||||
async def count(self, *args, **kwargs) -> int:
|
async def count(self, *args, **kwargs) -> int:
|
||||||
"""获取数据数量
|
"""获取数据数量
|
||||||
@ -167,7 +365,24 @@ class DataAccess(Generic[T]):
|
|||||||
返回:
|
返回:
|
||||||
T: 创建的数据
|
T: 创建的数据
|
||||||
"""
|
"""
|
||||||
return await self.model_cls.create(**kwargs)
|
# 创建数据
|
||||||
|
data = await self.model_cls.create(**kwargs)
|
||||||
|
|
||||||
|
# 如果有缓存类型,将数据存入缓存
|
||||||
|
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||||
|
try:
|
||||||
|
# 生成缓存键
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.cache_type} 新创建的数据已存入缓存: {cache_key}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.cache_type} 存入缓存失败,参数: {kwargs}", e=e)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
async def update_or_create(
|
async def update_or_create(
|
||||||
self, defaults: dict[str, Any] | None = None, **kwargs
|
self, defaults: dict[str, Any] | None = None, **kwargs
|
||||||
@ -181,7 +396,24 @@ class DataAccess(Generic[T]):
|
|||||||
返回:
|
返回:
|
||||||
tuple[T, bool]: (数据, 是否创建)
|
tuple[T, bool]: (数据, 是否创建)
|
||||||
"""
|
"""
|
||||||
return await self.model_cls.update_or_create(defaults=defaults, **kwargs)
|
# 更新或创建数据
|
||||||
|
data, created = await self.model_cls.update_or_create(
|
||||||
|
defaults=defaults, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有缓存类型,将数据存入缓存
|
||||||
|
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||||
|
try:
|
||||||
|
# 生成缓存键
|
||||||
|
cache_key = self._build_cache_key_for_item(data)
|
||||||
|
if cache_key is not None:
|
||||||
|
# 存入缓存
|
||||||
|
await self.cache.set(cache_key, data)
|
||||||
|
logger.debug(f"更新或创建的数据已存入缓存: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存入缓存失败,参数: {kwargs}", e=e)
|
||||||
|
|
||||||
|
return data, created
|
||||||
|
|
||||||
async def delete(self, *args, **kwargs) -> int:
|
async def delete(self, *args, **kwargs) -> int:
|
||||||
"""删除数据
|
"""删除数据
|
||||||
@ -193,18 +425,43 @@ class DataAccess(Generic[T]):
|
|||||||
返回:
|
返回:
|
||||||
int: 删除的数据数量
|
int: 删除的数据数量
|
||||||
"""
|
"""
|
||||||
|
# 如果有缓存类型且有key_field参数,先尝试删除缓存
|
||||||
|
if self.cache_type and cache_config.cache_mode != CacheMode.NONE:
|
||||||
|
try:
|
||||||
|
# 尝试构建缓存键
|
||||||
|
cache_key = self._build_cache_key_from_kwargs(**kwargs)
|
||||||
|
|
||||||
|
if cache_key is not None:
|
||||||
|
# 如果成功构建缓存键,直接删除缓存
|
||||||
|
await self.cache.delete(cache_key)
|
||||||
|
logger.debug(f"已删除缓存: {cache_key}")
|
||||||
|
else:
|
||||||
|
# 否则需要先查询出要删除的数据,然后删除对应的缓存
|
||||||
|
items = await self.model_cls.filter(*args, **kwargs)
|
||||||
|
for item in items:
|
||||||
|
item_cache_key = self._build_cache_key_for_item(item)
|
||||||
|
if item_cache_key is not None:
|
||||||
|
await self.cache.delete(item_cache_key)
|
||||||
|
if items:
|
||||||
|
logger.debug(f"已删除{len(items)}条数据的缓存")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("删除缓存失败", e=e)
|
||||||
|
|
||||||
|
# 删除数据
|
||||||
return await self.model_cls.filter(*args, **kwargs).delete()
|
return await self.model_cls.filter(*args, **kwargs).delete()
|
||||||
|
|
||||||
def _generate_cache_key(self, kwargs: dict[str, Any]) -> str:
|
def _generate_cache_key(self, data: T) -> str:
|
||||||
"""根据查询参数生成缓存键
|
"""根据数据对象生成缓存键
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
kwargs: 查询参数
|
data: 数据对象
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
str: 缓存键
|
str: 缓存键
|
||||||
"""
|
"""
|
||||||
# 实现一个简单的键生成算法
|
# 使用新方法构建复合键
|
||||||
if not kwargs:
|
if composite_key := self._build_composite_key(data):
|
||||||
return "default"
|
return composite_key
|
||||||
return "_".join(f"{k}:{v}" for k, v in sorted(kwargs.items()))
|
|
||||||
|
# 如果无法生成复合键,生成一个唯一键
|
||||||
|
return f"object_{id(data)}"
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from zhenxun.utils.exception import HookPriorityException
|
|||||||
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
from zhenxun.utils.manager.priority_manager import PriorityLifecycle
|
||||||
|
|
||||||
from .cache import CacheRoot
|
from .cache import CacheRoot
|
||||||
|
from .cache.config import COMPOSITE_KEY_SEPARATOR
|
||||||
from .log import logger
|
from .log import logger
|
||||||
|
|
||||||
SCRIPT_METHOD = []
|
SCRIPT_METHOD = []
|
||||||
@ -23,14 +24,6 @@ MODELS: list[str] = []
|
|||||||
|
|
||||||
driver = nonebot.get_driver()
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
CACHE_FLAG = False
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_connect
|
|
||||||
def _():
|
|
||||||
global CACHE_FLAG
|
|
||||||
CACHE_FLAG = True
|
|
||||||
|
|
||||||
|
|
||||||
class Model(TortoiseModel):
|
class Model(TortoiseModel):
|
||||||
"""
|
"""
|
||||||
@ -56,8 +49,55 @@ class Model(TortoiseModel):
|
|||||||
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
|
return cls.sem_data.get(cls.__module__, {}).get(lock_type, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cache_type(cls):
|
def get_cache_type(cls) -> str | None:
|
||||||
return getattr(cls, "cache_type", None) if CACHE_FLAG else None
|
return getattr(cls, "cache_type", None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_key_field(cls) -> str | tuple[str]:
|
||||||
|
"""获取缓存键字段名
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | tuple[str]: 缓存键字段名,可能是单个字段名或字段名元组
|
||||||
|
"""
|
||||||
|
if hasattr(cls, "cache_key_field"):
|
||||||
|
return getattr(cls, "cache_key_field", "id")
|
||||||
|
return "id"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cache_key(cls, instance) -> str | None:
|
||||||
|
"""获取缓存键
|
||||||
|
|
||||||
|
参数:
|
||||||
|
instance: 模型实例
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str | None
|
||||||
|
"""
|
||||||
|
key_field = cls.get_cache_key_field()
|
||||||
|
|
||||||
|
# 如果是元组,表示多个字段组成键
|
||||||
|
if isinstance(key_field, tuple):
|
||||||
|
# 构建键参数列表
|
||||||
|
key_parts = []
|
||||||
|
for field in key_field:
|
||||||
|
if hasattr(instance, field):
|
||||||
|
value = getattr(instance, field, None)
|
||||||
|
key_parts.append(value if value is not None else "")
|
||||||
|
else:
|
||||||
|
# 如果缺少任何必要的字段,返回None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果没有有效参数,返回None
|
||||||
|
if not key_parts:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return COMPOSITE_KEY_SEPARATOR.join(str(param) for param in key_parts)
|
||||||
|
|
||||||
|
# 单个字段作为键
|
||||||
|
elif hasattr(instance, key_field):
|
||||||
|
return getattr(instance, key_field, None)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(
|
async def create(
|
||||||
@ -79,7 +119,11 @@ class Model(TortoiseModel):
|
|||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if is_create and (cache_type := cls.get_cache_type()):
|
if is_create and (cache_type := cls.get_cache_type()):
|
||||||
await CacheRoot.reload(cache_type)
|
# 获取缓存键
|
||||||
|
key = cls.get_cache_key(result)
|
||||||
|
await CacheRoot.invalidate_cache(
|
||||||
|
cache_type, key if key is not None else None
|
||||||
|
)
|
||||||
return (result, is_create)
|
return (result, is_create)
|
||||||
else:
|
else:
|
||||||
# 如果没有锁,则执行原来的逻辑
|
# 如果没有锁,则执行原来的逻辑
|
||||||
@ -87,7 +131,11 @@ class Model(TortoiseModel):
|
|||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if is_create and (cache_type := cls.get_cache_type()):
|
if is_create and (cache_type := cls.get_cache_type()):
|
||||||
await CacheRoot.reload(cache_type)
|
# 获取缓存键
|
||||||
|
key = cls.get_cache_key(result)
|
||||||
|
await CacheRoot.invalidate_cache(
|
||||||
|
cache_type, key if key is not None else None
|
||||||
|
)
|
||||||
return (result, is_create)
|
return (result, is_create)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -104,7 +152,11 @@ class Model(TortoiseModel):
|
|||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if cache_type := cls.get_cache_type():
|
if cache_type := cls.get_cache_type():
|
||||||
await CacheRoot.reload(cache_type)
|
# 获取缓存键
|
||||||
|
key = cls.get_cache_key(result[0])
|
||||||
|
await CacheRoot.invalidate_cache(
|
||||||
|
cache_type, key if key is not None else None
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
# 如果没有锁,则执行原来的逻辑
|
# 如果没有锁,则执行原来的逻辑
|
||||||
@ -112,19 +164,23 @@ class Model(TortoiseModel):
|
|||||||
defaults=defaults, using_db=using_db, **kwargs
|
defaults=defaults, using_db=using_db, **kwargs
|
||||||
)
|
)
|
||||||
if cache_type := cls.get_cache_type():
|
if cache_type := cls.get_cache_type():
|
||||||
await CacheRoot.reload(cache_type)
|
# 获取缓存键
|
||||||
|
key = cls.get_cache_key(result[0])
|
||||||
|
await CacheRoot.invalidate_cache(
|
||||||
|
cache_type, key if key is not None else None
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def bulk_create( # type: ignore
|
async def bulk_create( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
objects: Iterable[Self],
|
objects: Iterable[Self], # type: ignore
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
ignore_conflicts: bool = False,
|
ignore_conflicts: bool = False,
|
||||||
update_fields: Iterable[str] | None = None,
|
update_fields: Iterable[str] | None = None,
|
||||||
on_conflict: Iterable[str] | None = None,
|
on_conflict: Iterable[str] | None = None,
|
||||||
using_db: BaseDBAsyncClient | None = None,
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
) -> list[Self]:
|
) -> list[Self]: # type: ignore
|
||||||
result = await super().bulk_create(
|
result = await super().bulk_create(
|
||||||
objects=objects,
|
objects=objects,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -134,17 +190,18 @@ class Model(TortoiseModel):
|
|||||||
using_db=using_db,
|
using_db=using_db,
|
||||||
)
|
)
|
||||||
if cache_type := cls.get_cache_type():
|
if cache_type := cls.get_cache_type():
|
||||||
await CacheRoot.reload(cache_type)
|
# 批量创建时清除整个类型的缓存
|
||||||
|
await CacheRoot.invalidate_cache(cache_type)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def bulk_update( # type: ignore
|
async def bulk_update( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
objects: Iterable[Self],
|
objects: Iterable[Self], # type: ignore
|
||||||
fields: Iterable[str],
|
fields: Iterable[str],
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
using_db: BaseDBAsyncClient | None = None,
|
using_db: BaseDBAsyncClient | None = None,
|
||||||
) -> int:
|
) -> int: # type: ignore
|
||||||
result = await super().bulk_update(
|
result = await super().bulk_update(
|
||||||
objects=objects,
|
objects=objects,
|
||||||
fields=fields,
|
fields=fields,
|
||||||
@ -152,7 +209,8 @@ class Model(TortoiseModel):
|
|||||||
using_db=using_db,
|
using_db=using_db,
|
||||||
)
|
)
|
||||||
if cache_type := cls.get_cache_type():
|
if cache_type := cls.get_cache_type():
|
||||||
await CacheRoot.reload(cache_type)
|
# 批量更新时清除整个类型的缓存
|
||||||
|
await CacheRoot.invalidate_cache(cache_type)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def save(
|
async def save(
|
||||||
@ -181,13 +239,24 @@ class Model(TortoiseModel):
|
|||||||
force_create=force_create,
|
force_create=force_create,
|
||||||
force_update=force_update,
|
force_update=force_update,
|
||||||
)
|
)
|
||||||
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
if cache_type := getattr(self, "cache_type", None):
|
||||||
await CacheRoot.reload(cache_type)
|
# 获取缓存键
|
||||||
|
key = self.__class__.get_cache_key(self)
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, key)
|
||||||
|
|
||||||
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
async def delete(self, using_db: BaseDBAsyncClient | None = None):
|
||||||
|
# 在删除前获取缓存键
|
||||||
|
cache_type = getattr(self, "cache_type", None)
|
||||||
|
key = None
|
||||||
|
if cache_type:
|
||||||
|
key = self.__class__.get_cache_key(self)
|
||||||
|
|
||||||
|
# 执行删除操作
|
||||||
await super().delete(using_db=using_db)
|
await super().delete(using_db=using_db)
|
||||||
if CACHE_FLAG and (cache_type := getattr(self, "cache_type", None)):
|
|
||||||
await CacheRoot.reload(cache_type)
|
# 清除缓存
|
||||||
|
if cache_type:
|
||||||
|
await CacheRoot.invalidate_cache(cache_type, key)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def safe_get_or_none(
|
async def safe_get_or_none(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user