|
1 | 1 | import datetime
|
| 2 | +import json |
2 | 3 | import os
|
3 | 4 | import subprocess
|
4 | 5 | import sys
|
5 | 6 | from pathlib import Path
|
6 | 7 | from typing import Optional
|
7 | 8 |
|
| 9 | +import boto3 |
| 10 | +from botocore.client import Config |
8 | 11 | from google.cloud import storage
|
9 | 12 | from huggingface_hub import hf_hub_download
|
10 | 13 |
|
@@ -32,52 +35,105 @@ def _download_from_url_using_b10cp(
|
32 | 35 | )
|
33 | 36 |
|
34 | 37 |
|
35 |
| -def split_gs_path(gs_path): |
| 38 | +def split_path(path, prefix="gs://"): |
36 | 39 | # Remove the 'gs://' prefix
|
37 |
| - path = gs_path.replace("gs://", "") |
| 40 | + path = path.replace(prefix, "") |
38 | 41 |
|
39 | 42 | # Split on the first slash
|
40 | 43 | parts = path.split("/", 1)
|
41 | 44 |
|
42 | 45 | bucket_name = parts[0]
|
43 |
| - prefix = parts[1] if len(parts) > 1 else "" |
| 46 | + path = parts[1] if len(parts) > 1 else "" |
44 | 47 |
|
45 |
| - return bucket_name, prefix |
| 48 | + return bucket_name, path |
| 49 | + |
| 50 | + |
| 51 | +def parse_s3_service_account_file(file_path): |
| 52 | + # open the json file |
| 53 | + with open(file_path, "r") as f: |
| 54 | + data = json.load(f) |
| 55 | + |
| 56 | + # validate the data |
| 57 | + if "aws_access_key_id" not in data or "aws_secret_access_key" not in data: |
| 58 | + raise ValueError("Invalid AWS credentials file") |
| 59 | + |
| 60 | + # parse the data |
| 61 | + aws_access_key_id = data["aws_access_key_id"] |
| 62 | + aws_secret_access_key = data["aws_secret_access_key"] |
| 63 | + aws_region = data["aws_region"] |
| 64 | + |
| 65 | + return aws_access_key_id, aws_secret_access_key, aws_region |
46 | 66 |
|
47 | 67 |
|
48 | 68 | def download_file(
|
49 | 69 | repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json"
|
50 | 70 | ):
|
51 | 71 | # Check if repo_name starts with "gs://"
|
52 |
| - if "gs://" in repo_name: |
| 72 | + if repo_name.startswith(("gs://", "s3://")): |
| 73 | + prefix = repo_name[:5] |
| 74 | + |
53 | 75 | # Create directory if not exist
|
54 |
| - bucket_name, _ = split_gs_path(repo_name) |
55 |
| - repo_name = repo_name.replace("gs://", "") |
| 76 | + bucket_name, _ = split_path(repo_name, prefix=prefix) |
| 77 | + repo_name = repo_name.replace(prefix, "") |
56 | 78 | cache_dir = Path(f"/app/hf_cache/{bucket_name}")
|
57 | 79 | cache_dir.mkdir(parents=True, exist_ok=True)
|
58 | 80 |
|
59 |
| - # Connect to GCS storage |
60 |
| - storage_client = storage.Client.from_service_account_json(key_file) |
61 |
| - bucket = storage_client.bucket(bucket_name) |
62 |
| - blob = bucket.blob(file_name) |
| 81 | + if prefix == "gs://": |
| 82 | + # Connect to GCS storage |
| 83 | + storage_client = storage.Client.from_service_account_json(key_file) |
| 84 | + bucket = storage_client.bucket(bucket_name) |
| 85 | + blob = bucket.blob(file_name) |
63 | 86 |
|
64 |
| - dst_file = Path(f"{cache_dir}/{file_name}") |
65 |
| - if not dst_file.parent.exists(): |
66 |
| - dst_file.parent.mkdir(parents=True) |
| 87 | + dst_file = Path(f"{cache_dir}/{file_name}") |
| 88 | + if not dst_file.parent.exists(): |
| 89 | + dst_file.parent.mkdir(parents=True) |
67 | 90 |
|
68 |
| - if not blob.exists(storage_client): |
69 |
| - raise RuntimeError(f"File not found on GCS bucket: {blob.name}") |
| 91 | + if not blob.exists(storage_client): |
| 92 | + raise RuntimeError(f"File not found on GCS bucket: {blob.name}") |
70 | 93 |
|
71 |
| - url = blob.generate_signed_url( |
72 |
| - version="v4", |
73 |
| - expiration=datetime.timedelta(minutes=15), |
74 |
| - method="GET", |
75 |
| - ) |
76 |
| - try: |
77 |
| - proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) |
78 |
| - proc.wait() |
79 |
| - except Exception as e: |
80 |
| - raise RuntimeError(f"Failure downloading file from GCS: {e}") |
| 94 | + url = blob.generate_signed_url( |
| 95 | + version="v4", |
| 96 | + expiration=datetime.timedelta(minutes=15), |
| 97 | + method="GET", |
| 98 | + ) |
| 99 | + try: |
| 100 | + proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) |
| 101 | + proc.wait() |
| 102 | + except Exception as e: |
| 103 | + raise RuntimeError(f"Failure downloading file from GCS: {e}") |
| 104 | + elif prefix == "s3://": |
| 105 | + ( |
| 106 | + AWS_ACCESS_KEY_ID, |
| 107 | + AWS_SECRET_ACCESS_KEY, |
| 108 | + AWS_REGION, |
| 109 | + ) = parse_s3_service_account_file(key_file) |
| 110 | + client = boto3.client( |
| 111 | + "s3", |
| 112 | + aws_access_key_id=AWS_ACCESS_KEY_ID, |
| 113 | + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, |
| 114 | + region_name=AWS_REGION, |
| 115 | + config=Config(signature_version="s3v4"), |
| 116 | + ) |
| 117 | + bucket_name, _ = split_path(bucket_name, prefix="s3://") |
| 118 | + |
| 119 | + dst_file = Path(f"{cache_dir}/{file_name}") |
| 120 | + if not dst_file.parent.exists(): |
| 121 | + dst_file.parent.mkdir(parents=True) |
| 122 | + |
| 123 | + try: |
| 124 | + url = client.generate_presigned_url( |
| 125 | + "get_object", |
| 126 | + Params={"Bucket": bucket_name, "Key": file_name}, |
| 127 | + ExpiresIn=3600, |
| 128 | + ) |
| 129 | + except Exception: |
| 130 | + raise RuntimeError(f"File not found on S3 bucket: {file_name}") |
| 131 | + |
| 132 | + try: |
| 133 | + proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file) |
| 134 | + proc.wait() |
| 135 | + except Exception as e: |
| 136 | + raise RuntimeError(f"Failure downloading file from S3: {e}") |
81 | 137 | else:
|
82 | 138 | secret_path = Path("/etc/secrets/hf-access-token")
|
83 | 139 | secret = secret_path.read_text().strip() if secret_path.exists() else None
|
|
0 commit comments