diff --git a/arclet/entari/plugin/module.py b/arclet/entari/plugin/module.py index eed3b7c..e67ef1d 100644 --- a/arclet/entari/plugin/module.py +++ b/arclet/entari/plugin/module.py @@ -34,15 +34,18 @@ def _check_mod(name, package=None): return module -def _unpack_import_from(__fullname: str, mod: str, aliases: list[str]): - if mod == ".": +def _unpack_import_from_level_x(__fullname: str, mod: str, level: int, aliases: list[str]): + if not mod: if len(aliases) == 1: - return _check_mod(f".{aliases[0]}", __fullname) - return tuple(_check_mod(f".{alias}", __fullname) for alias in aliases) - _mod = _check_mod(f".{mod}", __fullname) if mod else _check_mod(__fullname) + return _check_mod(f"{'.' * level}{aliases[0]}", __fullname) + return tuple(_check_mod(f"{'.' * level}{alias}", __fullname) for alias in aliases) + _mod = _check_mod(f"{'.' * level}{mod}", __fullname) # if mod else _check_mod(__fullname) if len(aliases) == 1: return getattr(_mod, aliases[0]) - return tuple(getattr(_mod, alias) for alias in aliases) + args = [] + for alias in aliases: + args.append(getattr(_mod, alias)) + return tuple(args) def _check_import(name: str, plugin_name: str): @@ -57,7 +60,17 @@ def _check_import(name: str, plugin_name: str): if plugin_name != mod.__plugin__.id: service._referents[mod.__plugin__.id].add(plugin_name) return mod.__plugin__.subproxy(name) - return __import__(name) + return __import__(name, fromlist=["__path__"]) + + +def _unpack_import_from_level_0(name, plugin_name, aliases): + mod = _check_import(name, plugin_name) + if len(aliases) == 1: + return getattr(mod, aliases[0]) + args = [] + for alias in aliases: + args.append(getattr(mod, alias)) + return tuple(args) class PluginLoader(SourceFileLoader): @@ -74,53 +87,48 @@ def source_to_code(self, data, path, *, _optimize=-1): # type: ignore nodes = ast.parse(data, type_comments=True) for i, body in enumerate(nodes.body): if isinstance(body, ast.ImportFrom): - if body.level == 0 and ( - body.module in _SUBMODULE_WAITLIST.get(self.name, ()) or body.module in service.plugins - ): + if body.level == 0: + aliases = [alias.asname or alias.name for alias in body.names] + nodes.body[i] = ast.parse( + ",".join(aliases) + + ( + f"=__unpack_import_from_level_0({body.module!r}, {self.name!r}, " + f"{[alias.name for alias in body.names]!r})" + ) + ).body[0] + for node in ast.walk(nodes.body[i]): + node.lineno = body.lineno # type: ignore + node.end_lineno = body.end_lineno # type: ignore + elif body.module is None: + aliases = [alias.asname or alias.name for alias in body.names] + nodes.body[i] = ast.parse( + ",".join(aliases) + + ( + f"=__unpack_import_from_level_x('{self.name}', '', {body.level}, " + f"{[alias.name for alias in body.names]!r})" + ) + ).body[0] + for node in ast.walk(nodes.body[i]): + node.lineno = body.lineno # type: ignore + node.end_lineno = body.end_lineno # type: ignore + else: aliases = [alias.asname or alias.name for alias in body.names] nodes.body[i] = ast.parse( ",".join(aliases) - + f"=__unpack_import_from('{body.module}', '', {[alias.name for alias in body.names]!r})" + + ( + f"=__unpack_import_from_level_x('{self.name}', {body.module!r}, {body.level}, " + f"{[alias.name for alias in body.names]!r})" + ) ).body[0] for node in ast.walk(nodes.body[i]): node.lineno = body.lineno # type: ignore node.end_lineno = body.end_lineno # type: ignore - if body.level == 1: - if body.module is None: - aliases = [alias.asname or alias.name for alias in body.names] - nodes.body[i] = ast.parse( - ",".join(aliases) - + f"=__unpack_import_from('{self.name}', '.', {[alias.name for alias in body.names]!r})" - ).body[0] - for node in ast.walk(nodes.body[i]): - node.lineno = body.lineno # type: ignore - node.end_lineno = body.end_lineno # type: ignore - else: - aliases = [alias.asname or alias.name for alias in body.names] - nodes.body[i] = ast.parse( - ",".join(aliases) - + ( - f"=__unpack_import_from('{self.name}', {body.module!r}, " - f"{[alias.name for alias in body.names]!r})" - ) - ).body[0] - for node in ast.walk(nodes.body[i]): - node.lineno = body.lineno # type: ignore - node.end_lineno = body.end_lineno # type: ignore - elif ( - isinstance(body, ast.Expr) - and isinstance(body.value, ast.Call) - and isinstance(body.value.func, ast.Name) - and body.value.func.id == "package" - ): - if body.value.args and isinstance(body.value.args[0], ast.Constant): - _SUBMODULE_WAITLIST.setdefault(self.name, set()).update(arg.value for arg in body.value.args) # type: ignore elif isinstance(body, ast.Import): aliases = [alias.asname or alias.name for alias in body.names] nodes.body[i] = ast.parse( ",".join(aliases) + "=" - + ",".join((f"__check_import({alias.name!r}, {self.name!r})") for alias in body.names) + + ",".join(f"__check_import({alias.name!r}, {self.name!r})" for alias in body.names) ).body[0] for node in ast.walk(nodes.body[i]): node.lineno = body.lineno # type: ignore @@ -138,7 +146,8 @@ def exec_module(self, module: ModuleType) -> None: if module.__name__ == plugin.module.__name__: # from . import xxxx return setattr(module, "__plugin__", plugin) - setattr(module, "__unpack_import_from", _unpack_import_from) + setattr(module, "__unpack_import_from_level_x", _unpack_import_from_level_x) + setattr(module, "__unpack_import_from_level_0", _unpack_import_from_level_0) setattr(module, "__check_import", _check_import) try: super().exec_module(module) @@ -156,7 +165,8 @@ def exec_module(self, module: ModuleType) -> None: # create plugin before executing plugin = Plugin(module.__name__, module) setattr(module, "__plugin__", plugin) - setattr(module, "__unpack_import_from", _unpack_import_from) + setattr(module, "__unpack_import_from_level_x", _unpack_import_from_level_x) + setattr(module, "__unpack_import_from_level_0", _unpack_import_from_level_0) setattr(module, "__check_import", _check_import) # enter plugin context