2
2
from typing import Any , Dict , Optional
3
3
4
4
import yaml
5
- from google .cloud import storage
6
5
from huggingface_hub import list_repo_files
7
6
from huggingface_hub .utils import filter_repo_objects
8
7
from truss .constants import (
28
27
)
29
28
from truss .contexts .truss_context import TrussContext
30
29
from truss .patch .hash import directory_content_hash
31
- from truss .truss_config import (
32
- Build ,
33
- HuggingFaceCache ,
34
- HuggingFaceModel ,
35
- ModelServer ,
36
- TrussConfig ,
37
- )
30
+ from truss .truss_config import Build , ModelServer , TrussConfig
38
31
from truss .truss_spec import TrussSpec
39
32
from truss .util .download import download_external_data
40
33
from truss .util .jinja import read_template_from_fs
@@ -82,12 +75,7 @@ def create_tgi_build_dir(config: TrussConfig, build_dir: Path):
82
75
supervisord_filepath .write_text (supervisord_contents )
83
76
84
77
85
- def create_vllm_build_dir (config : TrussConfig , build_dir : Path , truss_dir : Path ):
86
- def copy_into_build_dir (from_path : Path , path_in_build_dir : str ):
87
- copy_tree_or_file (from_path , build_dir / path_in_build_dir ) # type: ignore[operator]
88
-
89
- copy_tree_path (truss_dir , build_dir )
90
-
78
+ def create_vllm_build_dir (config : TrussConfig , build_dir : Path ):
91
79
server_endpoint_config = {
92
80
"Completions" : "/v1/completions" ,
93
81
"ChatCompletions" : "/v1/chat/completions" ,
@@ -97,58 +85,13 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
97
85
98
86
build_config : Build = config .build
99
87
server_endpoint = server_endpoint_config [build_config .arguments .pop ("endpoint" )]
100
-
101
- model_name = build_config .arguments .pop ("model" )
102
- if "gs://" in model_name :
103
- # if we are pulling from a gs bucket, we want to alias it as a part of the cache
104
- model_to_cache = {"repo_id" : model_name }
105
- if config .hf_cache :
106
- config .hf_cache .models .append (HuggingFaceModel .from_dict (model_to_cache ))
107
- else :
108
- config .hf_cache = HuggingFaceCache .from_list ([model_to_cache ])
109
- build_config .arguments [
110
- "model"
111
- ] = f"/app/hf_cache/{ model_name .replace ('gs://' , '' )} "
112
-
113
88
hf_access_token = config .secrets .get (HF_ACCESS_TOKEN_SECRET_NAME )
114
89
dockerfile_template = read_template_from_fs (
115
90
TEMPLATES_DIR , "vllm/vllm.Dockerfile.jinja"
116
91
)
117
92
nginx_template = read_template_from_fs (TEMPLATES_DIR , "vllm/proxy.conf.jinja" )
118
- copy_into_build_dir (
119
- TEMPLATES_DIR / "cache_requirements.txt" , "cache_requirements.txt"
120
- )
121
93
122
- model_files = {}
123
- if config .hf_cache :
124
- curr_dir = Path (__file__ ).parent .resolve ()
125
- copy_into_build_dir (curr_dir / "cache_warmer.py" , "cache_warmer.py" )
126
- for model in config .hf_cache .models :
127
- repo_id = model .repo_id
128
- revision = model .revision
129
-
130
- allow_patterns = model .allow_patterns
131
- ignore_patterns = model .ignore_patterns
132
-
133
- filtered_repo_files = list (
134
- filter_repo_objects (
135
- items = list_files (
136
- repo_id , truss_dir / config .data_dir , revision = revision
137
- ),
138
- allow_patterns = allow_patterns ,
139
- ignore_patterns = ignore_patterns ,
140
- )
141
- )
142
- model_files [repo_id ] = {
143
- "files" : filtered_repo_files ,
144
- "revision" : revision ,
145
- }
146
-
147
- dockerfile_content = dockerfile_template .render (
148
- hf_access_token = hf_access_token ,
149
- models = model_files ,
150
- should_install_server_requirements = True ,
151
- )
94
+ dockerfile_content = dockerfile_template .render (hf_access_token = hf_access_token )
152
95
dockerfile_filepath = build_dir / "Dockerfile"
153
96
dockerfile_filepath .write_text (dockerfile_content )
154
97
@@ -167,47 +110,6 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
167
110
supervisord_filepath .write_text (supervisord_contents )
168
111
169
112
170
- def split_gs_path (gs_path ):
171
- # Remove the 'gs://' prefix
172
- path = gs_path .replace ("gs://" , "" )
173
-
174
- # Split on the first slash
175
- parts = path .split ("/" , 1 )
176
-
177
- bucket_name = parts [0 ]
178
- prefix = parts [1 ] if len (parts ) > 1 else ""
179
-
180
- return bucket_name , prefix
181
-
182
-
183
- def list_bucket_files (bucket_name , data_dir , is_trusted = False ):
184
- # TODO(varun): provide support for aws s3
185
-
186
- if is_trusted :
187
- storage_client = storage .Client .from_service_account_json (
188
- data_dir / "service_account.json"
189
- )
190
- else :
191
- storage_client = storage .Client ()
192
- print (bucket_name .replace ("gs://" , "" ))
193
- bucket_name , prefix = split_gs_path (bucket_name )
194
- blobs = storage_client .list_blobs (bucket_name , prefix = prefix )
195
-
196
- all_objects = []
197
- for blob in blobs :
198
- all_objects .append (Path (blob .name ).name )
199
- print (Path (blob .name ).name )
200
- return all_objects
201
-
202
-
203
- def list_files (repo_id , data_dir , revision = None ):
204
- if repo_id .startswith (("s3://" , "gs://" )):
205
- return list_bucket_files (repo_id , data_dir , is_trusted = True )
206
- else :
207
- # we assume it's a HF bucket
208
- list_repo_files (repo_id , revision = revision )
209
-
210
-
211
113
class ServingImageBuilderContext (TrussContext ):
212
114
@staticmethod
213
115
def run (truss_dir : Path ):
@@ -241,7 +143,7 @@ def prepare_image_build_dir(
241
143
create_tgi_build_dir (config , build_dir )
242
144
return
243
145
elif config .build .model_server is ModelServer .VLLM :
244
- create_vllm_build_dir (config , build_dir , truss_dir )
146
+ create_vllm_build_dir (config , build_dir )
245
147
return
246
148
247
149
data_dir = build_dir / config .data_dir # type: ignore[operator]
@@ -273,7 +175,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
273
175
274
176
filtered_repo_files = list (
275
177
filter_repo_objects (
276
- items = list_files (repo_id , data_dir , revision = revision ),
178
+ items = list_repo_files (repo_id , revision = revision ),
277
179
allow_patterns = allow_patterns ,
278
180
ignore_patterns = ignore_patterns ,
279
181
)
0 commit comments