From 7c07a01516dc004baf5c85245d8b5de860477f74 Mon Sep 17 00:00:00 2001 From: Alcides Fonseca Date: Thu, 7 Nov 2024 21:56:57 +0000 Subject: [PATCH 1/2] Enumerative is now BFS instead of DFS --- geneticengine/algorithms/enumerative.py | 61 +++++++++++++++++++------ tests/enumerative/enumerative_test.py | 41 +++++++++++++++++ 2 files changed, 88 insertions(+), 14 deletions(-) create mode 100644 tests/enumerative/enumerative_test.py diff --git a/geneticengine/algorithms/enumerative.py b/geneticengine/algorithms/enumerative.py index 4e90e44f..12a3811f 100644 --- a/geneticengine/algorithms/enumerative.py +++ b/geneticengine/algorithms/enumerative.py @@ -1,6 +1,6 @@ from __future__ import annotations from itertools import count, takewhile -from typing import Any +from typing import Any, Generator, Optional from geneticengine.algorithms.api import SynthesisAlgorithm @@ -41,9 +41,12 @@ def combine_list_types(ts: list[type], acc: list[Any], gen): yield from combine_list_types(tail, acc + [x], gen) -def iterate_grammar(grammar: Grammar, starting_symbol: type): - def rec_generator(symbol): - return iterate_grammar(grammar, symbol) +def iterate_grammar(grammar: Grammar, starting_symbol: type, generator_for_recursive: Optional[Any] = None): + + if generator_for_recursive is None: + + def generator_for_recursive(symbol: type): + return iterate_grammar(grammar, symbol, generator_for_recursive) if starting_symbol is int: yield from range(-100000000, 100000000) @@ -54,7 +57,7 @@ def rec_generator(symbol): yield False elif is_generic_tuple(starting_symbol): types = get_generic_parameters(starting_symbol) - for li in combine_list_types(types, [], rec_generator): + for li in combine_list_types(types, [], generator_for_recursive): yield tuple(li) elif is_generic_list(starting_symbol): inner_type = get_generic_parameter(starting_symbol) @@ -62,7 +65,7 @@ def rec_generator(symbol): for length in range(0, 1024): generator_list = [inner_type for _ in range(length)] - for concrete_list in combine_list_types(generator_list, [], rec_generator): + for concrete_list in combine_list_types(generator_list, [], generator_for_recursive): yield concrete_list elif is_metahandler(starting_symbol): @@ -70,36 +73,66 @@ def rec_generator(symbol): base_type = get_generic_parameter(starting_symbol) if hasattr(metahandler, "iterate"): - yield from metahandler.iterate(base_type, lambda xs: combine_list_types(xs, [], rec_generator)) + yield from metahandler.iterate(base_type, lambda xs: combine_list_types(xs, [], generator_for_recursive)) else: base_type = get_generic_parameter(starting_symbol) - for ins in iterate_grammar(grammar, base_type): + for ins in generator_for_recursive(base_type): if metahandler.validate(ins): yield ins elif is_union(starting_symbol): for alt in get_generic_parameters(starting_symbol): - yield from iterate_grammar(grammar, alt) + yield from generator_for_recursive(alt) else: if starting_symbol not in grammar.all_nodes: raise GeneticEngineError( f"Symbol {starting_symbol} not in grammar rules.", ) elif starting_symbol in grammar.alternatives: - compatible_productions = grammar.alternatives[starting_symbol] + compatible_productions = sorted( + grammar.alternatives[starting_symbol], + key=lambda x: grammar.distanceToTerminal[x], + ) - for prod in sorted(compatible_productions, key=lambda x: grammar.distanceToTerminal[x]): - yield from iterate_grammar(grammar, prod) + non_recursive = [c for c in compatible_productions if c not in grammar.recursive_prods] + recursive = [c for c in compatible_productions if c in grammar.recursive_prods] + # key to sort from shallow to deepest + + cache = [] + + # Non-recursive + for prod in non_recursive: + for v in generator_for_recursive(prod): + yield v + cache.append(v) # Build level 0 of cache + + # Recursive cases, by level + + # This reader will replace the generator with reading from the cache of previous levels + # If the type is different, it generates it as it was previously done. + def rgenerator(t: type) -> Generator[Any, Any, Any]: + if t is starting_symbol: + yield from cache + else: + yield from generator_for_recursive(t) + + while True: + tmp = [] + for prod in recursive: + for v in iterate_grammar(grammar, prod, rgenerator): + yield v + tmp.append(v) + cache.extend(tmp) else: # Normal production args = [] # TODO: Add dependent types to enumerative # dependent_values = {} args = [argt for _, argt in get_arguments(starting_symbol)] - for li in combine_list_types(args, [], rec_generator): + for li in combine_list_types(args, [], generator_for_recursive): yield apply_constructor(starting_symbol, li) -def iterate_individuals(grammar: Grammar, starting_symbol: type): +def iterate_individuals(grammar: Grammar, starting_symbol: type) -> Generator[ConcreteIndividual, Any, Any]: for p in iterate_grammar(grammar, starting_symbol): yield ConcreteIndividual(instance=p) diff --git a/tests/enumerative/enumerative_test.py b/tests/enumerative/enumerative_test.py new file mode 100644 index 00000000..54a0be26 --- /dev/null +++ b/tests/enumerative/enumerative_test.py @@ -0,0 +1,41 @@ +from abc import ABC +from dataclasses import dataclass + +from geneticengine.algorithms.enumerative import iterate_individuals +from geneticengine.grammar.grammar import extract_grammar + + +class Root(ABC): + pass + + +@dataclass +class Leaf(Root): + pass + + +@dataclass +class Branch1(Root): + v1: Root + v2: Root + + +@dataclass +class Branch2(Root): + v1: Root + v2: Root + + +def test_enumerative(): + g = extract_grammar([Leaf, Branch1, Branch2], Root) + exp = [Leaf(), Branch1(Leaf(), Leaf()), Branch2(Leaf(), Leaf())] + + xs = [] + for x in iterate_individuals(g, Root): + xs.append(x.instance) + if len(xs) > 10: + break + print(xs) + + for expected, real in zip(exp, xs): + assert expected == real From 37129e20df3a2d73856a0b5d019607ec68ab5f60 Mon Sep 17 00:00:00 2001 From: Alcides Fonseca Date: Thu, 7 Nov 2024 23:33:00 +0000 Subject: [PATCH 2/2] Lists are now forever --- geneticengine/algorithms/enumerative.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/geneticengine/algorithms/enumerative.py b/geneticengine/algorithms/enumerative.py index 12a3811f..0c7833cc 100644 --- a/geneticengine/algorithms/enumerative.py +++ b/geneticengine/algorithms/enumerative.py @@ -62,11 +62,12 @@ def generator_for_recursive(symbol: type): elif is_generic_list(starting_symbol): inner_type = get_generic_parameter(starting_symbol) - for length in range(0, 1024): - + length = 0 + while True: generator_list = [inner_type for _ in range(length)] for concrete_list in combine_list_types(generator_list, [], generator_for_recursive): yield concrete_list + length += 1 elif is_metahandler(starting_symbol): metahandler: MetaHandlerGenerator = starting_symbol.__metadata__[0] # type: ignore