Skip to content

Commit

Permalink
Fixed target
Browse files Browse the repository at this point in the history
  • Loading branch information
alcides committed Oct 22, 2024
1 parent f3da02b commit 36f82c7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
8 changes: 4 additions & 4 deletions geml/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ def fit(self, X, y):
random = NativeRandomSource(self.seed)

feature_names, data = self.prepare_inputs(X)
target = self.prepare_outputs(y)
assert data.shape[0] == target.shape[0]
y = self.prepare_outputs(y)
assert data.shape[0] == y.shape[0]

grammar = self.get_grammar(feature_names, data, target)
grammar = self.get_grammar(feature_names, data, y)

def fitness_function(x: Expression) -> float:
try:
y_pred = forward_dataset(x.to_numpy(), data)
with np.errstate(all="ignore"):
return r2_score(target, y_pred)
return r2_score(y, y_pred)
except ValueError:
return -10000000

Expand Down
2 changes: 0 additions & 2 deletions geneticengine/algorithms/hill_climbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
def search(self) -> Individual:
assert isinstance(self.representation, RepresentationWithMutation)
current_ind = None
print(self.is_done())
while not self.is_done():
if current_ind is None:
n = self.representation.create_genotype(self.random)
Expand All @@ -41,5 +40,4 @@ def search(self) -> Individual:
neighbourhood = [Individual(genotype=n2, representation=self.representation) for n2 in genotypes]
self.tracker.evaluate(neighbourhood)
current_ind = self.tracker.get_best_individual()
print("debug", current_ind, self.tracker.get_best_individual())
return self.tracker.get_best_individual()

0 comments on commit 36f82c7

Please sign in to comment.