ethix commited on
Commit
c441f1b
·
0 Parent(s):

major: init

Browse files
Files changed (9) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +275 -0
  4. civitai_api.py +446 -0
  5. civitai_constants.py +26 -0
  6. null.png +0 -0
  7. packages.txt +1 -0
  8. requirements.txt +3 -0
  9. utils.py +299 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Civitai Asset Migration and Archiving Tool
3
+ emoji: 🤗
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_modal import Modal
3
+ from civitai_api import (search_civitai, download_civitai, select_civitai_item, add_civitai_item, get_civitai_tag, select_civitai_all_item,
4
+ update_civitai_selection, update_civitai_checkbox, from_civitai_checkbox)
5
+ from civitai_constants import (TYPE, BASEMODEL, SORT, PERIOD, FILETYPE)
6
+ import time
7
+
8
+ css = """
9
+ .title { font-size: 3em; align-items: center; text-align: center; }
10
+ .info { align-items: center; text-align: center; }
11
+ .desc [src$='#float'] { float: right; margin: 20px; }
12
+ #modal-window{ position: absolute; left: 50%; top: 50%; transform: translate(-50%, -50%); max-height: inherit; }
13
+ .modal-container{ max-width:60vw; height:100vh; }
14
+ .prose.desc{ padding: 0px 20px; }
15
+ /* Video hover styles */
16
+ .media-container {
17
+ position: relative;
18
+ width: 100%;
19
+ height: 100%;
20
+ }
21
+ .result{ padding: 16px; }
22
+
23
+ .media-container img,
24
+ .media-container video {
25
+ width: 100%;
26
+ height: 100%;
27
+ object-fit: contain;
28
+ }
29
+
30
+ .media-container video {
31
+ display: none;
32
+ position: absolute;
33
+ top: 0;
34
+ left: 0;
35
+ z-index: 2;
36
+ }
37
+
38
+ .gallery-item:hover .media-container video {
39
+ display: block;
40
+ }
41
+
42
+ .gallery-item:hover .media-container img {
43
+ opacity: 0.3;
44
+ }
45
+
46
+ .gallery-item {
47
+ position: relative;
48
+ cursor: pointer;
49
+ overflow: hidden;
50
+ }
51
+ """
52
+
53
+ # js = """
54
+ # (() => {
55
+ # function setupGalleryHover() {
56
+ # const gallery = document.getElementById('gallery');
57
+ # if (!gallery) return;
58
+
59
+ # function handleMediaContainer(container) {
60
+ # const video = container.querySelector('video');
61
+ # if (!video) return;
62
+
63
+ # video.muted = true;
64
+ # video.preload = "metadata";
65
+
66
+ # container.addEventListener('mouseenter', () => {
67
+ # if (video.paused) {
68
+ # video.play().catch(() => {});
69
+ # }
70
+ # });
71
+
72
+ # container.addEventListener('mouseleave', () => {
73
+ # if (!video.paused) {
74
+ # video.pause();
75
+ # video.currentTime = 0;
76
+ # }
77
+ # });
78
+ # }
79
+
80
+ # function setupContainers() {
81
+ # const containers = gallery.querySelectorAll('.media-container');
82
+ # containers.forEach(handleMediaContainer);
83
+ # }
84
+
85
+ # setupContainers();
86
+
87
+ # const observer = new MutationObserver(setupContainers);
88
+ # observer.observe(gallery, { childList: true, subtree: true });
89
+ # }
90
+
91
+ # if (document.readyState === 'loading') {
92
+ # document.addEventListener('DOMContentLoaded', setupGalleryHover);
93
+ # } else {
94
+ # setupGalleryHover();
95
+ # }
96
+ # })();
97
+ # """
98
+
99
+ common_sense = """
100
+ ## The Common Sense Agreement
101
+
102
+ This tool was created in response to recent censorship and model deactivations on Civitai. This Space should be treated as a tool for artists and creators to easily backup and archive models -- this is NOT a tool to blindly scrape and download every asset your heart desires. For those that wish to do so, may do so on their own local machines. Users of the tool are responsible for their own actions - you have been warned.
103
+
104
+ By continuing on to the tool, you are assumed to have **acknowledged, understood, and are in agreeance** with the following common sense keypoints:
105
+
106
+ 1. **Storage is not free.** You are responsible for understanding your storage limits and the costs that will be incurred if you exceed your tier's free allocation.
107
+ 2. **Bandwidth is not free.** Misuse of the tool will likely be frowned upon by both Huggingface as well as Civitai, potentially leading to negative consequences for all. Don't be that guy. There are better ways if you must do so.
108
+ 3. **Not all models are created equally.** Respect the licensing agreements and the wishes of creators, both big and small.
109
+ 4. **Respect the FOSS community**, give back where you can, and be aware of your surroundings. Protect this industry as best you can, otherwise prepare for a wicked spanking when AGI is here.
110
+
111
+ ### For those with common sense, please continue.
112
+ """
113
+
114
+
115
+ with gr.Blocks(fill_width=True, css=css, delete_cache=(3600, 3600)) as demo:
116
+ gr.Markdown("# 🚨 CivitAI Models & LoRA Conservation Project 📦", elem_classes="title")
117
+
118
+ state = gr.State(value={})
119
+
120
+ # Create browser state for API keys with error handling
121
+ api_keys_state = gr.BrowserState(["", ""], storage_key="civitai2hf_api_keys")
122
+
123
+ with gr.Row():
124
+ show_btn = gr.Button("Show Modal")
125
+ with gr.Tabs() as tabs:
126
+ with gr.TabItem("Search & Download"):
127
+ with gr.Row():
128
+ # Left column for filters (1/3 width)
129
+ with gr.Column(scale=1):
130
+ with gr.Group():
131
+ gr.Markdown("## 2️⃣ Search Filters", container=True)
132
+ with gr.Accordion("Type & Model", open=True):
133
+ search_civitai_type = gr.CheckboxGroup(label="Type", choices=TYPE, value=["Checkpoint", "LORA"])
134
+ search_civitai_basemodel = gr.CheckboxGroup(label="Base Model", choices=BASEMODEL, value=[])
135
+ search_civitai_filetype = gr.CheckboxGroup(label="File type", visible=False, choices=FILETYPE, value=["Model"])
136
+ with gr.Accordion("Search Options", open=True):
137
+ search_civitai_sort = gr.Radio(label="Sort", choices=SORT, value=SORT[0])
138
+ search_civitai_period = gr.Radio(label="Period", choices=PERIOD, value="Month")
139
+ search_civitai_limit = gr.Slider(0, 100, value=50, step=10, label="Limit", info="Maximum items to query per page via API.")
140
+ search_civitai_page = gr.Slider(0, 10, 1, step=1, label="Num Pages", info="If 0, retrieve all pages")
141
+
142
+ # Right column for results and download (2/3 width)
143
+ with gr.Column(scale=2):
144
+ with gr.Group():
145
+ gr.Markdown("## 3️⃣ Search Query", container=True)
146
+ with gr.Row():
147
+ search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
148
+ search_civitai_tag = gr.Dropdown(label="Tag", choices=get_civitai_tag(), value=get_civitai_tag()[0], allow_custom_value=True)
149
+ search_civitai_user = gr.Textbox(label="Username", lines=1)
150
+ search_civitai_submit = gr.Button("Search on Civitai", variant="primary")
151
+
152
+ with gr.Group():
153
+ gr.Markdown("## 4️⃣ Select Models to Backup from Search Results", container=True)
154
+ with gr.Row():
155
+ search_civitai_desc = gr.Markdown(value="", visible=False, elem_classes="desc")
156
+ search_civitai_json = gr.JSON(value={}, visible=False)
157
+ with gr.Row(equal_height=True):
158
+ with gr.Column(scale=9):
159
+ with gr.Accordion("Gallery View ⚠️ Surprise NSFW + RAM Warning ⚠️", open=False):
160
+ search_civitai_gallery = gr.Gallery(
161
+ [],
162
+ label="Select from Results",
163
+ allow_preview=False,
164
+ columns=5,
165
+ elem_id="gallery",
166
+ object_fit="contain",
167
+ show_share_button=False,
168
+ interactive=True,
169
+ preview=False,
170
+ height="auto"
171
+ )
172
+ with gr.Accordion("List View", open=False):
173
+ search_civitai_result_checkbox = gr.CheckboxGroup(label="", choices=[], value=[])
174
+ search_civitai_result = gr.Dropdown(label="Selected Models from Search", choices=[("", "")], value=[],
175
+ allow_custom_value=True, visible=True, multiselect=True)
176
+ search_civitai_result_info = gr.Markdown("Standing by.", elem_classes="info")
177
+ with gr.Row():
178
+ search_civitai_add = gr.Button("5️⃣ Add Selected Models to DL List", variant="primary")
179
+ search_civitai_select_all = gr.Button("Select All", variant="secondary")
180
+
181
+ with gr.Group():
182
+ gr.Markdown("## 6️⃣🚨 Ensure you have set your correct API key settings before moving on.", container=True)
183
+ dl_url = gr.Textbox(label="Download URL(s)", placeholder="https://civitai.com/api/download/models/28907\n...", value="", lines=3, max_lines=255, interactive=True)
184
+ with gr.Row():
185
+ newrepo_id = gr.Textbox(label="Your Repo ID", placeholder="yourid/yourrepo", value="", max_lines=1)
186
+ newrepo_type = gr.Radio(label="Repo Type", choices=["model", "dataset"], value="model")
187
+ with gr.Group():
188
+ is_private = gr.Checkbox(label="Private Repo", value=False)
189
+ is_rename = gr.Checkbox(label="Auto rename", value=True)
190
+ with gr.Row():
191
+ is_info = gr.Checkbox(label="Upload Civitai Metadata Files to HF", value=True)
192
+
193
+ run_button = gr.Button(value="7️⃣ Create Repo & Backup Models", variant="primary")
194
+ uploaded_urls = gr.CheckboxGroup(visible=False, choices=[], value=[]) # hidden
195
+ urls_md = gr.Markdown("<br><br>", elem_classes="result")
196
+ urls_remain = gr.Textbox("Remaining URLs", value="", show_copy_button=True, visible=False)
197
+
198
+ with gr.TabItem("Settings 1️⃣"):
199
+ with gr.Column():
200
+ with gr.Group():
201
+ with gr.Row():
202
+ with gr.Column():
203
+ civitai_key = gr.Textbox(label="Your Civitai Key", value="", max_lines=1)
204
+ gr.Markdown("Your Civitai API key is available at [https://civitai.com/user/account](https://civitai.com/user/account).", elem_classes="info")
205
+ with gr.Column():
206
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
207
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).", elem_classes="info")
208
+
209
+ # Add save status message
210
+ saved_message = gr.Markdown("✅ API keys saved to browser storage", visible=False)
211
+
212
+ # Load saved keys when the app starts with error handling
213
+ @demo.load(inputs=[api_keys_state], outputs=[civitai_key, hf_token])
214
+ def load_api_keys(saved_keys):
215
+ try:
216
+ if not saved_keys or not isinstance(saved_keys, list) or len(saved_keys) != 2:
217
+ return "", ""
218
+ # Ensure we're not trying to use None values
219
+ civitai = saved_keys[0] or ""
220
+ hf = saved_keys[1] or ""
221
+ return civitai, hf
222
+ except Exception as e:
223
+ print(f"Error loading API keys: {e}")
224
+ return "", ""
225
+
226
+ # Save keys when they change with improved error handling
227
+ @gr.on([civitai_key.change, hf_token.change], inputs=[civitai_key, hf_token], outputs=[api_keys_state, saved_message])
228
+ def save_api_keys(civitai_key, hf_token):
229
+ try:
230
+ # Ensure we're storing strings, not None
231
+ civitai = str(civitai_key or "")
232
+ hf = str(hf_token or "")
233
+ timestamp = time.strftime("%I:%M:%S %p")
234
+
235
+ if civitai or hf:
236
+ return [civitai, hf], gr.Markdown(
237
+ f"✅ API keys saved to browser storage at {timestamp}",
238
+ visible=True
239
+ )
240
+ return ["", ""], gr.Markdown("", visible=False)
241
+ except Exception as e:
242
+ print(f"Error saving API keys: {e}")
243
+ return ["", ""], gr.Markdown("Error saving API keys", visible=True)
244
+
245
+ with Modal(visible=True, elem_id="modal-window") as modal:
246
+ gr.Markdown(common_sense)
247
+ gr.on(
248
+ triggers=[run_button.click],
249
+ fn=download_civitai,
250
+ inputs=[dl_url, civitai_key, hf_token, uploaded_urls, newrepo_id, newrepo_type, is_private, is_info, is_rename],
251
+ outputs=[uploaded_urls, urls_md, urls_remain],
252
+ queue=True,
253
+ )
254
+ gr.on(
255
+ triggers=[search_civitai_submit.click, search_civitai_query.submit, search_civitai_user.submit],
256
+ fn=search_civitai,
257
+ inputs=[search_civitai_query, search_civitai_type, search_civitai_basemodel, search_civitai_sort,
258
+ search_civitai_period, search_civitai_tag, search_civitai_user, search_civitai_limit,
259
+ search_civitai_page, search_civitai_filetype, civitai_key, search_civitai_gallery, state],
260
+ outputs=[search_civitai_result, search_civitai_desc, search_civitai_submit, search_civitai_query, search_civitai_gallery,
261
+ search_civitai_result_checkbox, search_civitai_result_info, state],
262
+ queue=False,
263
+ show_api=False,
264
+ )
265
+ search_civitai_result.change(select_civitai_item, [search_civitai_result, state], [search_civitai_desc, search_civitai_json, state], queue=False, show_api=False)\
266
+ .success(update_civitai_checkbox, [search_civitai_result], [search_civitai_result_checkbox], queue=True, show_api=False)
267
+ search_civitai_result_checkbox.select(from_civitai_checkbox, [search_civitai_result_checkbox], [search_civitai_result], queue=False, show_api=False)
268
+ search_civitai_add.click(add_civitai_item, [search_civitai_result, dl_url], [dl_url], queue=False, show_api=False)
269
+ search_civitai_select_all.click(select_civitai_all_item, [search_civitai_select_all, state], [search_civitai_select_all, search_civitai_result], queue=False, show_api=False)
270
+ search_civitai_gallery.select(update_civitai_selection, [search_civitai_result, state], [search_civitai_result], queue=False, show_api=False)
271
+
272
+ show_btn.click(lambda: Modal(visible=True), None, modal)
273
+
274
+ demo.queue()
275
+ demo.launch(ssr_mode=False)
civitai_api.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, hf_hub_url
3
+ import os
4
+ from pathlib import Path
5
+ import gc
6
+ import requests
7
+ from requests.adapters import HTTPAdapter
8
+ from urllib3.util import Retry
9
+ from civitai_constants import PERIOD, SORT
10
+ from utils import (get_token, set_token, is_repo_exists, get_user_agent, get_download_file,
11
+ list_uniq, list_sub, duplicate_hf_repo, HF_SUBFOLDER_NAME, get_state, set_state)
12
+ import re
13
+ from PIL import Image
14
+ import json
15
+ import pandas as pd
16
+ import tempfile
17
+ import hashlib
18
+
19
+ # Huge shoutout to @John6666, saved me many hours.
20
+
21
+ TEMP_DIR = tempfile.mkdtemp()
22
+
23
+
24
+ def parse_urls(s):
25
+ url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
26
+ try:
27
+ urls = re.findall(url_pattern, s)
28
+ return list(urls)
29
+ except Exception:
30
+ return []
31
+
32
+
33
+ def parse_repos(s):
34
+ repo_pattern = r'[^\w_\-\.]?([\w_\-\.]+/[\w_\-\.]+)[^\w_\-\.]?'
35
+ try:
36
+ s = re.sub("https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+", "", s)
37
+ repos = re.findall(repo_pattern, s)
38
+ return list(repos)
39
+ except Exception:
40
+ return []
41
+
42
+
43
+ def to_urls(l: list[str]):
44
+ return "\n".join(l)
45
+
46
+
47
+ def uniq_urls(s):
48
+ return to_urls(list_uniq(parse_urls(s) + parse_repos(s)))
49
+
50
+
51
+ def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
52
+ output_filename = Path(filename).name
53
+ hf_token = get_token()
54
+ api = HfApi(token=hf_token)
55
+ try:
56
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
57
+ progress(0, desc=f"Start uploading... {filename} to {repo_id}")
58
+ api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
59
+ progress(1, desc="Uploaded.")
60
+ url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
61
+ except Exception as e:
62
+ print(f"Error: Failed to upload to {repo_id}. {e}")
63
+ gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
64
+ return None
65
+ finally:
66
+ if Path(filename).exists(): Path(filename).unlink()
67
+ return url
68
+
69
+
70
+ def is_same_file(filename: str, cmp_sha256: str, cmp_size: int):
71
+ if cmp_sha256:
72
+ sha256_hash = hashlib.sha256()
73
+ with open(filename, "rb") as f:
74
+ for byte_block in iter(lambda: f.read(4096), b""):
75
+ sha256_hash.update(byte_block)
76
+ sha256 = sha256_hash.hexdigest()
77
+ else: sha256 = ""
78
+ size = os.path.getsize(filename)
79
+ if size == cmp_size and sha256 == cmp_sha256: return True
80
+ else: return False
81
+
82
+
83
+ def get_safe_filename(filename, repo_id, repo_type):
84
+ hf_token = get_token()
85
+ api = HfApi(token=hf_token)
86
+ new_filename = filename
87
+ try:
88
+ i = 1
89
+ while api.file_exists(repo_id=repo_id, filename=Path(new_filename).name, repo_type=repo_type, token=hf_token):
90
+ infos = api.get_paths_info(repo_id=repo_id, paths=[Path(new_filename).name], repo_type=repo_type, token=hf_token)
91
+ if infos and len(infos) == 1:
92
+ repo_fs = infos[0].size
93
+ repo_sha256 = infos[0].lfs.sha256 if infos[0].lfs is not None else ""
94
+ if is_same_file(filename, repo_sha256, repo_fs): break
95
+ new_filename = str(Path(Path(filename).parent, f"{Path(filename).stem}_{i}{Path(filename).suffix}"))
96
+ i += 1
97
+ if filename != new_filename:
98
+ print(f"{Path(filename).name} is already exists but file content is different. renaming to {Path(new_filename).name}.")
99
+ Path(filename).rename(new_filename)
100
+ except Exception as e:
101
+ print(f"Error occured when renaming {filename}. {e}")
102
+ finally:
103
+ return new_filename
104
+
105
+
106
+ def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
107
+ download_dir = TEMP_DIR
108
+ progress(0, desc=f"Start downloading... {dl_url}")
109
+ output_filename = get_download_file(download_dir, dl_url, civitai_key)
110
+ return output_filename
111
+
112
+
113
+ def save_civitai_info(dl_url, filename, civitai_key="", progress=gr.Progress(track_tqdm=True)):
114
+ json_str, html_str, image_path = get_civitai_json(dl_url, True, filename, civitai_key)
115
+ if not json_str: return "", "", ""
116
+ json_path = str(Path(TEMP_DIR, Path(filename).stem + ".json"))
117
+ html_path = str(Path(TEMP_DIR, Path(filename).stem + ".html"))
118
+ try:
119
+ with open(json_path, 'w') as f:
120
+ json.dump(json_str, f, indent=2)
121
+ with open(html_path, mode='w', encoding="utf-8") as f:
122
+ f.write(html_str)
123
+ return json_path, html_path, image_path
124
+ except Exception as e:
125
+ print(f"Error: Failed to save info file {json_path}, {html_path} {e}")
126
+ return "", "", ""
127
+
128
+
129
+ def upload_info_to_repo(dl_url, filename, repo_id, repo_type, is_private, civitai_key="", progress=gr.Progress(track_tqdm=True)):
130
+ def upload_file(api, filename, repo_id, repo_type, hf_token):
131
+ if not Path(filename).exists(): return
132
+ api.upload_file(path_or_fileobj=filename, path_in_repo=Path(filename).name, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
133
+ Path(filename).unlink()
134
+
135
+ hf_token = get_token()
136
+ api = HfApi(token=hf_token)
137
+ try:
138
+ if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
139
+ progress(0, desc=f"Downloading info... {filename}")
140
+ json_path, html_path, image_path = save_civitai_info(dl_url, filename, civitai_key)
141
+ progress(0, desc=f"Start uploading info... {filename} to {repo_id}")
142
+ if not json_path: return
143
+ else: upload_file(api, json_path, repo_id, repo_type, hf_token)
144
+ if html_path: upload_file(api, html_path, repo_id, repo_type, hf_token)
145
+ if image_path: upload_file(api, image_path, repo_id, repo_type, hf_token)
146
+ progress(1, desc="Info uploaded.")
147
+ return
148
+ except Exception as e:
149
+ print(f"Error: Failed to upload info to {repo_id}. {e}")
150
+ gr.Warning(f"Error: Failed to upload info to {repo_id}. {e}")
151
+ return
152
+
153
+
154
+ def download_civitai(dl_url, civitai_key, hf_token, urls,
155
+ newrepo_id, repo_type="model", is_private=True, is_info=False, is_rename=True, progress=gr.Progress(track_tqdm=True)):
156
+ if hf_token: set_token(hf_token)
157
+ else: set_token(os.environ.get("HF_TOKEN")) # default huggingface write token
158
+ if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
159
+ if not newrepo_id: newrepo_id = os.environ.get("HF_REPO") # default repo to upload
160
+ if not get_token() or not civitai_key: raise gr.Error("HF write token and Civitai API key is required.")
161
+ if not urls: urls = []
162
+ dl_urls = parse_urls(dl_url)
163
+ remain_urls = dl_urls.copy()
164
+ try:
165
+ md = f'### Your repo: [{newrepo_id}]({"https://huggingface.co/datasets/" if repo_type == "dataset" else "https://huggingface.co/"}{newrepo_id})\n'
166
+ for u in dl_urls:
167
+ file = download_file(u, civitai_key)
168
+ if not Path(file).exists() or not Path(file).is_file(): continue
169
+ if is_rename: file = get_safe_filename(file, newrepo_id, repo_type)
170
+ url = upload_safetensors_to_repo(file, newrepo_id, repo_type, is_private)
171
+ if url:
172
+ if is_info: upload_info_to_repo(u, file, newrepo_id, repo_type, is_private, civitai_key)
173
+ urls.append(url)
174
+ remain_urls.remove(u)
175
+ md += f"- Uploaded [{str(u)}]({str(u)})\n"
176
+ dp_repos = parse_repos(dl_url)
177
+ for r in dp_repos:
178
+ url = duplicate_hf_repo(r, newrepo_id, "model", repo_type, is_private, HF_SUBFOLDER_NAME[1])
179
+ if url: urls.append(url)
180
+ return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=False)
181
+ except Exception as e:
182
+ gr.Info(f"Error occured: {e}")
183
+ return gr.update(value=urls, choices=urls), gr.update(value=md), gr.update(value="\n".join(remain_urls), visible=True)
184
+ finally:
185
+ gc.collect()
186
+
187
+
188
+ def search_on_civitai(query: str, types: list[str], allow_model: list[str] = [], limit: int = 100,
189
+ sort: str = "Highest Rated", period: str = "AllTime", tag: str = "", user: str = "", page: int = 1,
190
+ filetype: list[str] = [], api_key: str = "", progress=gr.Progress(track_tqdm=True)):
191
+ user_agent = get_user_agent()
192
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
193
+ if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
194
+ base_url = 'https://civitai.com/api/v1/models'
195
+ params = {'sort': sort, 'period': period, 'limit': int(limit), 'nsfw': 'true'}
196
+ if len(types) != 0: params["types"] = types
197
+ if query: params["query"] = query
198
+ if tag: params["tag"] = tag
199
+ if user: params["username"] = user
200
+ if page != 0: params["page"] = int(page)
201
+ session = requests.Session()
202
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
203
+ session.mount("https://", HTTPAdapter(max_retries=retries))
204
+ rs = []
205
+ try:
206
+ if page == 0:
207
+ progress(0, desc="Searching page 1...")
208
+ print("Searching page 1...")
209
+ r = session.get(base_url, params=params | {'page': 1}, headers=headers, stream=True, timeout=(7.0, 30))
210
+ rs.append(r)
211
+ if r.ok:
212
+ json = r.json()
213
+ next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
214
+ i = 2
215
+ while(next_url is not None):
216
+ progress(0, desc=f"Searching page {i}...")
217
+ print(f"Searching page {i}...")
218
+ r = session.get(next_url, headers=headers, stream=True, timeout=(7.0, 30))
219
+ rs.append(r)
220
+ if r.ok:
221
+ json = r.json()
222
+ next_url = json['metadata']['nextPage'] if 'metadata' in json and 'nextPage' in json['metadata'] else None
223
+ else: next_url = None
224
+ i += 1
225
+ else:
226
+ progress(0, desc="Searching page 1...")
227
+ print("Searching page 1...")
228
+ r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(7.0, 30))
229
+ rs.append(r)
230
+ except requests.exceptions.ConnectTimeout:
231
+ print("Request timed out.")
232
+ except Exception as e:
233
+ print(e)
234
+ items = []
235
+ for r in rs:
236
+ if not r.ok: continue
237
+ json = r.json()
238
+ if 'items' not in json: continue
239
+ for j in json['items']:
240
+ for model in j['modelVersions']:
241
+ item = {}
242
+ if len(allow_model) != 0 and model['baseModel'] not in set(allow_model): continue
243
+ item['name'] = j['name']
244
+ item['creator'] = j['creator']['username'] if 'creator' in j.keys() and 'username' in j['creator'].keys() else ""
245
+ item['tags'] = j['tags'] if 'tags' in j.keys() else []
246
+ item['model_name'] = model['name'] if 'name' in model.keys() else ""
247
+ item['base_model'] = model['baseModel'] if 'baseModel' in model.keys() else ""
248
+ item['description'] = model['description'] if 'description' in model.keys() else ""
249
+ item['md'] = ""
250
+
251
+ # Handle both images and videos
252
+ if 'images' in model.keys() and len(model["images"]) != 0:
253
+ first_media = model["images"][0]
254
+ item['img_url'] = first_media["url"]
255
+ item['is_video'] = first_media.get("type", "image") == "video"
256
+ item['video_url'] = first_media.get("meta", {}).get("video", "") if item['is_video'] else ""
257
+
258
+ if item['is_video']:
259
+ item['md'] += f'<video src="{item["img_url"]}" poster="{item["img_url"]}" muted loop autoplay width="300" height="480" style="float:right;padding:16px;"></video><br>'
260
+ else:
261
+ item['md'] += f'<img src="{item["img_url"]}#float" alt="thumbnail" width="150" height="240"><br>'
262
+ else:
263
+ item['img_url'] = "/home/user/app/null.png"
264
+ item['is_video'] = False
265
+ item['video_url'] = ""
266
+
267
+ item['md'] += f'''Model URL: [https://civitai.com/models/{j["id"]}](https://civitai.com/models/{j["id"]})<br>Model Name: {item["name"]}<br>
268
+ Creator: {item["creator"]}<br>Tags: {", ".join(item["tags"])}<br>Base Model: {item["base_model"]}<br>Description: {item["description"]}'''
269
+ if 'files' in model.keys():
270
+ for f in model['files']:
271
+ i = item.copy()
272
+ i['dl_url'] = f['downloadUrl']
273
+ if len(filetype) != 0 and f['type'] not in set(filetype): continue
274
+ items.append(i)
275
+ else:
276
+ item['dl_url'] = model['downloadUrl']
277
+ items.append(item)
278
+ return items if len(items) > 0 else None
279
+
280
+
281
+ def search_civitai(query, types, base_model=[], sort=SORT[0], period=PERIOD[0], tag="", user="", limit=100, page=1,
282
+ filetype=[], api_key="", gallery=[], state={}, progress=gr.Progress(track_tqdm=True)):
283
+ civitai_last_results = {}
284
+ set_state(state, "civitai_last_choices", [("", "")])
285
+ set_state(state, "civitai_last_gallery", [])
286
+ set_state(state, "civitai_last_results", civitai_last_results)
287
+ results_info = "No item found."
288
+ items = search_on_civitai(query, types, base_model, int(limit), sort, period, tag, user, int(page), filetype, api_key)
289
+ if not items: return gr.update(choices=[("", "")], value=[], visible=True),\
290
+ gr.update(value="", visible=False), gr.update(), gr.update(), gr.update(), gr.update(), results_info, state
291
+ choices = []
292
+ gallery = []
293
+ for item in items:
294
+ base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
295
+ name = f"{item['name']} (for {base_model_name} / By: {item['creator']})"
296
+ value = item['dl_url']
297
+ choices.append((name, value))
298
+
299
+ # For gallery, use tuples with HTML that includes both image and video
300
+ if item.get('is_video') and item.get('video_url'):
301
+ # Create an HTML element that contains both image and video
302
+ media_html = f"""
303
+ <div class="media-container">
304
+ <img src="{item['img_url']}" alt="{name}">
305
+ <video src="{item['video_url']}" muted loop poster="{item['img_url']}"></video>
306
+ </div>
307
+ """
308
+ gallery.append((item['img_url'], name)) # Keep using image as thumbnail
309
+ else:
310
+ gallery.append((item['img_url'], name))
311
+
312
+ civitai_last_results[value] = item
313
+ if len(choices) >= 1:
314
+ results_info = f"{int(len(choices))} items found."
315
+ else:
316
+ choices = [("", "")]
317
+
318
+ md = ""
319
+ set_state(state, "civitai_last_choices", choices)
320
+ set_state(state, "civitai_last_gallery", gallery)
321
+ set_state(state, "civitai_last_results", civitai_last_results)
322
+
323
+ return gr.update(choices=choices, value=[], visible=True),\
324
+ gr.update(value=md, visible=True),\
325
+ gr.update(),\
326
+ gr.update(),\
327
+ gr.update(value=gallery),\
328
+ gr.update(choices=choices, value=[]),\
329
+ results_info,\
330
+ state
331
+
332
+
333
+ def get_civitai_json(dl_url: str, is_html: bool=False, image_baseurl: str="", api_key=""):
334
+ if not image_baseurl: image_baseurl = dl_url
335
+ default = ("", "", "") if is_html else ""
336
+ if "https://civitai.com/api/download/models/" not in dl_url: return default
337
+ user_agent = get_user_agent()
338
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
339
+ if api_key: headers['Authorization'] = f'Bearer {{{api_key}}}'
340
+ base_url = 'https://civitai.com/api/v1/model-versions/'
341
+ params = {}
342
+ session = requests.Session()
343
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
344
+ session.mount("https://", HTTPAdapter(max_retries=retries))
345
+ model_id = re.sub('https://civitai.com/api/download/models/(\\d+)(?:.+)?', '\\1', dl_url)
346
+ url = base_url + model_id
347
+ #url = base_url + str(dl_url.split("/")[-1])
348
+ try:
349
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
350
+ if not r.ok: return default
351
+ json = dict(r.json()).copy()
352
+ html = ""
353
+ image = ""
354
+ if "modelId" in json.keys():
355
+ url = f"https://civitai.com/models/{json['modelId']}"
356
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
357
+ if not r.ok: return json, html, image
358
+ html = r.text
359
+ if 'images' in json.keys() and len(json["images"]) != 0:
360
+ url = json["images"][0]["url"]
361
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(5.0, 15))
362
+ if not r.ok: return json, html, image
363
+ image_temp = str(Path(TEMP_DIR, "image" + Path(url.split("/")[-1]).suffix))
364
+ image = str(Path(TEMP_DIR, Path(image_baseurl.split("/")[-1]).stem + ".png"))
365
+ with open(image_temp, 'wb') as f:
366
+ f.write(r.content)
367
+ Image.open(image_temp).convert('RGBA').save(image)
368
+ return json, html, image
369
+ except Exception as e:
370
+ print(e)
371
+ return default
372
+
373
+
374
+ def get_civitai_tag():
375
+ default = [""]
376
+ user_agent = get_user_agent()
377
+ headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
378
+ base_url = 'https://civitai.com/api/v1/tags'
379
+ params = {'limit': 200}
380
+ session = requests.Session()
381
+ retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
382
+ session.mount("https://", HTTPAdapter(max_retries=retries))
383
+ url = base_url
384
+ try:
385
+ r = session.get(url, params=params, headers=headers, stream=True, timeout=(7.0, 15))
386
+ if not r.ok: return default
387
+ j = dict(r.json()).copy()
388
+ if "items" not in j.keys(): return default
389
+ items = []
390
+ for item in j["items"]:
391
+ items.append([str(item.get("name", "")), int(item.get("modelCount", 0))])
392
+ df = pd.DataFrame(items)
393
+ df.sort_values(1, ascending=False)
394
+ tags = df.values.tolist()
395
+ tags = [""] + [l[0] for l in tags]
396
+ return tags
397
+ except Exception as e:
398
+ print(e)
399
+ return default
400
+
401
+
402
+ def select_civitai_item(results: list[str], state: dict):
403
+ json = {}
404
+ if "http" not in "".join(results) or len(results) == 0: return gr.update(value="", visible=True), gr.update(value=json, visible=False), state
405
+ result = get_state(state, "civitai_last_results")
406
+ last_selects = get_state(state, "civitai_last_selects")
407
+ selects = list_sub(results, last_selects if last_selects else [])
408
+ md = result.get(selects[-1]).get('md', "") if result and isinstance(result, dict) and len(selects) > 0 else ""
409
+ set_state(state, "civitai_last_selects", results)
410
+ return gr.update(value=md, visible=True), gr.update(value=json, visible=False), state
411
+
412
+
413
+ def add_civitai_item(results: list[str], dl_url: str):
414
+ if "http" not in "".join(results): return gr.update(value=dl_url)
415
+ new_url = dl_url if dl_url else ""
416
+ for result in results:
417
+ if "http" not in result: continue
418
+ new_url += f"\n{result}" if new_url else f"{result}"
419
+ new_url = uniq_urls(new_url)
420
+ return gr.update(value=new_url)
421
+
422
+
423
+ def select_civitai_all_item(button_name: str, state: dict):
424
+ if button_name not in ["Select All", "Deselect All"]: return gr.update(value=button_name), gr.Update(visible=True)
425
+ civitai_last_choices = get_state(state, "civitai_last_choices")
426
+ selected = [t[1] for t in civitai_last_choices if t[1] != ""] if button_name == "Select All" else []
427
+ new_button_name = "Select All" if button_name == "Deselect All" else "Deselect All"
428
+ return gr.update(value=new_button_name), gr.update(value=selected, choices=civitai_last_choices)
429
+
430
+
431
+ def update_civitai_selection(evt: gr.SelectData, value: list[str], state: dict):
432
+ try:
433
+ civitai_last_choices = get_state(state, "civitai_last_choices")
434
+ selected_index = evt.index
435
+ selected = list_uniq([v for v in value if v != ""] + [civitai_last_choices[selected_index][1]])
436
+ return gr.update(value=selected)
437
+ except Exception:
438
+ return gr.update()
439
+
440
+
441
+ def update_civitai_checkbox(selected: list[str]):
442
+ return gr.update(value=selected)
443
+
444
+
445
+ def from_civitai_checkbox(selected: list[str]):
446
+ return gr.update(value=selected)
civitai_constants.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CivitAI API Constants
2
+
3
+ TYPE = [
4
+ "Checkpoint", "TextualInversion", "Hypernetwork", "AestheticGradient",
5
+ "LORA", "LoCon", "DoRA", "Controlnet", "Upscaler", "Motion", "VAE",
6
+ "Poses", "Wildcards", "Workflows", "Other"
7
+ ]
8
+
9
+ FILETYPE = [
10
+ "Model", "VAE", "Config", "Training Data"
11
+ ]
12
+
13
+ BASEMODEL = [
14
+ "Pony", "Illustrious", "SDXL 1.0", "SD 1.5", "Flux.1 D", "Flux.1 S",
15
+ "SD 3.5", "CogVideoX", "SVD", "SVD XT", "Wan Video", "Mochi", "LTXV",
16
+ "Hunyuan Video", "HiDream", "Other"
17
+ ]
18
+
19
+ SORT = [
20
+ "Highest Rated", "Most Downloaded", "Most Liked", "Most Discussed",
21
+ "Most Collected", "Most Buzz", "Newest"
22
+ ]
23
+
24
+ PERIOD = [
25
+ "AllTime", "Year", "Month", "Week", "Day"
26
+ ]
null.png ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ huggingface-hub
2
+ gdown
3
+ gradio_modal
utils.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+ import subprocess
10
+ import time
11
+ from typing import Any
12
+
13
+
14
+ def get_token():
15
+ try:
16
+ token = HfFolder.get_token()
17
+ except Exception:
18
+ token = ""
19
+ return token
20
+
21
+
22
+ def set_token(token):
23
+ try:
24
+ HfFolder.save_token(token)
25
+ except Exception:
26
+ print(f"Error: Failed to save token.")
27
+
28
+
29
+ def get_state(state: dict, key: str):
30
+ if key in state.keys(): return state[key]
31
+ else:
32
+ print(f"State '{key}' not found.")
33
+ return None
34
+
35
+
36
+ def set_state(state: dict, key: str, value: Any):
37
+ state[key] = value
38
+
39
+
40
+ def get_user_agent():
41
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
42
+
43
+
44
+ def is_repo_exists(repo_id: str, repo_type: str="model"):
45
+ hf_token = get_token()
46
+ api = HfApi(token=hf_token)
47
+ try:
48
+ if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
49
+ else: return False
50
+ except Exception as e:
51
+ print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
52
+ return True # for safe
53
+
54
+
55
+ MODEL_TYPE_CLASS = {
56
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
57
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
58
+ "diffusers:FluxPipeline": "FLUX",
59
+ }
60
+
61
+
62
+ def get_model_type(repo_id: str):
63
+ hf_token = get_token()
64
+ api = HfApi(token=hf_token)
65
+ lora_filename = "pytorch_lora_weights.safetensors"
66
+ diffusers_filename = "model_index.json"
67
+ default = "SDXL"
68
+ try:
69
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
70
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
71
+ model = api.model_info(repo_id=repo_id, token=hf_token)
72
+ tags = model.tags
73
+ for tag in tags:
74
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
75
+ except Exception:
76
+ return default
77
+ return default
78
+
79
+
80
+ def list_uniq(l):
81
+ return sorted(set(l), key=l.index)
82
+
83
+
84
+ def list_sub(a, b):
85
+ return [e for e in a if e not in b]
86
+
87
+
88
+ def is_repo_name(s):
89
+ return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s)
90
+
91
+
92
+ def get_hf_url(repo_id: str, repo_type: str="model"):
93
+ if repo_type == "dataset": url = f"https://huggingface.co/datasets/{repo_id}"
94
+ elif repo_type == "space": url = f"https://huggingface.co/spaces/{repo_id}"
95
+ else: url = f"https://huggingface.co/{repo_id}"
96
+ return url
97
+
98
+
99
+ def split_hf_url(url: str):
100
+ try:
101
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
102
+ if len(s) < 4: return "", "", "", ""
103
+ repo_id = s[1]
104
+ if s[0] == "datasets": repo_type = "dataset"
105
+ elif s[0] == "spaces": repo_type = "space"
106
+ else: repo_type = "model"
107
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
108
+ filename = urllib.parse.unquote(s[3])
109
+ return repo_id, filename, subfolder, repo_type
110
+ except Exception as e:
111
+ print(e)
112
+
113
+
114
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
115
+ hf_token = get_token()
116
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
117
+ try:
118
+ print(f"Downloading {url} to {directory}")
119
+ if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
120
+ else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
121
+ return path
122
+ except Exception as e:
123
+ print(f"Failed to download: {e}")
124
+ return None
125
+
126
+
127
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
128
+ try:
129
+ url = url.strip()
130
+ if "drive.google.com" in url:
131
+ original_dir = os.getcwd()
132
+ os.chdir(directory)
133
+ subprocess.run(f"gdown --fuzzy {url}", shell=True)
134
+ os.chdir(original_dir)
135
+ elif "huggingface.co" in url:
136
+ url = url.replace("?download=true", "")
137
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
138
+ download_hf_file(directory, url)
139
+ elif "civitai.com" in url:
140
+ if civitai_api_key:
141
+ url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}"
142
+ print(f"Downloading {url}")
143
+ subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True)
144
+ else:
145
+ print("You need an API key to download Civitai models.")
146
+ else:
147
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
148
+ except Exception as e:
149
+ print(f"Failed to download: {e}")
150
+
151
+
152
+ def get_local_file_list(dir_path, recursive=False):
153
+ file_list = []
154
+ pattern = "**/*.*" if recursive else "*/*.*"
155
+ for file in Path(dir_path).glob(pattern):
156
+ if file.is_file():
157
+ file_path = str(file)
158
+ file_list.append(file_path)
159
+ return file_list
160
+
161
+
162
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
163
+ try:
164
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
165
+ print(f"Use HF Repo: {url}")
166
+ new_file = url
167
+ elif not "http" in url and Path(url).exists():
168
+ print(f"Use local file: {url}")
169
+ new_file = url
170
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
171
+ print(f"File to download alreday exists: {url}")
172
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
173
+ else:
174
+ print(f"Start downloading: {url}")
175
+ recursive = False if "huggingface.co" in url else True
176
+ before = get_local_file_list(temp_dir, recursive)
177
+ download_thing(temp_dir, url.strip(), civitai_key)
178
+ after = get_local_file_list(temp_dir, recursive)
179
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
180
+ if not new_file:
181
+ print(f"Download failed: {url}")
182
+ return ""
183
+ print(f"Download completed: {url}")
184
+ return new_file
185
+ except Exception as e:
186
+ print(f"Download failed: {url} {e}")
187
+ return ""
188
+
189
+
190
+ def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
191
+ hf_token = get_token()
192
+ try:
193
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
194
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"], force_download=True)
195
+ return True
196
+ except Exception as e:
197
+ print(f"Error: Failed to download {repo_id}. {e}")
198
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
199
+ return False
200
+
201
+
202
+ def upload_repo(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
203
+ hf_token = get_token()
204
+ api = HfApi(token=hf_token)
205
+ try:
206
+ progress(0, desc="Start uploading...")
207
+ api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
208
+ api.upload_folder(repo_id=repo_id, folder_path=dir_path, path_in_repo="", create_pr=is_pr, token=hf_token)
209
+ progress(1, desc="Uploaded.")
210
+ return get_hf_url(repo_id, "model")
211
+ except Exception as e:
212
+ print(f"Error: Failed to upload to {repo_id}. {e}")
213
+ return ""
214
+
215
+
216
+ def gate_repo(repo_id: str, gated_str: str, repo_type: str="model"):
217
+ hf_token = get_token()
218
+ api = HfApi(token=hf_token)
219
+ try:
220
+ if gated_str == "auto": gated = "auto"
221
+ elif gated_str == "manual": gated = "manual"
222
+ else: gated = False
223
+ api.update_repo_settings(repo_id=repo_id, gated=gated, repo_type=repo_type, token=hf_token)
224
+ except Exception as e:
225
+ print(f"Error: Failed to update settings {repo_id}. {e}")
226
+
227
+
228
+ HF_SUBFOLDER_NAME = ["None", "user_repo"]
229
+
230
+
231
+ def duplicate_hf_repo(src_repo: str, dst_repo: str, src_repo_type: str, dst_repo_type: str,
232
+ is_private: bool, subfolder_type: str=HF_SUBFOLDER_NAME[1], progress=gr.Progress(track_tqdm=True)):
233
+ hf_token = get_token()
234
+ api = HfApi(token=hf_token)
235
+ try:
236
+ if subfolder_type == "user_repo": subfolder = src_repo.replace("/", "_")
237
+ else: subfolder = ""
238
+ progress(0, desc="Start duplicating...")
239
+ api.create_repo(repo_id=dst_repo, repo_type=dst_repo_type, private=is_private, exist_ok=True, token=hf_token)
240
+ for path in api.list_repo_files(repo_id=src_repo, repo_type=src_repo_type, token=hf_token):
241
+ file = hf_hub_download(repo_id=src_repo, filename=path, repo_type=src_repo_type, token=hf_token)
242
+ if not Path(file).exists(): continue
243
+ if Path(file).is_dir(): # unused for now
244
+ api.upload_folder(repo_id=dst_repo, folder_path=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
245
+ repo_type=dst_repo_type, token=hf_token)
246
+ elif Path(file).is_file():
247
+ api.upload_file(repo_id=dst_repo, path_or_fileobj=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
248
+ repo_type=dst_repo_type, token=hf_token)
249
+ if Path(file).exists(): Path(file).unlink()
250
+ progress(1, desc="Duplicated.")
251
+ return f"{get_hf_url(dst_repo, dst_repo_type)}/tree/main/{subfolder}" if subfolder else get_hf_url(dst_repo, dst_repo_type)
252
+ except Exception as e:
253
+ print(f"Error: Failed to duplicate repo {src_repo} to {dst_repo}. {e}")
254
+ return ""
255
+
256
+
257
+ BASE_DIR = str(Path(__file__).resolve().parent.resolve())
258
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
259
+
260
+
261
+ def get_file(url: str, path: str): # requires aria2, gdown
262
+ print(f"Downloading {url} to {path}...")
263
+ get_download_file(path, url, CIVITAI_API_KEY)
264
+
265
+
266
+ def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git
267
+ os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True)
268
+ os.chdir(Path(BASE_DIR, path))
269
+ print(f"Cloning {url} to {path}...")
270
+ cmd = f'git clone {url}'
271
+ print(f'Running {cmd} at {Path.cwd()}')
272
+ i = subprocess.run(cmd, shell=True).returncode
273
+ if i != 0: print(f'Error occured at running {cmd}')
274
+ p = url.split("/")[-1]
275
+ if not Path(p).exists: return
276
+ if pip:
277
+ os.chdir(Path(BASE_DIR, path, p))
278
+ cmd = f'pip install -r requirements.txt'
279
+ print(f'Running {cmd} at {Path.cwd()}')
280
+ i = subprocess.run(cmd, shell=True).returncode
281
+ if i != 0: print(f'Error occured at running {cmd}')
282
+ if addcmd:
283
+ os.chdir(Path(BASE_DIR, path, p))
284
+ cmd = addcmd
285
+ print(f'Running {cmd} at {Path.cwd()}')
286
+ i = subprocess.run(cmd, shell=True).returncode
287
+ if i != 0: print(f'Error occured at running {cmd}')
288
+
289
+
290
+ def run(cmd: str, timeout: float=0):
291
+ print(f'Running {cmd} at {Path.cwd()}')
292
+ if timeout == 0:
293
+ i = subprocess.run(cmd, shell=True).returncode
294
+ if i != 0: print(f'Error occured at running {cmd}')
295
+ else:
296
+ p = subprocess.Popen(cmd, shell=True)
297
+ time.sleep(timeout)
298
+ p.terminate()
299
+ print(f'Terminated in {timeout} seconds')