Skip to content

Commit

Permalink
Allow users to pass a model name when adding node data through API.
Browse files Browse the repository at this point in the history
Example: `config.add_node_data_from_path(model_data1, model_name='model1.tflite')`

Also set numpy version < 2 to suppress warning message.

Fixes #103
  • Loading branch information
jinjingforever committed Jul 22, 2024
1 parent c1c554b commit e0b60aa
Show file tree
Hide file tree
Showing 13 changed files with 9,244 additions and 32,066 deletions.
4 changes: 2 additions & 2 deletions src/server/package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ai-edge-model-explorer"
version = "0.1.7"
version = "0.1.8"
authors = [
{ name="Google LLC", email="opensource@google.com" },
]
Expand All @@ -24,7 +24,7 @@ dependencies = [
"termcolor",
"typing-extensions",
"torch >= 2.2",
"numpy",
"numpy < 2",
]

[project.scripts]
Expand Down
41 changes: 38 additions & 3 deletions src/server/package/src/model_explorer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

EncodedUrlData = TypedDict(
'EncodedUrlData',
{'models': list[ModelSource], 'nodeData': NotRequired[list[str]]},
{'models': list[ModelSource],
'nodeData': NotRequired[list[str]],
'nodeDataTargets': NotRequired[list[str]]},
)


Expand All @@ -43,6 +45,9 @@ def __init__(self) -> None:
self.model_sources: list[ModelSource] = []
self.graphs_list: list[ModelExplorerGraphs] = []
self.node_data_sources: list[str] = []
# List of model names to apply node data to. For the meaning of
# "model name", see comments in `add_node_data_from_path` method below.
self.node_data_target_models: list[str] = []
self.node_data_list: list[NodeData] = []

def add_model_from_path(
Expand Down Expand Up @@ -95,27 +100,51 @@ def add_model_from_pytorch(

return self

def add_node_data_from_path(self, path: str) -> 'ModelExplorerConfig':
def add_node_data_from_path(
self,
path: str,
model_name: Union[str, None] = None) -> 'ModelExplorerConfig':
"""Adds node data file to the config.
Args:
path: the path of the node data json file to add.
model_name: the name of the model. If not set, the node data will be
applied to the first model added to the config by default.
To set this parameter:
For non-pytorch model, this should be the name of the model file
(e.g. model.tflite). For pytorch model, it should be the `name`
parameter used to call the `add_model_from_pytorch` api.
"""
# Get the absolute path (after expanding home dir path "~").
abs_model_path = os.path.abspath(os.path.expanduser(path))

self.node_data_sources.append(abs_model_path)
if model_name is None:
self.node_data_target_models.append('')
else:
self.node_data_target_models.append(model_name)

return self

def add_node_data(
self, name: str, node_data: NodeData
self,
name: str,
node_data: NodeData,
model_name: Union[str, None] = None
) -> 'ModelExplorerConfig':
"""Adds the given node data object.
Args:
name: the name of the NodeData for display purpose.
node_data: the NodeData object to add.
model_name: the name of the model. If not set, the node data will be
applied to the first model added to the config by default.
To set this parameter:
For non-pytorch model, this should be the name of the model file
(e.g. model.tflite). For pytorch model, it should be the `name`
parameter used to call the `add_model_from_pytorch` api.
"""
node_data_index = len(self.node_data_list)
self.node_data_list.append(node_data)
Expand All @@ -125,6 +154,10 @@ def add_node_data(
# The node data source has a special format, in the form of:
# node_data://{name}//{index}
self.node_data_sources.append(f'node_data://{name}/{node_data_index}')
if model_name is None:
self.node_data_target_models.append('')
else:
self.node_data_target_models.append(model_name)
return self

def to_url_param_value(self) -> str:
Expand All @@ -134,6 +167,8 @@ def to_url_param_value(self) -> str:

if self.node_data_sources:
encoded_url_data['nodeData'] = self.node_data_sources
if self.node_data_target_models:
encoded_url_data['nodeDataTargets'] = self.node_data_target_models

# Return its json string.
return quote(json.dumps(encoded_url_data))
Expand Down
33 changes: 25 additions & 8 deletions src/server/package/src/model_explorer/web_app/index.html

Large diffs are not rendered by default.

4,336 changes: 4,336 additions & 0 deletions src/server/package/src/model_explorer/web_app/main-5R2QLBOR.js

Large diffs are not rendered by default.

Large diffs are not rendered by default.

1,893 changes: 0 additions & 1,893 deletions src/server/package/src/model_explorer/web_app/static_files/app_bundle.js

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit e0b60aa

Please sign in to comment.