Skip to content

Commit

Permalink
✨ EntariConfig.save[json/yaml]
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Feb 28, 2025
1 parent eb6cbcc commit e5f384f
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 14 deletions.
7 changes: 7 additions & 0 deletions arclet/entari/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dataclasses import field as field # noqa

from .file import EntariConfig as EntariConfig
from .file import load_config as load_config
from .model import BasicConfModel as BasicConfModel
from .model import config_model_validate as config_model_validate
from .model import config_validator_register as config_validator_register
184 changes: 184 additions & 0 deletions arclet/entari/config/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import annotations

from dataclasses import dataclass, field
import json
import os
from pathlib import Path
import re
from typing import Any, Callable, ClassVar, TypedDict
import warnings

ENV_CONTEXT_PAT = re.compile(r"\$\{\{\s?env\.(?P<name>[^}]+)\s?\}\}")


class BasicConfig(TypedDict, total=False):
network: list[dict[str, Any]]
ignore_self_message: bool
skip_req_missing: bool
log_level: int | str
prefix: list[str]
cmd_count: int
external_dirs: list[str]


@dataclass
class EntariConfig:
path: Path
basic: BasicConfig = field(default_factory=dict, init=False) # type: ignore
plugin: dict[str, dict] = field(default_factory=dict, init=False)
prelude_plugin: list[str] = field(default_factory=list, init=False)
plugin_extra_files: list[str] = field(default_factory=list, init=False)
updater: Callable[[EntariConfig], None]

instance: ClassVar[EntariConfig]

def __post_init__(self):
self.__class__.instance = self
self.reload()

@staticmethod
def _load_plugin(path: Path):
if path.suffix.startswith(".json"):
with path.open("r", encoding="utf-8") as f:
return json.load(f)
if path.suffix in (".yaml", ".yml"):
try:
import yaml
except ImportError:
raise RuntimeError("yaml is not installed")

with path.open("r", encoding="utf-8") as f:
return yaml.safe_load(f)
raise NotImplementedError(f"unsupported plugin config file format: {path!s}")

def reload(self):
self.updater(self)
self.plugin_extra_files: list[str] = self.plugin.pop("$files", []) # type: ignore
for file in self.plugin_extra_files:
path = Path(file)
if not path.exists():
raise FileNotFoundError(file)
if path.is_dir():
for _path in path.iterdir():
if not _path.is_file():
continue
self.plugin[_path.stem] = self._load_plugin(_path)
else:
self.plugin[path.stem] = self._load_plugin(path)

self.plugin.setdefault(".commands", {})
self.prelude_plugin = self.plugin.pop("$prelude", []) # type: ignore
disabled = []
for k, v in self.plugin.items():
if v is True:
self.plugin[k] = {}
warnings.warn(
f"`True` usage in plugin '{k}' config is deprecated, use empty dict instead", DeprecationWarning
)
elif v is False:
disabled.append(k)
for k in disabled:
self.plugin[f"~{k}"] = self.plugin.pop(k)
warnings.warn(
f"`False` usage in plugin '{k}' config is deprecated, use `~` prefix instead", DeprecationWarning
)

def dump(self):
plugins = self.plugin.copy()
if plugins[".commands"] == {}:
plugins.pop(".commands")
if self.prelude_plugin:
plugins = {"$prelude": self.prelude_plugin, **plugins}
if self.plugin_extra_files:
for file in self.plugin_extra_files:
path = Path(file)
if path.is_file():
plugins.pop(path.stem)
else:
for _path in path.iterdir():
if _path.is_file():
plugins.pop(_path.stem)
plugins = {"$files": self.plugin_extra_files, **plugins}
return {"basic": self.basic, "plugins": plugins}

def save_json(self):
with self.path.open("r", encoding="utf-8") as f:
origin = json.load(f)
if "entari" in origin:
origin["entari"] = self.dump()
else:
origin = self.dump()
with self.path.open("w", encoding="utf-8") as f1:
json.dump(origin, f1, indent=2, ensure_ascii=False)

def save_yaml(self):
try:
import yaml
except ImportError:
raise RuntimeError("yaml is not installed. Please install with `arclet-entari[yaml]`")
with self.path.open("r", encoding="utf-8") as f:
origin = yaml.safe_load(f)
if "entari" in origin:
origin["entari"] = self.dump()
else:
origin = self.dump()
with self.path.open("w", encoding="utf-8") as f1:
yaml.dump(origin, f1)

@classmethod
def load(cls, path: str | os.PathLike[str] | None = None) -> EntariConfig:
try:
import dotenv

dotenv.load_dotenv()
except ImportError:
dotenv = None # noqa
pass
if path is None:
if "ENTARI_CONFIG_FILE" in os.environ:
_path = Path(os.environ["ENTARI_CONFIG_FILE"])
elif (Path.cwd() / ".entari.json").exists():
_path = Path.cwd() / ".entari.json"
else:
_path = Path.cwd() / "entari.yml"
else:
_path = Path(path)
if not _path.exists():
return cls(_path, lambda _: None)
if not _path.is_file():
raise ValueError(f"{_path} is not a file")

if _path.suffix.startswith(".json"):

def _updater(self: EntariConfig):
with self.path.open("r", encoding="utf-8") as f:
data = json.load(f)
if "entari" in data:
data = data["entari"]
self.basic = data.get("basic", {})
self.plugin = data.get("plugins", {})

obj = cls(_path, _updater)
cls.instance = obj
return obj
if _path.suffix in (".yaml", ".yml"):
try:
import yaml
except ImportError:
raise RuntimeError("yaml is not installed. Please install with `arclet-entari[yaml]`")

