Skip to content

Commit

Permalink
refactor(WIP): refactor Graph and Scope to avoid circular reference
Browse files Browse the repository at this point in the history
  • Loading branch information
raceychan committed Feb 1, 2025
1 parent cc957d1 commit e06d9ff
Show file tree
Hide file tree
Showing 12 changed files with 800 additions and 733 deletions.
40 changes: 39 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -901,5 +901,43 @@ dg.node(User)
assert dg.search_node("User").dependent_type is User
```
This is particularly useful for type defined by NewType
- `Graph.override`
```python
UserId = NewType("UserId", str)
assert dg.search_node("UserId")
```
- `Graph.override`
a helper function to override dependent within the graph
```python
def override(self, old_dep: INode[P, T], new_dep: INode[P, T]) -> None:
```
```python
dg = DependencyGraph()
@dg.entry
async def create_user(
user_name: str, user_email: str, service: UserService
) -> UserService:
return service
@dg.node
def user_factory() -> UserService:
return UserService("1", 2)
class FakeUserService(UserService): ...
dg.override(UserService, FakeUserService)
service_res = await create_user("1", "2")
assert isinstance(service_res, FakeUserService)
```
Note that, if you only want to override dependency for `create_user`
you can still just use `create_user.replace(UserService, FakeUserService)`,
and such override won't affect others.
1 change: 1 addition & 0 deletions ididi/_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"""



class BaseRegistry:
"""Base registry class with common functionali"""

Expand Down
27 changes: 13 additions & 14 deletions ididi/_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from functools import lru_cache
from inspect import Parameter, Signature
from inspect import _ParameterKind as ParameterKind # type: ignore
Expand Down Expand Up @@ -41,7 +40,7 @@
resolve_annotation,
resolve_forwardref,
)
from .config import DefaultConfig, NodeConfig
from .config import CacheMax, DefaultConfig, NodeConfig
from .errors import (
ABCNotImplementedError,
MissingAnnotationError,
Expand All @@ -50,6 +49,7 @@
from .interfaces import (
EMPTY_SIGNATURE,
INSPECT_EMPTY,
FrozenData,
INode,
INodeAnyFactory,
INodeConfig,
Expand Down Expand Up @@ -78,7 +78,7 @@ def func(service: UserService = use(factory)): ...
def func(service: Annotated[UserService, use(factory)]): ...
```
"""
node = DependentNode[T].from_node(factory, config=NodeConfig(**iconfig))
node = DependentNode[T].from_node(factory, config=NodeConfig.build(**iconfig))
annt = Annotated[node.dependent_type, node, IDIDI_USE_FACTORY_MARK]
return cast(T, annt)

Expand Down Expand Up @@ -117,8 +117,7 @@ def should_override(
# ======================= Signature =====================================


@dataclass(frozen=True)
class Dependency(Generic[T]):
class Dependency(FrozenData, Generic[T]):
"""'dpram' for short
Represents a parameter and its corresponding dependency node
Expand All @@ -137,8 +136,6 @@ def dependent_factory(self, n: int = 5) -> Any:
default: the default value of the param, 5, in this case.
"""

__slots__ = ("name", "param_kind", "param_type", "default")

name: str
param_type: type[T] # resolved_type
param_kind: ParameterKind
Expand Down Expand Up @@ -340,6 +337,12 @@ def auth_service_factory() -> AuthService:
"_unsolved_params",
)

dependent_type: type[T]
factory: INodeAnyFactory[T]
factory_type: FactoryType
dependencies: Dependencies
config: NodeConfig

def __init__(
self,
*,
Expand All @@ -349,11 +352,7 @@ def __init__(
dependencies: Dependencies,
config: NodeConfig,
):
"""
TODO:
- cancel Dependent, use dependent_type directly
- refactor DependentSignature to Dependencies
"""

self.dependent_type = dependent_type
self.factory = factory
self.factory_type: FactoryType = factory_type
Expand Down Expand Up @@ -410,7 +409,7 @@ def analyze_unsolved_params(
if param.param_type != param_type:
self.dependencies.update(param_name, param)

@lru_cache(1024)
@lru_cache(CacheMax)
def unsolved_params(self, ignore: NodeIgnore) -> list[tuple[str, type]]:
"yield dependencies that needs to be resolved"
if not self._unsolved_params:
Expand Down Expand Up @@ -532,7 +531,7 @@ def _from_class(
)

@classmethod
@lru_cache(1024)
@lru_cache(CacheMax)
def from_node(
cls,
factory_or_class: INode[P, T],
Expand Down
5 changes: 3 additions & 2 deletions ididi/_type_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from typing_extensions import TypeGuard, Unpack

from .config import CacheMax
from .errors import (
ForwardReferenceNotFoundError,
GenericDependencyNotSupportedError,
Expand Down Expand Up @@ -243,7 +244,7 @@ def first_solvable_type(types: tuple[Any, ...]) -> type:
"""


