Skip to content

Commit 0a52e52

Browse files
author
Varun Shenoy
authored
Bundling Private Weights from GCP (#552)
* working gcp bundling * update save dir to cache_dir / {bucket} * change toml * update poetry.lock * update pyproject * bump * vllm working * cleaned up example * cleanup * bump * model can now point to a gs bucket * alias model_name for gcs * bump * fix cache_requirements * bump * bump * bump * remove cache_reqs * remove cache_reqs * bump * add cache_reqs * bump * update serving builder * remove spec * revert toml * bump * add gc to toml * update lock
1 parent 073ef84 commit 0a52e52

File tree

9 files changed

+1089
-735
lines changed

9 files changed

+1089
-735
lines changed

examples/vllm-gcs/config.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
build:
2+
arguments:
3+
endpoint: Completions
4+
model: gs://llama-2-7b
5+
tokenizer: hf-internal-testing/llama-tokenizer
6+
model_server: VLLM
7+
environment_variables: {}
8+
external_package_dirs: []
9+
model_metadata: {}
10+
model_name: vllm llama gcs
11+
python_version: py39
12+
requirements: []
13+
resources:
14+
accelerator: A10G
15+
cpu: 500m
16+
memory: 30Gi
17+
use_gpu: true
18+
secrets: {}
19+
system_packages: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
YOUR SERVICE ACCOUNT KEY

poetry.lock

+909-709
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.6.1"
3+
version = "0.6.2"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"
@@ -41,6 +41,7 @@ watchfiles = "^0.19.0"
4141
huggingface_hub = "^0.16.4"
4242
rich-click = "^1.6.1"
4343
inquirerpy = "^0.3.4"
44+
google-cloud-storage = "2.10.0"
4445

4546
[tool.poetry.group.builder.dependencies]
4647
python = ">=3.8,<3.12"
@@ -59,6 +60,7 @@ uvicorn = "^0.21.1"
5960
httpx = "^0.24.1"
6061
psutil = "^5.9.4"
6162
huggingface_hub = "^0.16.4"
63+
google-cloud-storage = "2.10.0"
6264

6365
[tool.poetry.dev-dependencies]
6466
torch = "^1.9.0"

truss/contexts/image_builder/cache_warmer.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,41 @@
11
import sys
22
from pathlib import Path
33

4+
from google.cloud import storage
45
from huggingface_hub import hf_hub_download
56

67

7-
def download_file(repo_name, file_name, revision_name=None):
8-
secret = None
9-
secret_path = Path("/etc/secrets/hf_access_token")
10-
11-
if secret_path.exists():
12-
secret = secret_path.read_text().strip()
13-
try:
14-
hf_hub_download(repo_name, file_name, revision=revision_name, token=secret)
15-
except FileNotFoundError:
16-
raise RuntimeError(
17-
"Hugging Face repository not found (and no valid secret found for possibly private repository)."
18-
)
8+
def download_file(
9+
repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json"
10+
):
11+
# Check if repo_name starts with "gs://"
12+
if "gs://" in repo_name:
13+
# Create directory if not exist
14+
repo_name = repo_name.replace("gs://", "")
15+
cache_dir = Path(f"/app/hf_cache/{repo_name}")
16+
cache_dir.mkdir(parents=True, exist_ok=True)
17+
18+
# Connect to GCS storage
19+
try:
20+
storage_client = storage.Client.from_service_account_json(key_file)
21+
bucket = storage_client.bucket(repo_name)
22+
blob = bucket.blob(file_name)
23+
# Download the blob to a file
24+
blob.download_to_filename(f"{cache_dir}/{file_name}")
25+
except Exception as e:
26+
raise RuntimeError(f"Failure downloading file from GCS: {e}")
27+
else:
28+
secret_path = Path("/etc/secrets/hf_access_token")
29+
secret = secret_path.read_text().strip() if secret_path.exists() else None
30+
try:
31+
hf_hub_download(repo_name, file_name, revision=revision_name, token=secret)
32+
except FileNotFoundError:
33+
raise RuntimeError(
34+
"Hugging Face repository not found (and no valid secret found for possibly private repository)."
35+
)
1936

2037

2138
if __name__ == "__main__":
22-
# TODO(varun): might make sense to move this check + write to a separate `prepare_cache.py` script
2339
file_path = Path.home() / ".cache/huggingface/hub/version.txt"
2440
file_path.parent.mkdir(parents=True, exist_ok=True)
2541

truss/contexts/image_builder/serving_image_builder.py

+103-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, Optional
33

44
import yaml
5+
from google.cloud import storage
56
from huggingface_hub import list_repo_files
67
from huggingface_hub.utils import filter_repo_objects
78
from truss.constants import (
@@ -27,7 +28,13 @@
2728
)
2829
from truss.contexts.truss_context import TrussContext
2930
from truss.patch.hash import directory_content_hash
30-
from truss.truss_config import Build, ModelServer, TrussConfig
31+
from truss.truss_config import (
32+
Build,
33+
HuggingFaceCache,
34+
HuggingFaceModel,
35+
ModelServer,
36+
TrussConfig,
37+
)
3138
from truss.truss_spec import TrussSpec
3239
from truss.util.download import download_external_data
3340
from truss.util.jinja import read_template_from_fs
@@ -75,7 +82,12 @@ def create_tgi_build_dir(config: TrussConfig, build_dir: Path):
7582
supervisord_filepath.write_text(supervisord_contents)
7683

7784

78-
def create_vllm_build_dir(config: TrussConfig, build_dir: Path):
85+
def create_vllm_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path):
86+
def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
87+
copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator]
88+
89+
copy_tree_path(truss_dir, build_dir)
90+
7991
server_endpoint_config = {
8092
"Completions": "/v1/completions",
8193
"ChatCompletions": "/v1/chat/completions",
@@ -85,13 +97,58 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path):
8597

