|
10 | 10 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
11 | 11 | from urllib.error import HTTPError
|
12 | 12 |
|
| 13 | +import pkg_resources |
13 | 14 | import requests
|
14 | 15 | import yaml
|
15 | 16 | from requests import exceptions
|
|
79 | 80 | logger.addHandler(logging.StreamHandler(sys.stdout))
|
80 | 81 |
|
81 | 82 |
|
| 83 | +def _python_req_name(python_requirement: str) -> str: |
| 84 | + return pkg_resources.Requirement.parse(python_requirement).name |
| 85 | + |
| 86 | + |
82 | 87 | class TrussHandle:
|
83 | 88 | def __init__(self, truss_dir: Path, validate: bool = True) -> None:
|
84 | 89 | self._truss_dir = truss_dir
|
@@ -106,7 +111,6 @@ def spec(self) -> TrussSpec:
|
106 | 111 | def _wait_for_predict(
|
107 | 112 | model_base_url: str, request: Dict, binary: bool = False
|
108 | 113 | ) -> Response:
|
109 |
| - |
110 | 114 | url = f"{model_base_url}/v1/models/model:predict"
|
111 | 115 |
|
112 | 116 | if binary:
|
@@ -397,11 +401,15 @@ def training_docker_build_setup(self, build_dir: Optional[Path] = None):
|
397 | 401 |
|
398 | 402 | def add_python_requirement(self, python_requirement: str):
|
399 | 403 | """Add a python requirement to truss model's config."""
|
400 |
| - self._update_config( |
401 |
| - lambda conf: replace( |
402 |
| - conf, requirements=[*conf.requirements, python_requirement] |
403 |
| - ) |
404 |
| - ) |
| 404 | + |
| 405 | + # Parse the added python requirements |
| 406 | + input_python_req_name = _python_req_name(python_requirement) |
| 407 | + new_reqs = [ |
| 408 | + req for req in self._spec.config.requirements |
| 409 | + if _python_req_name(req) != input_python_req_name |
| 410 | + ] |
| 411 | + new_reqs.append(python_requirement) |
| 412 | + self._update_config(lambda conf: replace(conf, requirements=new_reqs)) |
405 | 413 |
|
406 | 414 | def remove_python_requirement(self, python_requirement: str):
|
407 | 415 | """Remove a python requirement to truss model's config.
|
|
0 commit comments