Skip to content

Commit

Permalink
Merge pull request #257 from alcides/bfs
Browse files Browse the repository at this point in the history
Enumerative is now BFS instead of DFS
  • Loading branch information
alcides authored Nov 7, 2024
2 parents 205b80c + 37129e2 commit 5112e91
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
66 changes: 50 additions & 16 deletions geneticengine/algorithms/enumerative.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from itertools import count, takewhile
from typing import Any
from typing import Any, Generator, Optional

from geneticengine.algorithms.api import SynthesisAlgorithm

Expand Down Expand Up @@ -41,9 +41,12 @@ def combine_list_types(ts: list[type], acc: list[Any], gen):
yield from combine_list_types(tail, acc + [x], gen)


def iterate_grammar(grammar: Grammar, starting_symbol: type):
def rec_generator(symbol):
return iterate_grammar(grammar, symbol)
def iterate_grammar(grammar: Grammar, starting_symbol: type, generator_for_recursive: Optional[Any] = None):

if generator_for_recursive is None:

def generator_for_recursive(symbol: type):
return iterate_grammar(grammar, symbol, generator_for_recursive)

if starting_symbol is int:
yield from range(-100000000, 100000000)
Expand All @@ -54,52 +57,83 @@ def rec_generator(symbol):
yield False
elif is_generic_tuple(starting_symbol):
types = get_generic_parameters(starting_symbol)
for li in combine_list_types(types, [], rec_generator):
for li in combine_list_types(types, [], generator_for_recursive):
yield tuple(li)
elif is_generic_list(starting_symbol):
inner_type = get_generic_parameter(starting_symbol)

for length in range(0, 1024):

length = 0
while True:
generator_list = [inner_type for _ in range(length)]
for concrete_list in combine_list_types(generator_list, [], rec_generator):
for concrete_list in combine_list_types(generator_list, [], generator_for_recursive):
yield concrete_list
length += 1

elif is_metahandler(starting_symbol):
metahandler: MetaHandlerGenerator = starting_symbol.__metadata__[0] # type: ignore
base_type = get_generic_parameter(starting_symbol)

if hasattr(metahandler, "iterate"):
yield from metahandler.iterate(base_type, lambda xs: combine_list_types(xs, [], rec_generator))
yield from metahandler.iterate(base_type, lambda xs: combine_list_types(xs, [], generator_for_recursive))
else:
base_type = get_generic_parameter(starting_symbol)
for ins in iterate_grammar(grammar, base_type):
for ins in generator_for_recursive(base_type):
if metahandler.validate(ins):
yield ins
elif is_union(starting_symbol):
for alt in get_generic_parameters(starting_symbol):
yield from iterate_grammar(grammar, alt)
yield from generator_for_recursive(alt)
else:
if starting_symbol not in grammar.all_nodes:
raise GeneticEngineError(
f"Symbol {starting_symbol} not in grammar rules.",
)
elif starting_symbol in grammar.alternatives:
compatible_productions = grammar.alternatives[starting_symbol]
compatible_productions = sorted(
grammar.alternatives[starting_symbol],
key=lambda x: grammar.distanceToTerminal[x],
)

for prod in sorted(compatible_productions, key=lambda x: grammar.distanceToTerminal[x]):
yield from iterate_grammar(grammar, prod)
non_recursive = [c for c in compatible_productions if c not in grammar.recursive_prods]
recursive = [c for c in compatible_productions if c in grammar.recursive_prods]
# key to sort from shallow to deepest

cache = []

# Non-recursive
for prod in non_recursive:
for v in generator_for_recursive(prod):
yield v
cache.append(v) # Build level 0 of cache

# Recursive cases, by level

# This reader will replace the generator with reading from the cache of previous levels
# If the type is different, it generates it as it was previously done.
def rgenerator(t: type) -> Generator[Any, Any, Any]:
if t is starting_symbol:
yield from cache
else:
yield from generator_for_recursive(t)

while True:
tmp = []
for prod in recursive:
for v in iterate_grammar(grammar, prod, rgenerator):
yield v
tmp.append(v)
cache.extend(tmp)
else:
# Normal production
args = []
# TODO: Add dependent types to enumerative
# dependent_values = {}
args = [argt for _, argt in get_arguments(starting_symbol)]
for li in combine_list_types(args, [], rec_generator):
for li in combine_list_types(args, [], generator_for_recursive):
yield apply_constructor(starting_symbol, li)


def iterate_individuals(grammar: Grammar, starting_symbol: type):
def iterate_individuals(grammar: Grammar, starting_symbol: type) -> Generator[ConcreteIndividual, Any, Any]:
for p in iterate_grammar(grammar, starting_symbol):
yield ConcreteIndividual(instance=p)

Expand Down
41 changes: 41 additions & 0 deletions tests/enumerative/enumerative_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC
from dataclasses import dataclass

from geneticengine.algorithms.enumerative import iterate_individuals
from geneticengine.grammar.grammar import extract_grammar


class Root(ABC):
pass


@dataclass
class Leaf(Root):
pass


@dataclass
class Branch1(Root):
v1: Root
v2: Root


@dataclass
class Branch2(Root):
v1: Root
v2: Root


def test_enumerative():
g = extract_grammar([Leaf, Branch1, Branch2], Root)
exp = [Leaf(), Branch1(Leaf(), Leaf()), Branch2(Leaf(), Leaf())]

xs = []
for x in iterate_individuals(g, Root):
xs.append(x.instance)
if len(xs) > 10:
break
print(xs)

for expected, real in zip(exp, xs):
assert expected == real

0 comments on commit 5112e91

Please sign in to comment.