Skip to content

Commit 9accdad

Browse files
committed
update to download metadata and preview image too
1 parent acc46c4 commit 9accdad

File tree

3 files changed

+150
-15
lines changed

3 files changed

+150
-15
lines changed

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,9 @@ A Simple and easy tool to download models from [Civitai](https://civitai.com/)
1111
3. Run the binary and download a model you like. That's it!
1212
```bash
1313
./civitai-downloader <sub_dir> <model_url|AIR>
14-
```
14+
```
15+
16+
## TODO
17+
18+
- [v] Add support for downloading models from [Civitai](https://civitai.com/)
19+
- [v] Download related image and metadata along with the model.

downloader/downloader.go

+134-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package downloader
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"io"
67
"net/http"
@@ -11,20 +12,55 @@ import (
1112
"github.com/schollz/progressbar/v3"
1213
)
1314

14-
func DownloadFile(outputPath string, url string) error {
15-
dir := filepath.Dir(outputPath)
16-
if err := os.MkdirAll(dir, 0755); err != nil {
17-
return fmt.Errorf("failed to create directories: %v", err)
15+
// ModelVersion represents the model version information from the Civitai API
16+
type ModelVersion struct {
17+
ID int64 `json:"id"`
18+
ModelID int64 `json:"modelId"`
19+
Name string `json:"name"`
20+
Files []File `json:"files"`
21+
Images []Image `json:"images"`
22+
Description string `json:"description"`
23+
DownloadURL string `json:"downloadUrl"`
24+
}
25+
26+
// File represents a file associated with the model version
27+
type File struct {
28+
ID int64 `json:"id"`
29+
SizeKB float64 `json:"sizeKB"`
30+
Name string `json:"name"`
31+
Type string `json:"type"`
32+
DownloadURL string `json:"downloadUrl"`
33+
}
34+
35+
// Image represents an image associated with the model version
36+
type Image struct {
37+
URL string `json:"url"`
38+
Type string `json:"type"`
39+
Width int `json:"width"`
40+
Height int `json:"height"`
41+
}
42+
43+
const (
44+
APIModelVersions = "https://civitai.com/api/v1/model-versions/"
45+
)
46+
47+
// DownloadFile downloads a single file from the given URL to the specified path and returns the saved filename
48+
func DownloadFile(outputPath, url, modelVersionId string) (string, error) {
49+
outputDir := filepath.Dir(outputPath)
50+
if _, err := os.Stat(outputDir); os.IsNotExist(err) {
51+
if err := os.MkdirAll(outputDir, 0755); err != nil {
52+
return "", fmt.Errorf("failed to create directory: %w", err)
53+
}
1854
}
1955

2056
resp, err := http.Get(url)
2157
if err != nil {
22-
return fmt.Errorf("failed to fetch %s: %v", url, err)
58+
return "", fmt.Errorf("failed to download file: %w", err)
2359
}
2460
defer resp.Body.Close()
2561

2662
if resp.StatusCode != http.StatusOK {
27-
return fmt.Errorf("HTTP error for %s: %v", url, resp.StatusCode)
63+
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
2864
}
2965

3066
header := resp.Header.Get("content-disposition")
@@ -37,7 +73,7 @@ func DownloadFile(outputPath string, url string) error {
3773

3874
file, err := os.Create(outputPath)
3975
if err != nil {
40-
return fmt.Errorf("failed to create file %s: %v", outputPath, err)
76+
return "", fmt.Errorf("failed to create file: %w", err)
4177
}
4278
defer file.Close()
4379

@@ -57,5 +93,95 @@ func DownloadFile(outputPath string, url string) error {
5793
)
5894

5995
_, err = io.Copy(io.MultiWriter(file, bar), resp.Body)
60-
return err
96+
if err != nil {
97+
return "", fmt.Errorf("failed to write to file: %w", err)
98+
}
99+
100+
return outputPath, nil
101+
}
102+
103+
// DownloadAll downloads all available files for a given model ID
104+
func DownloadAll(modelType, baseModelPath, modelID string) error {
105+
modelURL := fmt.Sprintf("%s%s", APIModelVersions, modelID)
106+
resp, err := http.Get(modelURL)
107+
if err != nil {
108+
return fmt.Errorf("failed to fetch model version: %w", err)
109+
}
110+
defer resp.Body.Close()
111+
112+
if resp.StatusCode != http.StatusOK {
113+
return fmt.Errorf("failed to fetch model version. Status code: %d", resp.StatusCode)
114+
}
115+
116+
var modelVersion ModelVersion
117+
err = json.NewDecoder(resp.Body).Decode(&modelVersion)
118+
if err != nil {
119+
return fmt.Errorf("failed to decode model version response: %w", err)
120+
}
121+
122+
// Create a subdirectory based on modelType
123+
dir := filepath.Join(baseModelPath, modelType)
124+
if err := os.MkdirAll(dir, 0755); err != nil {
125+
return fmt.Errorf("failed to create model type directory: %w", err)
126+
}
127+
128+
outputPath, err := DownloadFile(filepath.Join(dir, filepath.Base(baseModelPath)), modelVersion.DownloadURL, modelID)
129+
if err != nil {
130+
return fmt.Errorf("failed to download model file: %w", err)
131+
}
132+
133+
baseName := strings.TrimSuffix(filepath.Base(outputPath), filepath.Ext(outputPath))
134+
135+
for _, image := range modelVersion.Images {
136+
if image.Type == "image" {
137+
imgPath := fmt.Sprintf("%s.preview.png", filepath.Join(filepath.Dir(outputPath), baseName))
138+
if _, err := DownloadFile(imgPath, image.URL, modelID); err != nil {
139+
return fmt.Errorf("failed to download image: %w", err)
140+
}
141+
} else if image.Type == "video" {
142+
imgPath := fmt.Sprintf("%s.preview.mp4", filepath.Join(filepath.Dir(outputPath), baseName))
143+
if _, err := DownloadFile(imgPath, image.URL, modelID); err != nil {
144+
return fmt.Errorf("failed to download image: %w", err)
145+
}
146+
}
147+
}
148+
149+
metadataPath := fmt.Sprintf("%s.civitai.info", filepath.Join(filepath.Dir(outputPath), baseName))
150+
metadata, err := json.MarshalIndent(modelVersion, "", " ")
151+
if err != nil {
152+
return fmt.Errorf("failed to marshal metadata: %w", err)
153+
}
154+
if err := os.WriteFile(metadataPath, metadata, 0644); err != nil {
155+
return fmt.Errorf("failed to save metadata: %w", err)
156+
}
157+
158+
if modelVersion.Description != "" {
159+
descPath := fmt.Sprintf("%s.description.txt", filepath.Join(filepath.Dir(outputPath), baseName))
160+
if err := os.WriteFile(descPath, []byte(modelVersion.Description), 0644); err != nil {
161+
return fmt.Errorf("failed to save description: %w", err)
162+
}
163+
}
164+
165+
fmt.Printf("Successfully downloaded model files to %s\n", filepath.Dir(outputPath))
166+
return nil
167+
}
168+
169+
// GetModelID retrieves the model ID from the URL
170+
func GetModelID(url string) (string, error) {
171+
resp, err := http.Get(url)
172+
if err != nil {
173+
return "", fmt.Errorf("failed to fetch model page: %w", err)
174+
}
175+
defer resp.Body.Close()
176+
177+
if resp.StatusCode != http.StatusOK {
178+
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
179+
}
180+
181+
parts := url[len("https://civitai.com/models/"):]
182+
if len(parts) == 0 {
183+
return "", fmt.Errorf("invalid URL format")
184+
}
185+
modelID := parts
186+
return modelID, nil
61187
}

main.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"os"
66
"path/filepath"
7+
"regexp"
78
"strings"
89
"time"
910

@@ -44,22 +45,25 @@ func main() {
4445
fmt.Println("Base model path not specified, using current directory")
4546
}
4647

47-
var downloadURL string
48+
var modelVersionId string
4849
if strings.HasPrefix(modelIdentifier, "urn:air:") {
4950
parts := strings.Split(modelIdentifier, ":")
5051
if modelType == "" {
5152
modelType = parts[3]
5253
}
5354
modelInfo := strings.Split(parts[len(parts)-1], "@")
54-
version := modelInfo[1]
55-
downloadURL = fmt.Sprintf("https://civitai.com/api/download/models/%s?token=%s", version, token)
55+
modelVersionId = modelInfo[1]
5656
} else {
57-
downloadURL = modelIdentifier + "?token=" + token
57+
re := regexp.MustCompile(`https://civitai.com/models/(\d+)`)
58+
matches := re.FindStringSubmatch(modelIdentifier)
59+
if len(matches) == 2 {
60+
modelVersionId = matches[1]
61+
}
5862
}
59-
63+
fmt.Println("Model version ID:", modelVersionId)
6064
outputPath := filepath.Join(baseModelPath, modelType, fmt.Sprintf("temp-%d.safetensors", time.Now().UnixNano()))
6165

62-
err = downloader.DownloadFile(outputPath, downloadURL)
66+
err = downloader.DownloadAll(modelType, baseModelPath, modelVersionId)
6367
if err != nil {
6468
fmt.Printf("Error downloading %s: %v\n", modelType, err)
6569
return

0 commit comments

Comments
 (0)