Skip to content

Commit 2975bf7

Browse files
authoredAug 18, 2023
Revert "Bundling Private Weights from GCP (#552)" (#591)
This reverts commit 0a52e52.
1 parent a6963d1 commit 2975bf7

File tree

9 files changed

+52
-384
lines changed

9 files changed

+52
-384
lines changed
 

‎examples/vllm-gcs/config.yaml

-19
This file was deleted.

‎examples/vllm-gcs/data/service_account.json

-1
This file was deleted.

‎poetry.lock

+26-204
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.6.2"
3+
version = "0.6.1"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"
@@ -41,7 +41,6 @@ watchfiles = "^0.19.0"
4141
huggingface_hub = "^0.16.4"
4242
rich-click = "^1.6.1"
4343
inquirerpy = "^0.3.4"
44-
google-cloud-storage = "2.10.0"
4544

4645
[tool.poetry.group.builder.dependencies]
4746
python = ">=3.8,<3.12"
@@ -60,7 +59,6 @@ uvicorn = "^0.21.1"
6059
httpx = "^0.24.1"
6160
psutil = "^5.9.4"
6261
huggingface_hub = "^0.16.4"
63-
google-cloud-storage = "2.10.0"
6462

6563
[tool.poetry.dev-dependencies]
6664
torch = "^1.9.0"

‎truss/contexts/image_builder/cache_warmer.py

+13-29
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,25 @@
11
import sys
22
from pathlib import Path
33

4-
from google.cloud import storage
54
from huggingface_hub import hf_hub_download
65

76

8-
def download_file(
9-
repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json"
10-
):
11-
# Check if repo_name starts with "gs://"
12-
if "gs://" in repo_name:
13-
# Create directory if not exist
14-
repo_name = repo_name.replace("gs://", "")
15-
cache_dir = Path(f"/app/hf_cache/{repo_name}")
16-
cache_dir.mkdir(parents=True, exist_ok=True)
17-
18-
# Connect to GCS storage
19-
try:
20-
storage_client = storage.Client.from_service_account_json(key_file)
21-
bucket = storage_client.bucket(repo_name)
22-
blob = bucket.blob(file_name)
23-
# Download the blob to a file
24-
blob.download_to_filename(f"{cache_dir}/{file_name}")
25-
except Exception as e:
26-
raise RuntimeError(f"Failure downloading file from GCS: {e}")
27-
else:
28-
secret_path = Path("/etc/secrets/hf_access_token")
29-
secret = secret_path.read_text().strip() if secret_path.exists() else None
30-
try:
31-
hf_hub_download(repo_name, file_name, revision=revision_name, token=secret)
32-
except FileNotFoundError:
33-
raise RuntimeError(
34-
"Hugging Face repository not found (and no valid secret found for possibly private repository)."
35-
)
7+
def download_file(repo_name, file_name, revision_name=None):
8+
secret = None
9+
secret_path = Path("/etc/secrets/hf_access_token")
10+
11+
if secret_path.exists():
12+
secret = secret_path.read_text().strip()
13+
try:
14+
hf_hub_download(repo_name, file_name, revision=revision_name, token=secret)
15+
except FileNotFoundError:
16+
raise RuntimeError(
17+
"Hugging Face repository not found (and no valid secret found for possibly private repository)."
18+
)
3619

3720

3821
if __name__ == "__main__":
22+
# TODO(varun): might make sense to move this check + write to a separate `prepare_cache.py` script
3923
file_path = Path.home() / ".cache/huggingface/hub/version.txt"
4024
file_path.parent.mkdir(parents=True, exist_ok=True)
4125

‎truss/contexts/image_builder/serving_image_builder.py

+5-103
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Dict, Optional
33

44
import yaml
5-
from google.cloud import storage
65
from huggingface_hub import list_repo_files
76
from huggingface_hub.utils import filter_repo_objects
87
from truss.constants import (
@@ -28,13 +27,7 @@
2827
)
2928
from truss.contexts.truss_context import TrussContext
3029
from truss.patch.hash import directory_content_hash
31-
from truss.truss_config import (
32-
Build,
33-
HuggingFaceCache,
34-
HuggingFaceModel,
35-
ModelServer,
36-
TrussConfig,
37-
)
30+
from truss.truss_config import Build, ModelServer, TrussConfig
3831
from truss.truss_spec import TrussSpec
3932
from truss.util.download import download_external_data
4033
from truss.util.jinja import read_template_from_fs
@@ -82,12 +75,7 @@ def create_tgi_build_dir(config: TrussConfig, build_dir: Path):
8275
supervisord_filepath.write_text(supervisord_contents)
8376

8477

85-
def create_vllm_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path):
86-
def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
87-
copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator]
88-
89-
copy_tree_path(truss_dir, build_dir)
90-
78+
def create_vllm_build_dir(config: TrussConfig, build_dir: Path):
9179
server_endpoint_config = {
9280
"Completions": "/v1/completions",
9381
"ChatCompletions": "/v1/chat/completions",
@@ -97,58 +85,13 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
9785

9886
build_config: Build = config.build
9987
server_endpoint = server_endpoint_config[build_config.arguments.pop("endpoint")]
100-
101-
model_name = build_config.arguments.pop("model")
102-
if "gs://" in model_name:
103-
# if we are pulling from a gs bucket, we want to alias it as a part of the cache
104-
model_to_cache = {"repo_id": model_name}
105-
if config.hf_cache:
106-
config.hf_cache.models.append(HuggingFaceModel.from_dict(model_to_cache))
107-
else:
108-
config.hf_cache = HuggingFaceCache.from_list([model_to_cache])
109-
build_config.arguments[
110-
"model"
111-
] = f"/app/hf_cache/{model_name.replace('gs://', '')}"
112-
11388
hf_access_token = config.secrets.get(HF_ACCESS_TOKEN_SECRET_NAME)
11489
dockerfile_template = read_template_from_fs(
11590
TEMPLATES_DIR, "vllm/vllm.Dockerfile.jinja"
11691
)
11792
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja")
118-
copy_into_build_dir(
119-
TEMPLATES_DIR / "cache_requirements.txt", "cache_requirements.txt"
120-
)
12193

122-
model_files = {}
123-
if config.hf_cache:
124-
curr_dir = Path(__file__).parent.resolve()
125-
copy_into_build_dir(curr_dir / "cache_warmer.py", "cache_warmer.py")
126-
for model in config.hf_cache.models:
127-
repo_id = model.repo_id
128-
revision = model.revision
129-
130-
allow_patterns = model.allow_patterns
131-
ignore_patterns = model.ignore_patterns
132-
133-
filtered_repo_files = list(
134-
filter_repo_objects(
135-
items=list_files(
136-
repo_id, truss_dir / config.data_dir, revision=revision
137-
),
138-
allow_patterns=allow_patterns,
139-
ignore_patterns=ignore_patterns,
140-
)
141-
)
142-
model_files[repo_id] = {
143-
"files": filtered_repo_files,
144-
"revision": revision,
145-
}
146-
147-
dockerfile_content = dockerfile_template.render(
148-
hf_access_token=hf_access_token,
149-
models=model_files,
150-
should_install_server_requirements=True,
151-
)
94+
dockerfile_content = dockerfile_template.render(hf_access_token=hf_access_token)
15295
dockerfile_filepath = build_dir / "Dockerfile"
15396
dockerfile_filepath.write_text(dockerfile_content)
15497

@@ -167,47 +110,6 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
167110
supervisord_filepath.write_text(supervisord_contents)
168111

169112

170-
def split_gs_path(gs_path):
171-
# Remove the 'gs://' prefix
172-
path = gs_path.replace("gs://", "")
173-
174-
# Split on the first slash
175-
parts = path.split("/", 1)
176-
177-
bucket_name = parts[0]
178-
prefix = parts[1] if len(parts) > 1 else ""
179-
180-
return bucket_name, prefix
181-
182-
183-
def list_bucket_files(bucket_name, data_dir, is_trusted=False):
184-
# TODO(varun): provide support for aws s3
185-
186-
if is_trusted:
187-
storage_client = storage.Client.from_service_account_json(
188-
data_dir / "service_account.json"
189-
)
190-
else:
191-
storage_client = storage.Client()
192-
print(bucket_name.replace("gs://", ""))
193-
bucket_name, prefix = split_gs_path(bucket_name)
194-
blobs = storage_client.list_blobs(bucket_name, prefix=prefix)
195-
196-
all_objects = []
197-
for blob in blobs:
198-
all_objects.append(Path(blob.name).name)
199-
print(Path(blob.name).name)
200-
return all_objects
201-
202-
203-
def list_files(repo_id, data_dir, revision=None):
204-
if repo_id.startswith(("s3://", "gs://")):
205-
return list_bucket_files(repo_id, data_dir, is_trusted=True)
206-
else:
207-
# we assume it's a HF bucket
208-
list_repo_files(repo_id, revision=revision)
209-
210-
211113
class ServingImageBuilderContext(TrussContext):
212114
@staticmethod
213115
def run(truss_dir: Path):
@@ -241,7 +143,7 @@ def prepare_image_build_dir(
241143
create_tgi_build_dir(config, build_dir)
242144
return
243145
elif config.build.model_server is ModelServer.VLLM:
244-
create_vllm_build_dir(config, build_dir, truss_dir)
146+
create_vllm_build_dir(config, build_dir)
245147
return
246148

247149
data_dir = build_dir / config.data_dir # type: ignore[operator]
@@ -273,7 +175,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
273175

274176
filtered_repo_files = list(
275177
filter_repo_objects(
276-
items=list_files(repo_id, data_dir, revision=revision),
178+
items=list_repo_files(repo_id, revision=revision),
277179
allow_patterns=allow_patterns,
278180
ignore_patterns=ignore_patterns,
279181
)

‎truss/templates/base.Dockerfile.jinja

+2-4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ RUN pip install -r requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
4444
{% endblock %}
4545

4646

47+
{% block cache_weights %}
48+
{% endblock %}
4749

4850

4951
ENV APP_HOME /app
@@ -53,10 +55,6 @@ WORKDIR $APP_HOME
5355
{% block app_copy %}
5456
{% endblock %}
5557

56-
57-
{% block cache_weights %}
58-
{% endblock %}
59-
6058
{% block bundled_packages_copy %}
6159
{%- if bundled_packages_dir_exists %}
6260
COPY ./{{config.bundled_packages_dir}} /packages

‎truss/templates/cache_requirements.txt

-2
This file was deleted.

‎truss/templates/vllm/vllm.Dockerfile.jinja

+5-19
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,19 @@ FROM baseten/vllm:latest
22

33
EXPOSE 8080-9000
44

5-
{% if hf_access_token %}
6-
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
7-
{% endif %}
8-
9-
COPY ./data /app/data
10-
11-
{%- if hf_cache != None %}
12-
COPY ./cache_warmer.py /cache_warmer.py
13-
14-
COPY ./cache_requirements.txt /app/cache_requirements.txt
15-
16-
RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
17-
{% for repo, hf_dir in models.items() %}
18-
{% for file in hf_dir.files %}
19-
RUN python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %}
20-
{% endfor %}
21-
{% endfor %}
22-
{%- endif %}
23-
245
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
256
curl nginx supervisor && \
267
rm -rf /var/lib/apt/lists/*
278

9+
2810
COPY ./proxy.conf /etc/nginx/conf.d/proxy.conf
2911

3012
RUN mkdir -p /var/log/supervisor
3113
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
3214

15+
{% if hf_access_token %}
16+
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
17+
{% endif %}
18+
3319
ENV SERVER_START_CMD /usr/bin/supervisord
3420
ENTRYPOINT ["/usr/bin/supervisord"]

0 commit comments

Comments
 (0)