Skip to content

Commit

Permalink
Finish implementing bound handling and unit tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
mlojek committed Feb 8, 2025
1 parent 57414d5 commit f98db62
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 3 deletions.
18 changes: 17 additions & 1 deletion src/optilab/data_classes/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,23 @@ def reflect(self, point: Point) -> Point:
Returns:
Point: Reflected point.
"""
raise NotImplementedError
new_x = []

for val in point.x:
if val < self.lower or val > self.upper:
val -= self.lower
remainder = val % (self.upper - self.lower)
relative_distance = val // (self.upper - self.lower)

if relative_distance % 2 == 0:
new_x.append(self.lower + remainder)
else:
new_x.append(self.upper - remainder)
else:
new_x.append(val)

point.x = new_x
return point

def wrap(self, point: Point) -> Point:
"""
Expand Down
9 changes: 9 additions & 0 deletions tests/data_classes/bounds_handle/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,12 @@ def point_multidimensional() -> Point:
- project: [14, 10, 20, 10, 20, 10, 20]
"""
return Point([14, 10, 20, 2, 23, -4, 32])


@pytest.fixture
def evaluated_point() -> Point:
"""
An evaluated point, with y value and is_evaluated set to True.
Used to check if the handled point has the same y and is_evaluated values.
"""
return Point(x=[14, 10, 20, -4], y=10.1, is_evaluated=True)
8 changes: 8 additions & 0 deletions tests/data_classes/bounds_handle/test_bounds_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ def test_multidimensional(self, example_bounds, point_multidimensional):
"""
handled_point = example_bounds.project(point_multidimensional)
assert handled_point.x == [14, 10, 20, 10, 20, 10, 20]

def test_evaluated_point(self, example_bounds, evaluated_point):
"""
Test if projecting a point leaves y and is_evaluated members unchanged.
"""
handled_point = example_bounds.project(evaluated_point)
assert handled_point.y == 10.1
assert handled_point.is_evaluated
9 changes: 8 additions & 1 deletion tests/data_classes/bounds_handle/test_bounds_reflect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest


@pytest.mark.xfail
class TestBoundsReflect:
"""
Unit tests for Bounds.reflect method.
Expand Down Expand Up @@ -66,3 +65,11 @@ def test_multidimensional(self, example_bounds, point_multidimensional):
"""
handled_point = example_bounds.reflect(point_multidimensional)
assert handled_point.x == [14, 10, 20, 18, 17, 16, 12]

def test_evaluated_point(self, example_bounds, evaluated_point):
"""
Test if reflecting a point leaves y and is_evaluated members unchanged.
"""
handled_point = example_bounds.reflect(evaluated_point)
assert handled_point.y == 10.1
assert handled_point.is_evaluated
8 changes: 8 additions & 0 deletions tests/data_classes/bounds_handle/test_bounds_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ def test_multidimensional(self, example_bounds, point_multidimensional):
"""
handled_point = example_bounds.wrap(point_multidimensional)
assert handled_point.x == [14, 10, 20, 12, 13, 16, 12]

def test_evaluated_point(self, example_bounds, evaluated_point):
"""
Test if wrapping a point leaves y and is_evaluated members unchanged.
"""
handled_point = example_bounds.wrap(evaluated_point)
assert handled_point.y == 10.1
assert handled_point.is_evaluated
1 change: 0 additions & 1 deletion tests/data_classes/bounds_handle/test_handle_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def test_project(self, example_bounds, point_multidimensional):
handled_point = example_bounds.handle_bounds(point_multidimensional, "project")
assert handled_point.x == [14, 10, 20, 10, 20, 10, 20]

@pytest.mark.xfail()
def test_reflect(self, example_bounds, point_multidimensional):
handled_point = example_bounds.handle_bounds(point_multidimensional, "reflect")
assert handled_point.x == [14, 10, 20, 18, 17, 16, 12]
Expand Down

0 comments on commit f98db62

Please sign in to comment.