1
1
package downloader
2
2
3
3
import (
4
+ "encoding/json"
4
5
"fmt"
5
6
"io"
6
7
"net/http"
@@ -11,20 +12,55 @@ import (
11
12
"github.com/schollz/progressbar/v3"
12
13
)
13
14
14
- func DownloadFile (outputPath string , url string ) error {
15
- dir := filepath .Dir (outputPath )
16
- if err := os .MkdirAll (dir , 0755 ); err != nil {
17
- return fmt .Errorf ("failed to create directories: %v" , err )
15
+ // ModelVersion represents the model version information from the Civitai API
16
+ type ModelVersion struct {
17
+ ID int64 `json:"id"`
18
+ ModelID int64 `json:"modelId"`
19
+ Name string `json:"name"`
20
+ Files []File `json:"files"`
21
+ Images []Image `json:"images"`
22
+ Description string `json:"description"`
23
+ DownloadURL string `json:"downloadUrl"`
24
+ }
25
+
26
+ // File represents a file associated with the model version
27
+ type File struct {
28
+ ID int64 `json:"id"`
29
+ SizeKB float64 `json:"sizeKB"`
30
+ Name string `json:"name"`
31
+ Type string `json:"type"`
32
+ DownloadURL string `json:"downloadUrl"`
33
+ }
34
+
35
+ // Image represents an image associated with the model version
36
+ type Image struct {
37
+ URL string `json:"url"`
38
+ Type string `json:"type"`
39
+ Width int `json:"width"`
40
+ Height int `json:"height"`
41
+ }
42
+
43
+ const (
44
+ APIModelVersions = "https://civitai.com/api/v1/model-versions/"
45
+ )
46
+
47
+ // DownloadFile downloads a single file from the given URL to the specified path and returns the saved filename
48
+ func DownloadFile (outputPath , url , modelVersionId string ) (string , error ) {
49
+ outputDir := filepath .Dir (outputPath )
50
+ if _ , err := os .Stat (outputDir ); os .IsNotExist (err ) {
51
+ if err := os .MkdirAll (outputDir , 0755 ); err != nil {
52
+ return "" , fmt .Errorf ("failed to create directory: %w" , err )
53
+ }
18
54
}
19
55
20
56
resp , err := http .Get (url )
21
57
if err != nil {
22
- return fmt .Errorf ("failed to fetch %s : %v" , url , err )
58
+ return "" , fmt .Errorf ("failed to download file : %w" , err )
23
59
}
24
60
defer resp .Body .Close ()
25
61
26
62
if resp .StatusCode != http .StatusOK {
27
- return fmt .Errorf ("HTTP error for %s : %v" , url , resp .StatusCode )
63
+ return "" , fmt .Errorf ("unexpected status code : %d" , resp .StatusCode )
28
64
}
29
65
30
66
header := resp .Header .Get ("content-disposition" )
@@ -37,7 +73,7 @@ func DownloadFile(outputPath string, url string) error {
37
73
38
74
file , err := os .Create (outputPath )
39
75
if err != nil {
40
- return fmt .Errorf ("failed to create file %s : %v" , outputPath , err )
76
+ return "" , fmt .Errorf ("failed to create file: %w" , err )
41
77
}
42
78
defer file .Close ()
43
79
@@ -57,5 +93,95 @@ func DownloadFile(outputPath string, url string) error {
57
93
)
58
94
59
95
_ , err = io .Copy (io .MultiWriter (file , bar ), resp .Body )
60
- return err
96
+ if err != nil {
97
+ return "" , fmt .Errorf ("failed to write to file: %w" , err )
98
+ }
99
+
100
+ return outputPath , nil
101
+ }
102
+
103
+ // DownloadAll downloads all available files for a given model ID
104
+ func DownloadAll (modelType , baseModelPath , modelID string ) error {
105
+ modelURL := fmt .Sprintf ("%s%s" , APIModelVersions , modelID )
106
+ resp , err := http .Get (modelURL )
107
+ if err != nil {
108
+ return fmt .Errorf ("failed to fetch model version: %w" , err )
109
+ }
110
+ defer resp .Body .Close ()
111
+
112
+ if resp .StatusCode != http .StatusOK {
113
+ return fmt .Errorf ("failed to fetch model version. Status code: %d" , resp .StatusCode )
114
+ }
115
+
116
+ var modelVersion ModelVersion
117
+ err = json .NewDecoder (resp .Body ).Decode (& modelVersion )
118
+ if err != nil {
119
+ return fmt .Errorf ("failed to decode model version response: %w" , err )
120
+ }
121
+
122
+ // Create a subdirectory based on modelType
123
+ dir := filepath .Join (baseModelPath , modelType )
124
+ if err := os .MkdirAll (dir , 0755 ); err != nil {
125
+ return fmt .Errorf ("failed to create model type directory: %w" , err )
126
+ }
127
+
128
+ outputPath , err := DownloadFile (filepath .Join (dir , filepath .Base (baseModelPath )), modelVersion .DownloadURL , modelID )
129
+ if err != nil {
130
+ return fmt .Errorf ("failed to download model file: %w" , err )
131
+ }
132
+
133
+ baseName := strings .TrimSuffix (filepath .Base (outputPath ), filepath .Ext (outputPath ))
134
+
135
+ for _ , image := range modelVersion .Images {
136
+ if image .Type == "image" {
137
+ imgPath := fmt .Sprintf ("%s.preview.png" , filepath .Join (filepath .Dir (outputPath ), baseName ))
138
+ if _ , err := DownloadFile (imgPath , image .URL , modelID ); err != nil {
139
+ return fmt .Errorf ("failed to download image: %w" , err )
140
+ }
141
+ } else if image .Type == "video" {
142
+ imgPath := fmt .Sprintf ("%s.preview.mp4" , filepath .Join (filepath .Dir (outputPath ), baseName ))
143
+ if _ , err := DownloadFile (imgPath , image .URL , modelID ); err != nil {
144
+ return fmt .Errorf ("failed to download image: %w" , err )
145
+ }
146
+ }
147
+ }
148
+
149
+ metadataPath := fmt .Sprintf ("%s.civitai.info" , filepath .Join (filepath .Dir (outputPath ), baseName ))
150
+ metadata , err := json .MarshalIndent (modelVersion , "" , " " )
151
+ if err != nil {
152
+ return fmt .Errorf ("failed to marshal metadata: %w" , err )
153
+ }
154
+ if err := os .WriteFile (metadataPath , metadata , 0644 ); err != nil {
155
+ return fmt .Errorf ("failed to save metadata: %w" , err )
156
+ }
157
+
158
+ if modelVersion .Description != "" {
159
+ descPath := fmt .Sprintf ("%s.description.txt" , filepath .Join (filepath .Dir (outputPath ), baseName ))
160
+ if err := os .WriteFile (descPath , []byte (modelVersion .Description ), 0644 ); err != nil {
161
+ return fmt .Errorf ("failed to save description: %w" , err )
162
+ }
163
+ }
164
+
165
+ fmt .Printf ("Successfully downloaded model files to %s\n " , filepath .Dir (outputPath ))
166
+ return nil
167
+ }
168
+
169
+ // GetModelID retrieves the model ID from the URL
170
+ func GetModelID (url string ) (string , error ) {
171
+ resp , err := http .Get (url )
172
+ if err != nil {
173
+ return "" , fmt .Errorf ("failed to fetch model page: %w" , err )
174
+ }
175
+ defer resp .Body .Close ()
176
+
177
+ if resp .StatusCode != http .StatusOK {
178
+ return "" , fmt .Errorf ("unexpected status code: %d" , resp .StatusCode )
179
+ }
180
+
181
+ parts := url [len ("https://civitai.com/models/" ):]
182
+ if len (parts ) == 0 {
183
+ return "" , fmt .Errorf ("invalid URL format" )
184
+ }
185
+ modelID := parts
186
+ return modelID , nil
61
187
}
0 commit comments