Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metahandler Union added #184

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions geneticengine/grammar/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import NamedTuple

from geneticengine.grammar.decorators import get_gengy
from geneticengine.grammar.utils import all_init_arguments_typed
from geneticengine.grammar.utils import all_init_arguments_typed, is_union
from geneticengine.grammar.utils import get_arguments
from geneticengine.grammar.utils import get_generic_parameter
from geneticengine.grammar.utils import get_generic_parameters
Expand Down Expand Up @@ -240,12 +240,20 @@ def preprocess(self) -> None:

reachability: dict[type, set[type]] = defaultdict(lambda: set())

def explode_generics(tys: list[type]):
for ty in tys:
if is_union(ty):
yield from explode_generics(get_generic_parameters(ty))
elif is_generic_list(ty) or is_annotated(ty):
yield from explode_generics([get_generic_parameter(ty)])
else:
yield ty

def process_reachability(src: type, dsts: list[type]):
src = strip_annotations(src)
src = strip_annotations(src) # TODO remove strip annotations???
ch = False
src_reach = reachability[src]
for prod in dsts:
prod = strip_annotations(prod)
for prod in explode_generics(dsts):
reach = reachability[prod]
oldlen = len(reach)
reach.add(src)
Expand Down Expand Up @@ -289,7 +297,6 @@ def process_reachability(src: type, dsts: list[type]):
args = get_arguments(sym)
assert args
val = max(1 + self.get_distance_to_terminal(argt) for (_, argt) in args)

changed |= process_reachability(
sym,
[argt for (_, argt) in args],
Expand Down
9 changes: 8 additions & 1 deletion geneticengine/grammar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from abc import ABC
from typing import Any
from typing import Any, get_origin, Union
from typing import Callable
from typing import get_type_hints
from typing import Protocol
Expand Down Expand Up @@ -45,6 +45,11 @@ def is_generic(ty: type[Any]):
return hasattr(ty, "__origin__")


def is_union(ty: type[Any]):
"""Returns whether a type is List[T] for any T."""
return get_origin(ty) is Union


def get_generic_parameters(ty: type[Any]) -> list[type]:
"""Annotated[T, <annotations>] or List[T], this function returns
Dict[T,]"""
Expand Down Expand Up @@ -93,6 +98,8 @@ def get_arguments(n) -> list[tuple[str, type]]:

def is_abstract(t: type) -> bool:
"""Returns whether a class is a Protocol or AbstractBaseClass."""
if is_union(t):
return False
return t.mro()[1] in [ABC, Protocol] or get_gengy(t).get("abstract", False)


Expand Down
8 changes: 6 additions & 2 deletions geneticengine/representations/tree/initializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from geneticengine.random.sources import RandomSource
from geneticengine.solutions.tree import GengyList
from geneticengine.representations.tree.utils import relabel_nodes_of_trees
from geneticengine.grammar.utils import build_finalizers
from geneticengine.grammar.utils import build_finalizers, is_union, get_generic_parameters
from geneticengine.grammar.utils import get_arguments
from geneticengine.grammar.utils import get_generic_parameter
from geneticengine.grammar.utils import is_abstract
Expand Down Expand Up @@ -276,6 +276,10 @@ def expand_node(
valb = r.random_bool(str(starting_symbol))
receiver(valb)
return
elif is_union(starting_symbol):
option = r.choice(get_generic_parameters(starting_symbol))
new_symbol(option, receiver, depth, id, ctx)
return
elif is_generic_list(starting_symbol):
ctx = ctx.copy()
ctx["_"] = id
Expand Down Expand Up @@ -339,7 +343,7 @@ def expand_node(
rule = r.choice(valid_productions, str(starting_symbol))
new_symbol(rule, receiver, depth - extra_depth, id, ctx)
else: # Normal production
args = get_arguments(starting_symbol)
args: list[tuple[str, type]] = get_arguments(starting_symbol)
ctx = ctx.copy()
li: list[Any] = []
for argn, _ in args:
Expand Down
17 changes: 16 additions & 1 deletion tests/core/metahandlers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from dataclasses import dataclass
from typing import Annotated
from typing import Annotated, Union

import numpy as np

Expand All @@ -26,6 +26,11 @@ class IntRangeM(Root):
x: Annotated[int, IntRange[9, 10]]


@dataclass
class UnionIntRangeM(Root):
x: Union[Annotated[int, IntRange[0, 10]], Annotated[int, IntRange[20, 30]]]


@dataclass
class IntervalRangeM(Root):
x: Annotated[
Expand Down Expand Up @@ -158,3 +163,13 @@ def test_intervalrange(self):
assert isinstance(n, IntervalRangeM)
assert 5 < n.x[1] - n.x[0] < 10 and n.x[1] < 100
assert isinstance(n, Root)

def test_union_int(self):
r = NativeRandomSource(seed=1)
g = extract_grammar([UnionIntRangeM], Root)
for _ in range(100):
n = random_node(r, g, 3, Root)
assert isinstance(n, UnionIntRangeM)
assert (0 <= n.x <= 10) or (20 <= n.x <= 30)
assert isinstance(n, Root)

Loading