def _updater(self: EntariConfig):
with self.path.open("r", encoding="utf-8") as f:
text = f.read()
text = ENV_CONTEXT_PAT.sub(lambda m: os.environ.get(m["name"], ""), text)
data = yaml.safe_load(text)
if "entari" in data:
data = data["entari"]
self.basic = data.get("basic", {})
self.plugin = data.get("plugins", {})

return cls(_path, _updater)
raise NotImplementedError(f"unsupported config file format: {_path!s}")


load_config = EntariConfig.load
68 changes: 68 additions & 0 deletions arclet/entari/config/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from dataclasses import dataclass, fields, is_dataclass
from inspect import Signature
from typing import Any, Callable, TypeVar, get_args, get_origin
from typing_extensions import dataclass_transform

_available_dc_attrs = set(Signature.from_callable(dataclass).parameters.keys())

_config_model_validators = {}

C = TypeVar("C")


def config_validator_register(base: type):
def wrapper(func: Callable[[dict[str, Any], type[C]], C]):
_config_model_validators[base] = func
return func

return wrapper


def config_model_validate(base: type[C], data: dict[str, Any]) -> C:
for b in base.__mro__[-2::-1]:
if b in _config_model_validators:
return _config_model_validators[b](data, base)
return base(**data)


@dataclass_transform(kw_only_default=True)
class BasicConfModel:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
dataclass(**{k: v for k, v in kwargs.items() if k in _available_dc_attrs})(cls)


@config_validator_register(BasicConfModel)
def _basic_config_validate(data: dict[str, Any], base: type[C]) -> C:
def _nested_validate(namespace: dict[str, Any], cls):
result = {}
for field_ in fields(cls):
if field_.name not in namespace:
continue
if is_dataclass(field_.type):
result[field_.name] = _nested_validate(namespace[field_.name], field_.type)
elif get_origin(field_.type) is list and is_dataclass(get_args(field_.type)[0]):
result[field_.name] = [_nested_validate(d, get_args(field_.type)[0]) for d in namespace[field_.name]]
elif get_origin(field_.type) is set and is_dataclass(get_args(field_.type)[0]):
result[field_.name] = {_nested_validate(d, get_args(field_.type)[0]) for d in namespace[field_.name]}
elif get_origin(field_.type) is dict and is_dataclass(get_args(field_.type)[1]):
result[field_.name] = {
k: _nested_validate(v, get_args(field_.type)[1]) for k, v in namespace[field_.name].items()
}
elif get_origin(field_.type) is tuple:
args = get_args(field_.type)
result[field_.name] = tuple(
_nested_validate(d, args[i]) if is_dataclass(args[i]) else d
for i, d in enumerate(namespace[field_.name])
)
else:
result[field_.name] = namespace[field_.name]
return cls(**result)

return _nested_validate(data, base)
10 changes: 6 additions & 4 deletions arclet/entari/event/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import Any, Optional, overload
from typing import Any, Optional, TypeVar, overload

from arclet.letoderea import make_event

from ..config import C, config_model_validate
from ..config import config_model_validate

_C = TypeVar("_C")


@dataclass
Expand All @@ -20,9 +22,9 @@ class ConfigReload:
def plugin_config(self) -> dict[str, Any]: ...

@overload
def plugin_config(self, model_type: type[C]) -> C: ...
def plugin_config(self, model_type: type[_C]) -> _C: ...

def plugin_config(self, model_type: Optional[type[C]] = None):
def plugin_config(self, model_type: Optional[type[_C]] = None):
if self.scope != "plugin":
raise ValueError("not a plugin config")
if model_type:
Expand Down
3 changes: 1 addition & 2 deletions arclet/entari/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ async def before(self, session: Optional[Session] = None):
return STOP

async def after(self):
if self.success:
self.last_time = datetime.now()
self.last_time = datetime.now()

def compose(self):
yield self.before, True, 15
Expand Down
16 changes: 8 additions & 8 deletions arclet/entari/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import inspect
from os import PathLike
from pathlib import Path
from typing import Any, Callable, overload
from typing import Any, Callable, TypeVar, overload

from arclet.letoderea import es
from tarina import init_spec

from ..config import C, EntariConfig, config_model_validate
from ..config import EntariConfig, config_model_validate
from ..event.plugin import PluginLoadedFailed, PluginLoadedSuccess, PluginUnloaded
from ..logger import log
from .model import PluginMetadata as PluginMetadata
Expand Down Expand Up @@ -63,10 +63,7 @@ def load_plugin(
"""
if config is None:
config = EntariConfig.instance.plugin.get(path)
conf = config or {}
if "$static" in conf:
del conf["$static"]
conf = conf.copy()
conf = (config or {}).copy()
if prelude:
conf["$static"] = True
if recursive_guard is None:
Expand Down Expand Up @@ -129,15 +126,18 @@ def metadata(data: PluginMetadata):
get_plugin(1)._metadata = data # type: ignore


_C = TypeVar("_C")


@overload
def plugin_config() -> dict[str, Any]: ...


@overload
def plugin_config(model_type: type[C]) -> C: ...
def plugin_config(model_type: type[_C]) -> _C: ...


def plugin_config(model_type: type[C] | None = None):
def plugin_config(model_type: type[_C] | None = None):
"""获取当前插件的配置"""
plugin = get_plugin(1)
if model_type:
Expand Down

0 comments on commit e5f384f

Please sign in to comment.