8698
build_config: Build = config.build
8799
server_endpoint = server_endpoint_config[build_config.arguments.pop("endpoint")]
100+
101+
model_name = build_config.arguments.pop("model")
102+
if "gs://" in model_name:
103+
# if we are pulling from a gs bucket, we want to alias it as a part of the cache
104+
model_to_cache = {"repo_id": model_name}
105+
if config.hf_cache:
106+
config.hf_cache.models.append(HuggingFaceModel.from_dict(model_to_cache))
107+
else:
108+
config.hf_cache = HuggingFaceCache.from_list([model_to_cache])
109+
build_config.arguments[
110+
"model"
111+
] = f"/app/hf_cache/{model_name.replace('gs://', '')}"
112+
88113
hf_access_token = config.secrets.get(HF_ACCESS_TOKEN_SECRET_NAME)
89114
dockerfile_template = read_template_from_fs(
90115
TEMPLATES_DIR, "vllm/vllm.Dockerfile.jinja"
91116
)
92117
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja")
118+
copy_into_build_dir(
119+
TEMPLATES_DIR / "cache_requirements.txt", "cache_requirements.txt"
120+
)
93121

94-
dockerfile_content = dockerfile_template.render(hf_access_token=hf_access_token)
122+
model_files = {}
123+
if config.hf_cache:
124+
curr_dir = Path(__file__).parent.resolve()
125+
copy_into_build_dir(curr_dir / "cache_warmer.py", "cache_warmer.py")
126+
for model in config.hf_cache.models:
127+
repo_id = model.repo_id
128+
revision = model.revision
129+
130+
allow_patterns = model.allow_patterns
131+
ignore_patterns = model.ignore_patterns
132+
133+
filtered_repo_files = list(
134+
filter_repo_objects(
135+
items=list_files(
136+
repo_id, truss_dir / config.data_dir, revision=revision
137+
),
138+
allow_patterns=allow_patterns,
139+
ignore_patterns=ignore_patterns,
140+
)
141+
)
142+
model_files[repo_id] = {
143+
"files": filtered_repo_files,
144+
"revision": revision,
145+
}
146+
147+
dockerfile_content = dockerfile_template.render(
148+
hf_access_token=hf_access_token,
149+
models=model_files,
150+
should_install_server_requirements=True,
151+
)
95152
dockerfile_filepath = build_dir / "Dockerfile"
96153
dockerfile_filepath.write_text(dockerfile_content)
97154

@@ -110,6 +167,47 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path):
110167
supervisord_filepath.write_text(supervisord_contents)
111168

112169

