Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,6 @@ import torch
|
|
13 |
from PIL import Image
|
14 |
import gradio as gr
|
15 |
|
16 |
-
|
17 |
from diffusers import (
|
18 |
DiffusionPipeline,
|
19 |
AutoencoderTiny,
|
@@ -26,13 +25,21 @@ from huggingface_hub import (
|
|
26 |
hf_hub_download,
|
27 |
HfFileSystem,
|
28 |
ModelCard,
|
29 |
-
snapshot_download
|
|
|
|
|
30 |
|
31 |
from diffusers.utils import load_image
|
32 |
|
33 |
import spaces
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def calculate_shift(
|
38 |
image_seq_len,
|
@@ -2089,24 +2096,25 @@ loras = [
|
|
2089 |
]
|
2090 |
|
2091 |
#--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
|
2092 |
-
|
2093 |
dtype = torch.bfloat16
|
2094 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
2095 |
base_model = "black-forest-labs/FLUX.1-dev"
|
2096 |
|
2097 |
-
#TAEF1 is very tiny autoencoder
|
2098 |
-
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
2099 |
-
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
2100 |
-
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
2101 |
-
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
2102 |
-
|
2103 |
-
|
2104 |
-
|
2105 |
-
|
2106 |
-
|
2107 |
-
|
2108 |
-
|
2109 |
-
|
|
|
|
|
2110 |
|
2111 |
MAX_SEED = 2**32-1
|
2112 |
|
@@ -2210,7 +2218,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
2210 |
pipe.unload_lora_weights()
|
2211 |
pipe_i2i.unload_lora_weights()
|
2212 |
|
2213 |
-
#LoRA weights flow
|
2214 |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
2215 |
pipe_to_use = pipe_i2i if image_input is not None else pipe
|
2216 |
weight_name = selected_lora.get("weights", None)
|
@@ -2235,7 +2243,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
2235 |
final_image = None
|
2236 |
step_counter = 0
|
2237 |
for image in image_generator:
|
2238 |
-
step_counter+=1
|
2239 |
final_image = image
|
2240 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
2241 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
@@ -2243,41 +2251,37 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
2243 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
2244 |
|
2245 |
def get_huggingface_safetensors(link):
|
2246 |
-
|
2247 |
-
|
2248 |
-
|
2249 |
-
|
2250 |
-
|
2251 |
|
2252 |
-
|
2253 |
-
|
2254 |
-
|
2255 |
|
2256 |
-
|
2257 |
-
|
2258 |
-
|
2259 |
-
|
2260 |
-
|
2261 |
-
|
2262 |
-
|
2263 |
-
|
2264 |
-
|
2265 |
-
|
2266 |
-
|
2267 |
-
|
2268 |
-
|
2269 |
-
|
2270 |
-
|
2271 |
-
|
2272 |
-
|
2273 |
-
print(e)
|
2274 |
-
gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
2275 |
-
raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
2276 |
-
return split_link[1], link, safetensors_name, trigger_word, image_url
|
2277 |
|
2278 |
def check_custom_model(link):
|
2279 |
-
if
|
2280 |
-
if
|
2281 |
link_split = link.split("huggingface.co/")
|
2282 |
return get_huggingface_safetensors(link_split[1])
|
2283 |
else:
|
@@ -2285,7 +2289,7 @@ def check_custom_model(link):
|
|
2285 |
|
2286 |
def add_custom_lora(custom_lora):
|
2287 |
global loras
|
2288 |
-
if
|
2289 |
try:
|
2290 |
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
|
2291 |
print(f"Loaded custom LoRA: {repo}")
|
@@ -2302,7 +2306,7 @@ def add_custom_lora(custom_lora):
|
|
2302 |
</div>
|
2303 |
'''
|
2304 |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
|
2305 |
-
if
|
2306 |
new_item = {
|
2307 |
"image": image,
|
2308 |
"title": title,
|
@@ -2316,8 +2320,8 @@ def add_custom_lora(custom_lora):
|
|
2316 |
|
2317 |
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
|
2318 |
except Exception as e:
|
2319 |
-
gr.Warning(
|
2320 |
-
return gr.update(visible=True, value=
|
2321 |
else:
|
2322 |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
|
2323 |
|
@@ -2371,7 +2375,7 @@ with gr.Blocks(theme="prithivMLmods/Minecraft-Theme", css=css, delete_cache=(60,
|
|
2371 |
custom_lora_info = gr.HTML(visible=False)
|
2372 |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
|
2373 |
with gr.Column():
|
2374 |
-
progress_bar = gr.Markdown(elem_id="progress",visible=False)
|
2375 |
result = gr.Image(label="Generated Image")
|
2376 |
|
2377 |
with gr.Row():
|
|
|
13 |
from PIL import Image
|
14 |
import gradio as gr
|
15 |
|
|
|
16 |
from diffusers import (
|
17 |
DiffusionPipeline,
|
18 |
AutoencoderTiny,
|
|
|
25 |
hf_hub_download,
|
26 |
HfFileSystem,
|
27 |
ModelCard,
|
28 |
+
snapshot_download,
|
29 |
+
login # imported for one-time authentication
|
30 |
+
)
|
31 |
|
32 |
from diffusers.utils import load_image
|
33 |
|
34 |
import spaces
|
35 |
|
36 |
+
# -------------------------------
|
37 |
+
# Authenticate with Hugging Face once
|
38 |
+
# -------------------------------
|
39 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
40 |
+
if HF_TOKEN:
|
41 |
+
login(HF_TOKEN)
|
42 |
+
print("Authenticated with Hugging Face.")
|
43 |
|
44 |
def calculate_shift(
|
45 |
image_seq_len,
|
|
|
2096 |
]
|
2097 |
|
2098 |
#--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
|
|
|
2099 |
dtype = torch.bfloat16
|
2100 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
2101 |
base_model = "black-forest-labs/FLUX.1-dev"
|
2102 |
|
2103 |
+
# TAEF1 is a very tiny autoencoder using the same "latent API" as FLUX.1's VAE.
|
2104 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, use_auth_token=HF_TOKEN).to(device)
|
2105 |
+
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, use_auth_token=HF_TOKEN).to(device)
|
2106 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, use_auth_token=HF_TOKEN).to(device)
|
2107 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
2108 |
+
base_model,
|
2109 |
+
vae=good_vae,
|
2110 |
+
transformer=pipe.transformer,
|
2111 |
+
text_encoder=pipe.text_encoder,
|
2112 |
+
tokenizer=pipe.tokenizer,
|
2113 |
+
text_encoder_2=pipe.text_encoder_2,
|
2114 |
+
tokenizer_2=pipe.tokenizer_2,
|
2115 |
+
torch_dtype=dtype,
|
2116 |
+
use_auth_token=HF_TOKEN
|
2117 |
+
)
|
2118 |
|
2119 |
MAX_SEED = 2**32-1
|
2120 |
|
|
|
2218 |
pipe.unload_lora_weights()
|
2219 |
pipe_i2i.unload_lora_weights()
|
2220 |
|
2221 |
+
# LoRA weights flow
|
2222 |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
2223 |
pipe_to_use = pipe_i2i if image_input is not None else pipe
|
2224 |
weight_name = selected_lora.get("weights", None)
|
|
|
2243 |
final_image = None
|
2244 |
step_counter = 0
|
2245 |
for image in image_generator:
|
2246 |
+
step_counter += 1
|
2247 |
final_image = image
|
2248 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
2249 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
|
|
2251 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
2252 |
|
2253 |
def get_huggingface_safetensors(link):
|
2254 |
+
split_link = link.split("/")
|
2255 |
+
if len(split_link) == 2:
|
2256 |
+
model_card = ModelCard.load(link)
|
2257 |
+
base_model = model_card.data.get("base_model")
|
2258 |
+
print(base_model)
|
2259 |
|
2260 |
+
# Allows Both
|
2261 |
+
if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
|
2262 |
+
raise Exception("Flux LoRA Not Found!")
|
2263 |
|
2264 |
+
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
2265 |
+
trigger_word = model_card.data.get("instance_prompt", "")
|
2266 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
2267 |
+
fs = HfFileSystem()
|
2268 |
+
try:
|
2269 |
+
list_of_files = fs.ls(link, detail=False)
|
2270 |
+
for file in list_of_files:
|
2271 |
+
if file.endswith(".safetensors"):
|
2272 |
+
safetensors_name = file.split("/")[-1]
|
2273 |
+
if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
|
2274 |
+
image_elements = file.split("/")
|
2275 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
|
2276 |
+
except Exception as e:
|
2277 |
+
print(e)
|
2278 |
+
gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
2279 |
+
raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
|
2280 |
+
return split_link[1], link, safetensors_name, trigger_word, image_url
|
|
|
|
|
|
|
|
|
2281 |
|
2282 |
def check_custom_model(link):
|
2283 |
+
if link.startswith("https://"):
|
2284 |
+
if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
|
2285 |
link_split = link.split("huggingface.co/")
|
2286 |
return get_huggingface_safetensors(link_split[1])
|
2287 |
else:
|
|
|
2289 |
|
2290 |
def add_custom_lora(custom_lora):
|
2291 |
global loras
|
2292 |
+
if custom_lora:
|
2293 |
try:
|
2294 |
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
|
2295 |
print(f"Loaded custom LoRA: {repo}")
|
|
|
2306 |
</div>
|
2307 |
'''
|
2308 |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
|
2309 |
+
if existing_item_index is None:
|
2310 |
new_item = {
|
2311 |
"image": image,
|
2312 |
"title": title,
|
|
|
2320 |
|
2321 |
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
|
2322 |
except Exception as e:
|
2323 |
+
gr.Warning("Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
|
2324 |
+
return gr.update(visible=True, value="Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
|
2325 |
else:
|
2326 |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
|
2327 |
|
|
|
2375 |
custom_lora_info = gr.HTML(visible=False)
|
2376 |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
|
2377 |
with gr.Column():
|
2378 |
+
progress_bar = gr.Markdown(elem_id="progress", visible=False)
|
2379 |
result = gr.Image(label="Generated Image")
|
2380 |
|
2381 |
with gr.Row():
|