Skip to content

Commit

Permalink
Added add_floor utility to db
Browse files Browse the repository at this point in the history
  • Loading branch information
ppizarror committed Feb 3, 2025
1 parent 980d7cc commit 5519b5d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
2 changes: 1 addition & 1 deletion MLStructFP/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__email__ = 'pablo@ppizarror.com'
__keywords__ = ['ml', 'ai', 'floor plan', 'architectural', 'dataset', 'cnn']
__license__ = 'MIT'
__version__ = '0.7.2'
__version__ = '0.7.3'

# URL
__url__ = 'https://github.com/MLSTRUCT/MLSTRUCT-FP'
Expand Down
34 changes: 30 additions & 4 deletions MLStructFP/db/_db_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DbLoader(object):
__filter: Optional[Callable[['Floor'], bool]]
__filtered_floors: List['Floor']
__floor: Dict[int, 'Floor']
__floor_categories: Dict[int, str]
__path: str

def __init__(self, db: str, floor_only: bool = False) -> None:
Expand All @@ -45,17 +46,17 @@ def __init__(self, db: str, floor_only: bool = False) -> None:
assert os.path.isfile(db), f'Dataset file {db} not found'
self.__filter = None
self.__filtered_floors = []
self.__path = str(Path(os.path.realpath(db)).parent)
self.__floor = {}
self.__floor_categories: Dict[int, str] = {}
self.__path = str(Path(os.path.realpath(db)).parent)

with open(db, 'r', encoding='utf8') as dbfile:
data: dict = json.load(dbfile)
meta: dict = data['meta'] if 'meta' in data else {}

# Load metadata
floor_categories: Dict[int, str] = {}
for cat in (meta['floor_categories'] if 'floor_categories' in meta else {}):
floor_categories[meta['floor_categories'][cat]] = cat
self.__floor_categories[meta['floor_categories'][cat]] = cat
item_types: Dict[int, Tuple[str, str]] = {}
for cat in (meta['item_types'] if 'item_types' in meta else {}):
ic = meta['item_types'][cat]
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, db: str, floor_only: bool = False) -> None:
project_id=project_id,
project_label=project_label[project_id] if project_id in project_label else '',
category=f_cat,
category_name=floor_categories.get(f_cat, ''),
category_name=self.__floor_categories.get(f_cat, ''),
elevation=f_data['elevation'] if 'elevation' in f_data else False
)
if floor_only:
Expand Down Expand Up @@ -153,6 +154,31 @@ def __init__(self, db: str, floor_only: bool = False) -> None:
def __getitem__(self, item: int) -> 'Floor':
return self.__floor[item]

def add_floor(self, floor_image: str, scale: float, category: int, elevation: bool) -> 'Floor':
"""
Adds a floor to the dataset. No project.
:param floor_image: Floor image file
:param scale: Image scale
:param category: Floor category
:param elevation: Floor is elevation
:return: Added floor object
"""
assert os.path.isfile(floor_image)
f_id: int = len(self.__floor) + 1
f = Floor(
floor_id=int(f_id),
image_path=floor_image,
image_scale=scale,
project_id=-1,
project_label='',
category=category,
category_name=self.__floor_categories.get(category, ''),
elevation=elevation
)
self.__floor[f_id] = f
return f

@property
def floors(self) -> Tuple['Floor', ...]:
if len(self.__filtered_floors) == 0:
Expand Down
12 changes: 12 additions & 0 deletions test/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ def test_hist(self) -> None:
db = DbLoader(DB_PATH)
self.assertEqual(db.hist(show_plot=False), ('',))

def test_add_floor(self) -> None:
"""
Test add floor to database.
"""
db = DbLoader(DB_PATH)
f0 = db.floors[0]
f = db.add_floor(floor_image=f0.image_path, scale=f0.image_scale, category=f0.category, elevation=f0.elevation)
self.assertEqual(f.image_path, f0.image_path)
self.assertEqual(f.image_scale, f0.image_scale)
self.assertEqual(f.category, f0.category)
self.assertEqual(f.elevation, f0.elevation)

def test_image(self) -> None:
"""
Test image obtain in binary/photo.
Expand Down

0 comments on commit 5519b5d

Please sign in to comment.