Anurag181011 commited on
Commit
e311d80
·
verified ·
1 Parent(s): abcffb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -55
app.py CHANGED
@@ -13,7 +13,6 @@ import torch
13
  from PIL import Image
14
  import gradio as gr
15
 
16
-
17
  from diffusers import (
18
  DiffusionPipeline,
19
  AutoencoderTiny,
@@ -26,13 +25,21 @@ from huggingface_hub import (
26
  hf_hub_download,
27
  HfFileSystem,
28
  ModelCard,
29
- snapshot_download)
 
 
30
 
31
  from diffusers.utils import load_image
32
 
33
  import spaces
34
 
35
-
 
 
 
 
 
 
36
 
37
  def calculate_shift(
38
  image_seq_len,
@@ -2089,24 +2096,25 @@ loras = [
2089
  ]
2090
 
2091
  #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
2092
-
2093
  dtype = torch.bfloat16
2094
  device = "cuda" if torch.cuda.is_available() else "cpu"
2095
  base_model = "black-forest-labs/FLUX.1-dev"
2096
 
2097
- #TAEF1 is very tiny autoencoder which uses the same "latent API" as FLUX.1's VAE. FLUX.1 is useful for real-time previewing of the FLUX.1 generation process.#
2098
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
2099
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
2100
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
2101
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model,
2102
- vae=good_vae,
2103
- transformer=pipe.transformer,
2104
- text_encoder=pipe.text_encoder,
2105
- tokenizer=pipe.tokenizer,
2106
- text_encoder_2=pipe.text_encoder_2,
2107
- tokenizer_2=pipe.tokenizer_2,
2108
- torch_dtype=dtype
2109
- )
 
 
2110
 
2111
  MAX_SEED = 2**32-1
2112
 
@@ -2210,7 +2218,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
2210
  pipe.unload_lora_weights()
2211
  pipe_i2i.unload_lora_weights()
2212
 
2213
- #LoRA weights flow
2214
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
2215
  pipe_to_use = pipe_i2i if image_input is not None else pipe
2216
  weight_name = selected_lora.get("weights", None)
@@ -2235,7 +2243,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
2235
  final_image = None
2236
  step_counter = 0
2237
  for image in image_generator:
2238
- step_counter+=1
2239
  final_image = image
2240
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
2241
  yield image, seed, gr.update(value=progress_bar, visible=True)
@@ -2243,41 +2251,37 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
2243
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
2244
 
2245
  def get_huggingface_safetensors(link):
2246
- split_link = link.split("/")
2247
- if(len(split_link) == 2):
2248
- model_card = ModelCard.load(link)
2249
- base_model = model_card.data.get("base_model")
2250
- print(base_model)
2251
 
2252
- #Allows Both
2253
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
2254
- raise Exception("Flux LoRA Not Found!")
2255
 
2256
- # Only allow "black-forest-labs/FLUX.1-dev"
2257
- #if base_model != "black-forest-labs/FLUX.1-dev":
2258
- #raise Exception("Only FLUX.1-dev is supported, other LoRA models are not allowed!")
2259
-
2260
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
2261
- trigger_word = model_card.data.get("instance_prompt", "")
2262
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
2263
- fs = HfFileSystem()
2264
- try:
2265
- list_of_files = fs.ls(link, detail=False)
2266
- for file in list_of_files:
2267
- if(file.endswith(".safetensors")):
2268
- safetensors_name = file.split("/")[-1]
2269
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
2270
- image_elements = file.split("/")
2271
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
2272
- except Exception as e:
2273
- print(e)
2274
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
2275
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
2276
- return split_link[1], link, safetensors_name, trigger_word, image_url
2277
 
2278
  def check_custom_model(link):
2279
- if(link.startswith("https://")):
2280
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
2281
  link_split = link.split("huggingface.co/")
2282
  return get_huggingface_safetensors(link_split[1])
2283
  else:
@@ -2285,7 +2289,7 @@ def check_custom_model(link):
2285
 
2286
  def add_custom_lora(custom_lora):
2287
  global loras
2288
- if(custom_lora):
2289
  try:
2290
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
2291
  print(f"Loaded custom LoRA: {repo}")
@@ -2302,7 +2306,7 @@ def add_custom_lora(custom_lora):
2302
  </div>
2303
  '''
2304
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
2305
- if(not existing_item_index):
2306
  new_item = {
2307
  "image": image,
2308
  "title": title,
@@ -2316,8 +2320,8 @@ def add_custom_lora(custom_lora):
2316
 
2317
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
2318
  except Exception as e:
2319
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
2320
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
2321
  else:
2322
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
2323
 
@@ -2371,7 +2375,7 @@ with gr.Blocks(theme="prithivMLmods/Minecraft-Theme", css=css, delete_cache=(60,
2371
  custom_lora_info = gr.HTML(visible=False)
2372
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
2373
  with gr.Column():
2374
- progress_bar = gr.Markdown(elem_id="progress",visible=False)
2375
  result = gr.Image(label="Generated Image")
2376
 
2377
  with gr.Row():
 
13
  from PIL import Image
14
  import gradio as gr
15
 
 
16
  from diffusers import (
17
  DiffusionPipeline,
18
  AutoencoderTiny,
 
25
  hf_hub_download,
26
  HfFileSystem,
27
  ModelCard,
28
+ snapshot_download,
29
+ login # imported for one-time authentication
30
+ )
31
 
32
  from diffusers.utils import load_image
33
 
34
  import spaces
35
 
36
+ # -------------------------------
37
+ # Authenticate with Hugging Face once
38
+ # -------------------------------
39
+ HF_TOKEN = os.environ.get("HF_TOKEN")
40
+ if HF_TOKEN:
41
+ login(HF_TOKEN)
42
+ print("Authenticated with Hugging Face.")
43
 
44
  def calculate_shift(
45
  image_seq_len,
 
2096
  ]
2097
 
2098
  #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
 
2099
  dtype = torch.bfloat16
2100
  device = "cuda" if torch.cuda.is_available() else "cpu"
2101
  base_model = "black-forest-labs/FLUX.1-dev"
2102
 
2103
+ # TAEF1 is a very tiny autoencoder using the same "latent API" as FLUX.1's VAE.
2104
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype, use_auth_token=HF_TOKEN).to(device)
2105
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype, use_auth_token=HF_TOKEN).to(device)
2106
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1, use_auth_token=HF_TOKEN).to(device)
2107
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
2108
+ base_model,
2109
+ vae=good_vae,
2110
+ transformer=pipe.transformer,
2111
+ text_encoder=pipe.text_encoder,
2112
+ tokenizer=pipe.tokenizer,
2113
+ text_encoder_2=pipe.text_encoder_2,
2114
+ tokenizer_2=pipe.tokenizer_2,
2115
+ torch_dtype=dtype,
2116
+ use_auth_token=HF_TOKEN
2117
+ )
2118
 
2119
  MAX_SEED = 2**32-1
2120
 
 
2218
  pipe.unload_lora_weights()
2219
  pipe_i2i.unload_lora_weights()
2220
 
2221
+ # LoRA weights flow
2222
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
2223
  pipe_to_use = pipe_i2i if image_input is not None else pipe
2224
  weight_name = selected_lora.get("weights", None)
 
2243
  final_image = None
2244
  step_counter = 0
2245
  for image in image_generator:
2246
+ step_counter += 1
2247
  final_image = image
2248
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
2249
  yield image, seed, gr.update(value=progress_bar, visible=True)
 
2251
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
2252
 
2253
  def get_huggingface_safetensors(link):
2254
+ split_link = link.split("/")
2255
+ if len(split_link) == 2:
2256
+ model_card = ModelCard.load(link)
2257
+ base_model = model_card.data.get("base_model")
2258
+ print(base_model)
2259
 
2260
+ # Allows Both
2261
+ if (base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell"):
2262
+ raise Exception("Flux LoRA Not Found!")
2263
 
2264
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
2265
+ trigger_word = model_card.data.get("instance_prompt", "")
2266
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
2267
+ fs = HfFileSystem()
2268
+ try:
2269
+ list_of_files = fs.ls(link, detail=False)
2270
+ for file in list_of_files:
2271
+ if file.endswith(".safetensors"):
2272
+ safetensors_name = file.split("/")[-1]
2273
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
2274
+ image_elements = file.split("/")
2275
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
2276
+ except Exception as e:
2277
+ print(e)
2278
+ gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
2279
+ raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
2280
+ return split_link[1], link, safetensors_name, trigger_word, image_url
 
 
 
 
2281
 
2282
  def check_custom_model(link):
2283
+ if link.startswith("https://"):
2284
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
2285
  link_split = link.split("huggingface.co/")
2286
  return get_huggingface_safetensors(link_split[1])
2287
  else:
 
2289
 
2290
  def add_custom_lora(custom_lora):
2291
  global loras
2292
+ if custom_lora:
2293
  try:
2294
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
2295
  print(f"Loaded custom LoRA: {repo}")
 
2306
  </div>
2307
  '''
2308
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
2309
+ if existing_item_index is None:
2310
  new_item = {
2311
  "image": image,
2312
  "title": title,
 
2320
 
2321
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
2322
  except Exception as e:
2323
+ gr.Warning("Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
2324
+ return gr.update(visible=True, value="Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=False), gr.update(), "", None, ""
2325
  else:
2326
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
2327
 
 
2375
  custom_lora_info = gr.HTML(visible=False)
2376
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
2377
  with gr.Column():
2378
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
2379
  result = gr.Image(label="Generated Image")
2380
 
2381
  with gr.Row():