Skip to content

Commit 59de884

Browse files
author
Varun Shenoy
authored
Multistage builds for caching (#623)
* vllm multistage * add multistage builds for vllm + tgi * use slim instead * added multistage hf copies on separate layers * unit tests complete * add truss server integration test * added truss server test * rename function * clean up cache_warmer * template out cache in separate jinja file * templated out hf_cache section and cleaned up code * add integration tests with download checks + template out copies * increase sleep time * fix integration test for
1 parent db2dd05 commit 59de884

17 files changed

+421
-115
lines changed

truss/contexts/image_builder/cache_warmer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ def download_file(
6363
secret_path = Path("/etc/secrets/hf-access-token")
6464
secret = secret_path.read_text().strip() if secret_path.exists() else None
6565
try:
66-
hf_hub_download(repo_name, file_name, revision=revision_name, token=secret)
66+
hf_hub_download(
67+
repo_name,
68+
file_name,
69+
revision=revision_name,
70+
token=secret,
71+
)
6772
except FileNotFoundError:
6873
raise RuntimeError(
6974
"Hugging Face repository not found (and no valid secret found for possibly private repository)."

truss/contexts/image_builder/serving_image_builder.py

+103-46
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional
2+
from typing import Any, Dict, List, Optional
33

44
import yaml
55
from google.cloud import storage
6-
from huggingface_hub import list_repo_files
6+
from huggingface_hub import get_hf_file_metadata, hf_hub_url, list_repo_files
77
from huggingface_hub.utils import filter_repo_objects
88
from truss.constants import (
99
BASE_SERVER_REQUIREMENTS_TXT_FILENAME,
@@ -29,13 +29,7 @@
2929
)
3030
from truss.contexts.truss_context import TrussContext
3131
from truss.patch.hash import directory_content_hash
32-
from truss.truss_config import (
33-
Build,
34-
HuggingFaceCache,
35-
HuggingFaceModel,
36-
ModelServer,
37-
TrussConfig,
38-
)
32+
from truss.truss_config import Build, HuggingFaceModel, ModelServer, TrussConfig
3933
from truss.truss_spec import TrussSpec
4034
from truss.util.download import download_external_data
4135
from truss.util.jinja import read_template_from_fs
@@ -133,33 +127,43 @@ def list_files(repo_id, data_dir, revision=None):
133127
return list_repo_files(repo_id, revision=revision)
134128

135129

136-
def update_config_and_gather_files(
137-
config: TrussConfig, truss_dir: Path, build_dir: Path, server_name: str
138-
):
130+
def update_model_key(config: TrussConfig) -> str:
131+
server_name = config.build.model_server
132+
133+
if server_name == ModelServer.TGI:
134+
return "model_id"
135+
elif server_name == ModelServer.VLLM:
136+
return "model"
137+
138+
raise ValueError(
139+
f"Invalid server name (must be `TGI` or `VLLM`, not `{server_name}`)."
140+
)
141+
142+
143+
def update_model_name(config: TrussConfig, model_key: str) -> str:
144+
if model_key not in config.build.arguments:
145+
# We should definitely just use the same key across both vLLM and TGI
146+
raise KeyError(
147+
"Key for model missing in config or incorrect key used. Use `model` for VLLM and `model_id` for TGI."
148+
)
149+
model_name = config.build.arguments[model_key]
150+
if "gs://" in model_name:
151+
# if we are pulling from a gs bucket, we want to alias it as a part of the cache
152+
model_to_cache = HuggingFaceModel(model_name)
153+
config.hf_cache.models.append(model_to_cache)
154+
155+
config.build.arguments[
156+
model_key
157+
] = f"/app/hf_cache/{model_name.replace('gs://', '')}"
158+
return model_name
159+
160+
161+
def get_files_to_cache(config: TrussConfig, truss_dir: Path, build_dir: Path):
139162
def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
140163
copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator]
141164

142-
if server_name == "TGI":
143-
model_key = "model_id"
144-
elif server_name == "vLLM":
145-
model_key = "model"
146-
147-
if server_name != "TrussServer":
148-
model_name = config.build.arguments[model_key]
149-
if "gs://" in model_name:
150-
# if we are pulling from a gs bucket, we want to alias it as a part of the cache
151-
model_to_cache = {"repo_id": model_name}
152-
if config.hf_cache:
153-
config.hf_cache.models.append(
154-
HuggingFaceModel.from_dict(model_to_cache)
155-
)
156-
else:
157-
config.hf_cache = HuggingFaceCache.from_list([model_to_cache])
158-
config.build.arguments[
159-
model_key
160-
] = f"/app/hf_cache/{model_name.replace('gs://', '')}"
161-
162165
model_files = {}
166+
cached_files: List[str] = []
163167
if config.hf_cache:
164168
curr_dir = Path(__file__).parent.resolve()
165169
copy_into_build_dir(curr_dir / "cache_warmer.py", "cache_warmer.py")
@@ -180,18 +184,52 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
180184
)
181185
)
182186