170+
def split_gs_path(gs_path):
171+
# Remove the 'gs://' prefix
172+
path = gs_path.replace("gs://", "")
173+
174+
# Split on the first slash
175+
parts = path.split("/", 1)
176+
177+
bucket_name = parts[0]
178+
prefix = parts[1] if len(parts) > 1 else ""
179+
180+
return bucket_name, prefix
181+
182+
183+
def list_bucket_files(bucket_name, data_dir, is_trusted=False):
184+
# TODO(varun): provide support for aws s3
185+
186+
if is_trusted:
187+
storage_client = storage.Client.from_service_account_json(
188+
data_dir / "service_account.json"
189+
)
190+
else:
191+
storage_client = storage.Client()
192+
print(bucket_name.replace("gs://", ""))
193+
bucket_name, prefix = split_gs_path(bucket_name)
194+
blobs = storage_client.list_blobs(bucket_name, prefix=prefix)
195+
196+
all_objects = []
197+
for blob in blobs:
198+
all_objects.append(Path(blob.name).name)
199+
print(Path(blob.name).name)
200+
return all_objects
201+
202+
203+
def list_files(repo_id, data_dir, revision=None):
204+
if repo_id.startswith(("s3://", "gs://")):
205+
return list_bucket_files(repo_id, data_dir, is_trusted=True)
206+
else:
207+
# we assume it's a HF bucket
208+
list_repo_files(repo_id, revision=revision)
209+
210+
113211
class ServingImageBuilderContext(TrussContext):
114212
@staticmethod
115213
def run(truss_dir: Path):
@@ -143,7 +241,7 @@ def prepare_image_build_dir(
143241
create_tgi_build_dir(config, build_dir)
144242
return
145243
elif config.build.model_server is ModelServer.VLLM:
146-
create_vllm_build_dir(config, build_dir)
244+
create_vllm_build_dir(config, build_dir, truss_dir)
147245
return
148246

149247
data_dir = build_dir / config.data_dir # type: ignore[operator]
@@ -175,7 +273,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
175273

176274
filtered_repo_files = list(
177275
filter_repo_objects(
178-
items=list_repo_files(repo_id, revision=revision),
276+
items=list_files(repo_id, data_dir, revision=revision),
179277
allow_patterns=allow_patterns,
180278
ignore_patterns=ignore_patterns,
181279
)

truss/templates/base.Dockerfile.jinja

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ RUN pip install -r requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
4444
{% endblock %}
4545

4646

47-
{% block cache_weights %}
48-
{% endblock %}
4947

5048

5149
ENV APP_HOME /app
@@ -55,6 +53,10 @@ WORKDIR $APP_HOME
5553
{% block app_copy %}
5654
{% endblock %}
5755

56+
57+
{% block cache_weights %}
58+
{% endblock %}
59+
5860
{% block bundled_packages_copy %}
5961
{%- if bundled_packages_dir_exists %}
6062
COPY ./{{config.bundled_packages_dir}} /packages
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
huggingface-hub==0.16.4
2+
google-cloud-storage==2.10.0

truss/templates/vllm/vllm.Dockerfile.jinja

+19-5
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,33 @@ FROM baseten/vllm:latest
22

33
EXPOSE 8080-9000
44

5+
{% if hf_access_token %}
6+
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
7+
{% endif %}
8+
9+
COPY ./data /app/data
10+
11+
{%- if hf_cache != None %}
12+
COPY ./cache_warmer.py /cache_warmer.py
13+
14+
COPY ./cache_requirements.txt /app/cache_requirements.txt
15+
16+
RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
17+
{% for repo, hf_dir in models.items() %}
18+
{% for file in hf_dir.files %}
19+
RUN python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %}
20+
{% endfor %}
21+
{% endfor %}
22+
{%- endif %}
23+
524
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
625
curl nginx supervisor && \
726
rm -rf /var/lib/apt/lists/*
827

9-
1028
COPY ./proxy.conf /etc/nginx/conf.d/proxy.conf
1129

1230
RUN mkdir -p /var/log/supervisor
1331
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
1432

15-
{% if hf_access_token %}
16-
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
17-
{% endif %}
18-
1933
ENV SERVER_START_CMD /usr/bin/supervisord
2034
ENTRYPOINT ["/usr/bin/supervisord"]

0 commit comments

Comments
 (0)