@lru_cache(1024)
@lru_cache(CacheMax)
def resolve_new_type(annotation: Any) -> type:
name = getattr(annotation, "__name__")
tyep_repr = getattr(annotation.__supertype__, "__name__")
Expand Down Expand Up @@ -273,7 +274,7 @@ def flatten_annotated(typ: Annotated[Any, Any]) -> list[Any]:
return flattened_metadata


@lru_cache(1024)
@lru_cache(CacheMax)
def get_bases(dependent: type) -> tuple[type, ...]:
if issubclass(dependent, Protocol):
# -3 excludes Protocol, Gener, object
Expand Down
72 changes: 20 additions & 52 deletions ididi/config.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,22 @@
from dataclasses import FrozenInstanceError
from typing import Any, Iterable
# from dataclasses import FrozenInstanceError
from typing import Final, Iterable

from .interfaces import GraphIgnore, GraphIgnoreConfig, NodeIgnore, NodeIgnoreConfig
from .interfaces import (
FrozenData,
GraphIgnore,
GraphIgnoreConfig,
NodeIgnore,
NodeIgnoreConfig,
)


class FrozenSlot:
"""
A Mixin class provides a hashable, frozen class with slots defined.
This is mainly due to the fact that dataclass does not support slots before python 3.10
"""

__slots__: tuple[str, ...]

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

return all(
getattr(self, attr) == getattr(other, attr) for attr in self.__slots__
)

def __repr__(self):
attr_repr = "".join(
f"{attr}={getattr(self, attr)}, " for attr in self.__slots__
).rstrip(", ")
return f"{self.__class__.__name__}({attr_repr})"

def __setattr__(self, name: str, value: Any) -> None:
raise FrozenInstanceError("can't set attribute")

def __hash__(self) -> int:
attrs = tuple(getattr(self, attr) for attr in self.__slots__)
return hash(attrs)


class NodeConfig(FrozenSlot):
__slots__ = ("reuse", "ignore")
class NodeConfig(FrozenData):

ignore: NodeIgnore
reuse: bool

def __init__(
self,
*,
reuse: bool = True,
ignore: NodeIgnoreConfig = frozenset(),
):

@classmethod
def build(cls, ignore: NodeIgnoreConfig = frozenset(), reuse: bool = True):
if not isinstance(ignore, frozenset):
if isinstance(ignore, Iterable):
if isinstance(ignore, str):
Expand All @@ -56,17 +26,15 @@ def __init__(
else:
ignore = frozenset([ignore])

object.__setattr__(self, "ignore", ignore)
object.__setattr__(self, "reuse", reuse)
return cls(reuse=reuse, ignore=ignore)


class GraphConfig(FrozenSlot):
__slots__ = ("self_inject", "ignore")

class GraphConfig(FrozenData):
self_inject: bool
ignore: GraphIgnore

def __init__(self, *, self_inject: bool, ignore: GraphIgnoreConfig):

@classmethod
def build(cls, *, self_inject: bool, ignore: GraphIgnoreConfig):
if not isinstance(ignore, frozenset):
if isinstance(ignore, Iterable):
if isinstance(ignore, str):
Expand All @@ -76,8 +44,8 @@ def __init__(self, *, self_inject: bool, ignore: GraphIgnoreConfig):
else:
ignore = frozenset([ignore])

object.__setattr__(self, "self_inject", self_inject)
object.__setattr__(self, "ignore", ignore)
return cls(self_inject=self_inject, ignore=ignore)


DefaultConfig = NodeConfig()
DefaultConfig: Final[NodeConfig] = NodeConfig.build()
CacheMax: Final[int] = 1024
Loading

0 comments on commit e06d9ff

Please sign in to comment.