183-
if "gs://" in repo_id:
184-
repo_id, _ = split_gs_path(repo_id)
185-
repo_id = f"gs://{repo_id}"
187+
cached_files = fetch_files_to_cache(
188+
cached_files, repo_id, filtered_repo_files
189+
)
190+
191+
model_files[repo_id] = {"files": filtered_repo_files, "revision": revision}
186192

187-
model_files[repo_id] = {
188-
"files": filtered_repo_files,
189-
"revision": revision,
190-
}
191193
copy_into_build_dir(
192194
TEMPLATES_DIR / "cache_requirements.txt", "cache_requirements.txt"
193195
)
194-
return model_files
196+
return model_files, cached_files
197+
198+
199+
def fetch_files_to_cache(cached_files: list, repo_id: str, filtered_repo_files: list):
200+
if "gs://" in repo_id:
201+
bucket_name, _ = split_gs_path(repo_id)
202+
repo_id = f"gs://{bucket_name}"
203+
204+
for filename in filtered_repo_files:
205+
cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}")
206+
else:
207+
repo_folder_name = f"models--{repo_id.replace('/', '--')}"
208+
for filename in filtered_repo_files:
209+
hf_url = hf_hub_url(repo_id, filename)
210+
hf_file_metadata = get_hf_file_metadata(hf_url)
211+
212+
cached_files.append(f"{repo_folder_name}/blobs/{hf_file_metadata.etag}")
213+
214+
# snapshots is just a set of folders with symlinks -- we can copy the entire thing separately
215+
cached_files.append(f"{repo_folder_name}/snapshots/")
216+
217+
# refs just has files with revision commit hashes
218+
cached_files.append(f"{repo_folder_name}/refs/")
219+
220+
cached_files.append("version.txt")
221+
222+
return cached_files
223+
224+
225+
def update_config_and_gather_files(
226+
config: TrussConfig, truss_dir: Path, build_dir: Path
227+
):
228+
if config.build.model_server != ModelServer.TrussServer:
229+
model_key = update_model_key(config)
230+
update_model_name(config, model_key)
231+
232+
return get_files_to_cache(config, truss_dir, build_dir)
195233

196234

197235
def create_tgi_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path):
@@ -200,19 +238,24 @@ def create_tgi_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path):
200238
if not build_dir.exists():
201239
build_dir.mkdir(parents=True)
202240

203-
model_files = update_config_and_gather_files(config, truss_dir, build_dir, "TGI")
241+
model_files, cached_file_paths = update_config_and_gather_files(
242+
config, truss_dir, build_dir
243+
)
204244

205245
hf_access_token = config.secrets.get(HF_ACCESS_TOKEN_SECRET_NAME)
206246
dockerfile_template = read_template_from_fs(
207247
TEMPLATES_DIR, "tgi/tgi.Dockerfile.jinja"
208248
)
209249

210250
data_dir = build_dir / "data"
251+
credentials_file = data_dir / "service_account.json"
211252
dockerfile_content = dockerfile_template.render(
212253
hf_access_token=hf_access_token,
213254
models=model_files,
214255
hf_cache=config.hf_cache,
215-
data_dir_exists=Path(data_dir).exists(),
256+
data_dir_exists=data_dir.exists(),
257+
credentials_exists=credentials_file.exists(),
258+
cached_files=cached_file_paths,
216259
)
217260
dockerfile_filepath = build_dir / "Dockerfile"
218261
dockerfile_filepath.write_text(dockerfile_content)
@@ -247,7 +290,9 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path)
247290
build_config: Build = config.build
248291
server_endpoint = server_endpoint_config[build_config.arguments.pop("endpoint")]
249292

250-
model_files = update_config_and_gather_files(config, truss_dir, build_dir, "vLLM")
293+
model_files, cached_file_paths = update_config_and_gather_files(
294+
config, truss_dir, build_dir
295+
)
251296

252297
hf_access_token = config.secrets.get(HF_ACCESS_TOKEN_SECRET_NAME)
253298
dockerfile_template = read_template_from_fs(
@@ -256,12 +301,15 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path)
256301
nginx_template = read_template_from_fs(TEMPLATES_DIR, "vllm/proxy.conf.jinja")
257302

258303
data_dir = build_dir / "data"
304+
credentials_file = data_dir / "service_account.json"
259305
dockerfile_content = dockerfile_template.render(
260306
hf_access_token=hf_access_token,
261307
models=model_files,
262308
should_install_server_requirements=True,
263309
hf_cache=config.hf_cache,
264310
data_dir_exists=data_dir.exists(),
311+
credentials_exists=credentials_file.exists(),
312+
cached_files=cached_file_paths,
265313
)
266314
dockerfile_filepath = build_dir / "Dockerfile"
267315
dockerfile_filepath.write_text(dockerfile_content)
@@ -336,8 +384,8 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
336384
download_external_data(self._spec.external_data, data_dir)
337385

