Skip to content

Commit b4602ef

Browse files
committed
refactor
1 parent e4d70c7 commit b4602ef

File tree

3 files changed

+96
-80
lines changed

3 files changed

+96
-80
lines changed

config/config.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package config
2+
3+
import (
4+
"os"
5+
6+
"gopkg.in/yaml.v3"
7+
)
8+
9+
type Config struct {
10+
Civitai struct {
11+
Token string `yaml:"token"`
12+
} `yaml:"civitai"`
13+
ComfyUI struct {
14+
BaseModelPath string `yaml:"base_model_path"`
15+
} `yaml:"comfyui"`
16+
}
17+
18+
func LoadConfig(filename string) (*Config, error) {
19+
data, err := os.ReadFile(filename)
20+
if err != nil {
21+
return nil, err
22+
}
23+
24+
config := &Config{}
25+
err = yaml.Unmarshal(data, config)
26+
if err != nil {
27+
return nil, err
28+
}
29+
30+
return config, nil
31+
}

downloader/downloader.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package downloader
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"os"
8+
"path/filepath"
9+
"strings"
10+
11+
"github.com/schollz/progressbar/v3"
12+
)
13+
14+
func DownloadFile(outputPath string, url string) error {
15+
dir := filepath.Dir(outputPath)
16+
if err := os.MkdirAll(dir, 0755); err != nil {
17+
return fmt.Errorf("failed to create directories: %v", err)
18+
}
19+
20+
resp, err := http.Get(url)
21+
if err != nil {
22+
return fmt.Errorf("failed to fetch %s: %v", url, err)
23+
}
24+
defer resp.Body.Close()
25+
26+
if resp.StatusCode != http.StatusOK {
27+
return fmt.Errorf("HTTP error for %s: %v", url, resp.StatusCode)
28+
}
29+
30+
header := resp.Header.Get("content-disposition")
31+
if header != "" {
32+
parts := strings.Split(header, "filename=")
33+
if len(parts) > 1 {
34+
outputPath = filepath.Join(filepath.Dir(outputPath), strings.Trim(parts[1], "\""))
35+
}
36+
}
37+
38+
file, err := os.Create(outputPath)
39+
if err != nil {
40+
return fmt.Errorf("failed to create file %s: %v", outputPath, err)
41+
}
42+
defer file.Close()
43+
44+
bar := progressbar.NewOptions(
45+
int(resp.ContentLength),
46+
progressbar.OptionSetWidth(15),
47+
progressbar.OptionEnableColorCodes(true),
48+
progressbar.OptionSetDescription("[Downloading] "),
49+
progressbar.OptionSetTheme(
50+
progressbar.Theme{
51+
Saucer: "[green]=[reset]",
52+
SaucerHead: "[green]>[reset]",
53+
SaucerPadding: " ",
54+
BarStart: "|",
55+
BarEnd: "|",
56+
}),
57+
)
58+
59+
_, err = io.Copy(io.MultiWriter(file, bar), resp.Body)
60+
return err
61+
}

main.go

+4-80
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,15 @@ package main
22

33
import (
44
"fmt"
5-
"io"
6-
"net/http"
75
"os"
86
"path/filepath"
97
"strings"
108
"time"
119

12-
"gopkg.in/yaml.v3"
13-
14-
"github.com/schollz/progressbar/v3"
10+
"github.com/yansigit/civitai-downloader/config"
11+
"github.com/yansigit/civitai-downloader/downloader"
1512
)
1613

17-
type Config struct {
18-
Civitai struct {
19-
Token string `yaml:"token"`
20-
} `yaml:"civitai"`
21-
ComfyUI struct {
22-
BaseModelPath string `yaml:"base_model_path"`
23-
} `yaml:"comfyui"`
24-
}
25-
2614
func main() {
2715
if len(os.Args) != 3 {
2816
fmt.Println("Usage: ", os.Args[0], " <model_type> <model_url|AIR>")
@@ -38,7 +26,7 @@ func main() {
3826
}
3927

4028
configPath := os.Getenv("HOME") + "/.civitai-downloader/config.yaml"
41-
config, err := loadConfig(configPath)
29+
config, err := config.LoadConfig(configPath)
4230
if err != nil {
4331
fmt.Println("Error loading config:", err)
4432
return
@@ -71,75 +59,11 @@ func main() {
7159

7260
outputPath := filepath.Join(baseModelPath, modelType, fmt.Sprintf("temp-%d.safetensors", time.Now().UnixNano()))
7361

74-
err = downloadFile(outputPath, downloadURL)
62+
err = downloader.DownloadFile(outputPath, downloadURL)
7563
if err != nil {
7664
fmt.Printf("Error downloading %s: %v\n", modelType, err)
7765
return
7866
}
7967

8068
fmt.Printf("Model downloaded successfully to %s\n", filepath.Dir(outputPath))
8169
}
82-
83-
func loadConfig(filename string) (*Config, error) {
84-
data, err := os.ReadFile(filename)
85-
if err != nil {
86-
return nil, err
87-
}
88-
89-
config := &Config{}
90-
err = yaml.Unmarshal(data, config)
91-
if err != nil {
92-
return nil, err
93-
}
94-
95-
return config, nil
96-
}
97-
98-
func downloadFile(outputPath string, url string) error {
99-
dir := filepath.Dir(outputPath)
100-
if err := os.MkdirAll(dir, 0755); err != nil {
101-
return fmt.Errorf("failed to create directories: %v", err)
102-
}
103-
104-
resp, err := http.Get(url)
105-
if err != nil {
106-
return fmt.Errorf("failed to fetch %s: %v", url, err)
107-
}
108-
defer resp.Body.Close()
109-
110-
if resp.StatusCode != http.StatusOK {
111-
return fmt.Errorf("HTTP error for %s: %v", url, resp.StatusCode)
112-
}
113-
114-
header := resp.Header.Get("content-disposition")
115-
if header != "" {
116-
parts := strings.Split(header, "filename=")
117-
if len(parts) > 1 {
118-
outputPath = filepath.Join(filepath.Dir(outputPath), strings.Trim(parts[1], "\""))
119-
}
120-
}
121-
122-
file, err := os.Create(outputPath)
123-
if err != nil {
124-
return fmt.Errorf("failed to create file %s: %v", outputPath, err)
125-
}
126-
defer file.Close()
127-
128-
bar := progressbar.NewOptions(
129-
int(resp.ContentLength),
130-
progressbar.OptionSetWidth(15),
131-
progressbar.OptionEnableColorCodes(true),
132-
progressbar.OptionSetDescription("[Downloading] "),
133-
progressbar.OptionSetTheme(
134-
progressbar.Theme{
135-
Saucer: "[green]=[reset]",
136-
SaucerHead: "[green]>[reset]",
137-
SaucerPadding: " ",
138-
BarStart: "|",
139-
BarEnd: "|",
140-
}),
141-
)
142-
143-
_, err = io.Copy(io.MultiWriter(file, bar), resp.Body)
144-
return err
145-
}

0 commit comments

Comments
 (0)