🩹 优化插件加载与模块格式转换逻辑

This commit is contained in:
HibiKier 2024-12-20 17:16:34 +08:00
parent 5a92a4fe3d
commit 3cc497d6a6
4 changed files with 51 additions and 11 deletions

View File

@ -120,7 +120,7 @@ async def _():
if module_list := await PluginInfo.all().values("id", "module_path"):
module2id = {m["module_path"]: m["id"] for m in module_list}
for plugin in get_loaded_plugins():
load_plugin.append(plugin.name)
load_plugin.append(plugin.module_name)
await _handle_setting(plugin, plugin_list, limit_list, task_list)
create_list = []
update_list = []
@ -198,8 +198,8 @@ async def _():
10,
)
await data_migration()
await PluginInfo.filter(module__in=load_plugin).update(load_status=True)
await PluginInfo.filter(module__not_in=load_plugin).update(load_status=False)
await PluginInfo.filter(module_path__in=load_plugin).update(load_status=True)
await PluginInfo.filter(module_path__not_in=load_plugin).update(load_status=False)
manager.init()
if limit_list:
for limit in limit_list:

View File

@ -13,6 +13,7 @@ from zhenxun.models.plugin_info import PluginInfo
from zhenxun.models.statistics import Statistics
from zhenxun.models.task_info import TaskInfo
from zhenxun.services.log import logger
from zhenxun.utils.common_utils import CommonUtils
from zhenxun.utils.enum import RequestHandleType, RequestType
from zhenxun.utils.exception import NotFoundError
from zhenxun.utils.platform import PlatformUtils
@ -84,12 +85,16 @@ async def _(group: UpdateGroup) -> Result[str]:
db_group.level = group.level
db_group.status = group.status
if group.close_plugins:
group.close_plugins = [f"<{module}" for module in group.close_plugins]
db_group.block_plugin = ",".join(group.close_plugins) + ","
db_group.block_plugin = CommonUtils.convert_module_format(
group.close_plugins
)
else:
db_group.block_plugin = ""
if group.task:
if block_task := [t for t in task_list if t not in group.task]:
block_task = [f"<{module}" for module in block_task]
db_group.block_task = ",".join(block_task) + "," # type: ignore
db_group.block_task = CommonUtils.convert_module_format(block_task) # type: ignore
else:
db_group.block_task = CommonUtils.convert_module_format(task_list) # type: ignore
await db_group.save(
update_fields=["level", "status", "block_plugin", "block_task"]
)
@ -302,8 +307,8 @@ async def _(param: LeaveGroup) -> Result:
platform = PlatformUtils.get_platform(bots[bot_id])
if platform != "qq":
return Result.warning_("该平台不支持退群操作...")
group_list = await bots[bot_id].get_group_list()
if param.group_id not in [str(g["group_id"]) for g in group_list]:
group_list, _ = await PlatformUtils.get_group_list(bots[bot_id])
if param.group_id not in [g.group_id for g in group_list]:
return Result.warning_("Bot未在该群聊中...")
await bots[bot_id].set_group_leave(group_id=param.group_id)
return Result.ok(info="成功处理了请求!")

View File

@ -39,7 +39,7 @@ async def _(
if plugin_type:
query = query.filter(plugin_type__in=plugin_type, load_status=True)
if menu_type:
query = query.filter(menu_type=menu_type)
query = query.filter(menu_type=menu_type, load_status=True)
plugins = await query.all()
for plugin in plugins:
plugin_info = PluginInfo(
@ -116,6 +116,7 @@ async def _(plugin: UpdatePlugin) -> Result:
if c.type and value is not None:
value = cattrs.structure(value, c.type)
Config.set_config(plugin.module, key, value)
Config.save(save_simple_data=True)
except Exception as e:
logger.error("调用API错误", "/update_plugins", e=e)
return Result.fail(f"{type(e)}: {e}")
@ -152,7 +153,11 @@ async def _(param: PluginSwitch) -> Result:
)
async def _() -> Result[list[str]]:
menu_type_list = []
result = await DbPluginInfo.annotate().values_list("menu_type", flat=True)
result = (
await DbPluginInfo.filter(load_status=True)
.annotate()
.values_list("menu_type", flat=True)
)
for r in result:
if r not in menu_type_list and r:
menu_type_list.append(r)

View File

@ -1,3 +1,5 @@
from typing import overload
from nonebot.adapters import Bot
from nonebot_plugin_uninfo import Session, SupportScope, Uninfo, get_interface
@ -62,6 +64,34 @@ class CommonUtils:
return True
return False
@staticmethod
def format(name: str) -> str:
return f"<{name},"
@overload
@classmethod
def convert_module_format(cls, data: str) -> list[str]: ...
@overload
@classmethod
def convert_module_format(cls, data: list[str]) -> str: ...
@classmethod
def convert_module_format(cls, data: str | list[str]) -> str | list[str]:
"""
`<aaa,<bbb,<ccc,` `["aaa", "bbb", "ccc"]` 之间进行相互转换
参数:
data (str | list[str]): 输入数据可能是格式化字符串或字符串列表
返回:
str | list[str]: 根据输入类型返回转换后的数据
"""
if isinstance(data, str):
return [item.strip(",") for item in data.split("<") if item]
elif isinstance(data, list):
return "".join(cls.format(item) for item in data)
class SqlUtils:
@classmethod