338386
# Download from HuggingFace
339-
model_files = update_config_and_gather_files(
340-
config, truss_dir, build_dir, server_name="TrussServer"
387+
model_files, cached_files = update_config_and_gather_files(
388+
config, truss_dir, build_dir
341389
)
342390

343391
# Copy inference server code
@@ -391,7 +439,11 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
391439
(build_dir / SYSTEM_PACKAGES_TXT_FILENAME).write_text(spec.system_packages_txt)
392440

393441
self._render_dockerfile(
394-
build_dir, should_install_server_requirements, model_files, use_hf_secret
442+
build_dir,
443+
should_install_server_requirements,
444+
model_files,
445+
use_hf_secret,
446+
cached_files,
395447
)
396448

397449
def _render_dockerfile(
@@ -400,10 +452,12 @@ def _render_dockerfile(
400452
should_install_server_requirements: bool,
401453
model_files: Dict[str, Any],
402454
use_hf_secret: bool,
455+
cached_files: List[str],
403456
):
404457
config = self._spec.config
405458
data_dir = build_dir / config.data_dir
406459
bundled_packages_dir = build_dir / config.bundled_packages_dir
460+
credentials_file = data_dir / "service_account.json"
407461
dockerfile_template = read_template_from_fs(
408462
TEMPLATES_DIR, SERVER_DOCKERFILE_TEMPLATE_NAME
409463
)
@@ -437,6 +491,9 @@ def _render_dockerfile(
437491
truss_hash=directory_content_hash(self._truss_dir),
438492
models=model_files,
439493
use_hf_secret=use_hf_secret,
494+
cached_files=cached_files,
495+
credentials_exists=credentials_file.exists(),
496+
hf_cache=len(config.hf_cache.models) > 0,
440497
)
441498
docker_file_path = build_dir / MODEL_DOCKERFILE_NAME
442499
docker_file_path.write_text(dockerfile_contents)

truss/templates/base.Dockerfile.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
ARG PYVERSION={{config.python_version}}
2-
FROM {{base_image_name_and_tag}}
2+
FROM {{base_image_name_and_tag}} as truss_server
33

44
ENV PYTHON_EXECUTABLE {{ config.base_image.python_executable_path or 'python3' }}
55

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
FROM python:3.11-slim as cache_warmer
2+
3+
RUN mkdir -p /app/hf_cache
4+
WORKDIR /app
5+
6+
{% if hf_access_token %}
7+
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
8+
{% endif %}
9+
{%- if credentials_exists %}
10+
COPY ./data/service_account.json /app/data/service_account.json
11+
{%- endif %}
12+
13+
RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
14+
ENV B10CP_PATH_TRUSS /app/b10cp
15+
COPY ./cache_requirements.txt /app/cache_requirements.txt
16+
RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
17+
COPY ./cache_warmer.py /cache_warmer.py
18+
{% for repo, hf_dir in models.items() %}
19+
{% for file in hf_dir.files %}
20+
RUN python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %}
21+
{% endfor %}
22+
{% endfor %}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{% for file in cached_files %}
2+
{%- if credentials_exists %}
3+
COPY --from=cache_warmer {{file}} {{file}}
4+
{%- else %}
5+
COPY --from=cache_warmer ./root/.cache/huggingface/hub/{{file}} {{hf_dst_directory}}{{file}}
6+
{%- endif %}
7+
{% endfor %}

truss/templates/server.Dockerfile.jinja

+11-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
{%- if hf_cache %}
2+
{%- include "cache.Dockerfile.jinja" %}
3+
{%- endif %}
4+
15
{% extends "base.Dockerfile.jinja" %}
26

37
{% block base_image_patch %}
@@ -52,20 +56,6 @@ RUN pip install -r server_requirements.txt --no-cache-dir && rm -rf /root/.cache
5256
COPY ./{{config.data_dir}} /app/data
5357
{%- endif %}
5458

55-
{%- if config.hf_cache != None %}
56-
RUN mkdir -p /app
57-
RUN curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
58-
ENV B10CP_PATH_TRUSS /app/b10cp
59-
COPY ./cache_requirements.txt /app/cache_requirements.txt
60-
RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
61-
COPY ./cache_warmer.py /cache_warmer.py
62-
{% for repo, hf_dir in models.items() %}
63-
{% for file in hf_dir.files %}
64-
{{ "RUN --mount=type=secret,id=hf_access_token,dst=/etc/secrets/hf_access_token" if use_hf_secret else "RUN" }} $PYTHON_EXECUTABLE /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %}
65-
{% endfor %}
66-
{% endfor %}
67-
{%- endif %}
68-
6959
COPY ./server /app
7060
COPY ./{{ config.model_module_dir }} /app/model
7161
COPY ./config.yaml /app/config.yaml
@@ -74,6 +64,13 @@ COPY ./control /control
7464
RUN python3 -m venv /control/.env \
7565
&& /control/.env/bin/pip3 install -r /control/requirements.txt
7666
{%- endif %}
67+
68+
{%- if hf_cache %}
69+
{%- set hf_dst_directory="/root/.cache/huggingface/hub/"%}
70+
{%- include "copy_cache_files.Dockerfile.jinja"%}
71+
{%- endif %}
72+
73+
7774
{% endblock %}
7875

7976
{% block run %}

0 commit comments

Comments
 (0)