alvdansen commited on
Commit
6d9c61b
Β·
verified Β·
1 Parent(s): e5e853f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -149
app.py CHANGED
@@ -1,24 +1,37 @@
1
  import json
2
  import random
3
- import requests
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
- from PIL import Image
10
- import os
11
 
12
- # Load the JSON data
13
  with open("sdxl_lora.json", "r") as file:
14
  data = json.load(file)
15
- sdxl_loras_raw = sorted(data, key=lambda x: x["likes"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
19
 
20
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
21
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
 
22
  pipe.to(device=DEVICE, dtype=torch.float16)
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
@@ -29,46 +42,6 @@ def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
29
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
30
  return lora_id, trigger_word
31
 
32
- def load_lora_for_style(style_repo):
33
- pipe.unload_lora_weights()
34
- pipe.load_lora_weights(style_repo, adapter_name="lora")
35
-
36
- def get_image(image_data):
37
- if isinstance(image_data, str):
38
- return image_data
39
-
40
- if isinstance(image_data, dict):
41
- local_path = image_data.get('local_path')
42
- hf_url = image_data.get('hf_url')
43
- else:
44
- print(f"Unexpected image_data format: {type(image_data)}")
45
- return None
46
-
47
- # Try loading from local path first
48
- if local_path and os.path.exists(local_path):
49
- try:
50
- Image.open(local_path).verify() # Verify that it's a valid image
51
- return local_path
52
- except Exception as e:
53
- print(f"Error loading local image {local_path}: {e}")
54
-
55
- # If local path fails or doesn't exist, try URL
56
- if hf_url:
57
- try:
58
- response = requests.get(hf_url)
59
- if response.status_code == 200:
60
- img = Image.open(requests.get(hf_url, stream=True).raw)
61
- img.verify() # Verify that it's a valid image
62
- img.save(local_path) # Save for future use
63
- return local_path
64
- else:
65
- print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
66
- except Exception as e:
67
- print(f"Error loading image from URL {hf_url}: {e}")
68
-
69
- print(f"Failed to load image for {image_data}")
70
- return None
71
-
72
  @spaces.GPU
73
  def infer(
74
  pre_prompt,
@@ -82,7 +55,19 @@ def infer(
82
  user_lora_weight,
83
  progress=gr.Progress(track_tqdm=True),
84
  ):
85
- load_lora_for_style(user_lora_selector)
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  if randomize_seed:
88
  seed = random.randint(0, MAX_SEED)
@@ -103,129 +88,141 @@ def infer(
103
  return image
104
 
105
  css = """
106
- body {
107
- background-color: #1a1a1a;
108
- color: #ffffff;
109
- }
110
- .container {
111
- max-width: 900px;
112
- margin: auto;
113
- padding: 20px;
114
- }
115
- h1, h2 {
116
- color: #4CAF50;
117
  text-align: center;
 
118
  }
119
- .gallery {
120
- display: flex;
121
- flex-wrap: wrap;
122
- justify-content: center;
123
- }
124
- .gallery img {
125
- margin: 10px;
126
- border-radius: 10px;
127
- transition: transform 0.3s ease;
128
- }
129
- .gallery img:hover {
130
- transform: scale(1.05);
131
- }
132
- .gradio-slider input[type="range"] {
133
- background-color: #4CAF50;
134
- }
135
- .gradio-button {
136
- background-color: #4CAF50 !important;
137
  }
138
  """
139
 
 
 
 
 
 
140
  with gr.Blocks(css=css) as demo:
141
  gr.Markdown(
142
- """
143
- # ⚑ FlashDiffusion: Araminta K's FlashLoRA Showcase ⚑
144
-
145
- This interactive demo showcases [Araminta K's models](https://huggingface.co/alvdansen) using [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) technology.
146
-
147
- ## Acknowledgments
148
- - Original Flash Diffusion technology by the Jasper AI team
149
- - Based on the paper: [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin
150
- - Models showcased here are created by Araminta K at Alvdansen Labs
151
-
152
- Explore the power of FlashLoRA with Araminta K's unique artistic styles!
153
- """
154
  )
155
 
156
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
157
  gr_lora_id = gr.State(value="")
158
 
159
  with gr.Row():
160
- with gr.Column(scale=2):
161
- gallery = gr.Gallery(
162
- value=[(get_image(item["image"]), item["title"]) for item in sdxl_loras_raw if get_image(item["image"]) is not None],
163
- label="SDXL LoRA Gallery",
164
- show_label=False,
165
- elem_id="gallery",
166
- columns=3,
167
- height=600,
168
- )
169
-
170
- user_lora_selector = gr.Textbox(
171
- label="Current Selected LoRA",
172
- interactive=False,
173
- )
174
-
175
- with gr.Column(scale=3):
176
- prompt = gr.Textbox(
177
- label="Prompt",
178
- placeholder="Enter your prompt",
179
- lines=3,
180
- )
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  with gr.Row():
183
- run_button = gr.Button("Run", variant="primary")
184
- clear_button = gr.Button("Clear")
 
 
 
 
 
 
 
 
185
 
186
- result = gr.Image(label="Result", height=512)
187
 
188
  with gr.Accordion("Advanced Settings", open=False):
189
- pre_prompt = gr.Textbox(
190
  label="Pre-Prompt",
 
 
191
  placeholder="Pre Prompt from the LoRA config",
192
- lines=2,
 
193
  )
194
 
 
 
 
 
 
 
 
 
 
 
195
  with gr.Row():
196
- seed = gr.Slider(
197
- label="Seed",
198
- minimum=0,
199
- maximum=MAX_SEED,
200
  step=1,
201
- value=0,
202
  )
203
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
204
 
205
- num_inference_steps = gr.Slider(
206
- label="Number of inference steps",
207
- minimum=4,
208
- maximum=8,
209
- step=1,
210
- value=4,
211
- )
 
212
 
213
- guidance_scale = gr.Slider(
214
- label="Guidance Scale",
215
- minimum=1,
216
- maximum=6,
217
- step=0.5,
218
- value=1,
219
  )
220
 
221
- negative_prompt = gr.Textbox(
222
  label="Negative Prompt",
 
 
223
  placeholder="Enter a negative Prompt",
224
- lines=2,
225
  )
226
 
227
  gr.on(
228
- [run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
229
  fn=infer,
230
  inputs=[
231
  pre_prompt,
@@ -236,30 +233,24 @@ with gr.Blocks(css=css) as demo:
236
  negative_prompt,
237
  guidance_scale,
238
  user_lora_selector,
239
- gr.Slider(label="Selected LoRA Weight", minimum=0.5, maximum=3, step=0.1, value=1),
240
  ],
241
  outputs=[result],
242
  )
243
 
244
- clear_button.click(lambda: "", outputs=[prompt, result])
245
-
246
  gallery.select(
247
  fn=update_selection,
248
  inputs=[gr_sdxl_loras],
249
- outputs=[user_lora_selector, pre_prompt],
 
 
 
 
250
  )
251
 
 
252
  gr.Markdown(
253
- """
254
- ## Unleash Your Creativity!
255
-
256
- This showcase brings together the speed of Flash Diffusion and the artistic flair of Araminta K's models.
257
- Craft your prompts, adjust the settings, and watch as AI brings your ideas to life in stunning detail.
258
-
259
- Remember to use this tool ethically and respect copyright and individual privacy.
260
-
261
- Enjoy exploring these unique artistic styles!
262
- """
263
  )
264
 
265
  demo.queue().launch()
 
1
  import json
2
  import random
3
+
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
 
 
9
 
 
10
  with open("sdxl_lora.json", "r") as file:
11
  data = json.load(file)
12
+ sdxl_loras_raw = [
13
+ {
14
+ "image": item["image"],
15
+ "title": item["title"],
16
+ "repo": item["repo"],
17
+ "trigger_word": item["trigger_word"],
18
+ "weights": item["weights"],
19
+ "is_pivotal": item.get("is_pivotal", False),
20
+ "text_embedding_weights": item.get("text_embedding_weights", None),
21
+ "likes": item.get("likes", 0),
22
+ }
23
+ for item in data
24
+ ]
25
+
26
+ # Sort the loras by likes
27
+ sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
33
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
34
+ pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="flash_lora")
35
  pipe.to(device=DEVICE, dtype=torch.float16)
36
 
37
  MAX_SEED = np.iinfo(np.int32).max
 
42
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
43
  return lora_id, trigger_word
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @spaces.GPU
46
  def infer(
47
  pre_prompt,
 
55
  user_lora_weight,
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
58
+ flash_sdxl_id = "jasperai/flash-sdxl"
59
+
60
+ new_adapter_id = user_lora_selector.replace("/", "_")
61
+ loaded_adapters = pipe.get_list_adapters()
62
+
63
+ if new_adapter_id not in loaded_adapters["unet"]:
64
+ gr.Info("Loading new LoRA")
65
+ pipe.unload_lora_weights()
66
+ pipe.load_lora_weights(flash_sdxl_id, adapter_name="flash_lora")
67
+ pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
68
+
69
+ pipe.set_adapters(["flash_lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
70
+ gr.Info("LoRA setup complete")
71
 
72
  if randomize_seed:
73
  seed = random.randint(0, MAX_SEED)
 
88
  return image
89
 
90
  css = """
91
+ h1 {
 
 
 
 
 
 
 
 
 
 
92
  text-align: center;
93
+ display:block;
94
  }
95
+ p {
96
+ text-align: justify;
97
+ display:block;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  }
99
  """
100
 
101
+ if torch.cuda.is_available():
102
+ power_device = "GPU"
103
+ else:
104
+ power_device = "CPU"
105
+
106
  with gr.Blocks(css=css) as demo:
107
  gr.Markdown(
108
+ f"""
109
+ # ⚑ FlashDiffusion: FlashLoRA ⚑
110
+ This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
111
+
112
+ The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
113
+ The LoRAs can be added **without** any retraining for similar results in most cases. Feel free to tweak the parameters and use your own LoRAs by giving a look at the [Github Repo](https://github.com/gojasper/flash-diffusion)
114
+ """
115
+ )
116
+ gr.Markdown(
117
+ "If you enjoy the space, please also promote *open-source* by giving a ⭐ to our repo [![GitHub Stars](https://img.shields.io/github/stars/gojasper/flash-diffusion?style=social)](https://github.com/gojasper/flash-diffusion)"
 
 
118
  )
119
 
120
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
121
  gr_lora_id = gr.State(value="")
122
 
123
  with gr.Row():
124
+ with gr.Blocks():
125
+ with gr.Column():
126
+ user_lora_selector = gr.Textbox(
127
+ label="Current Selected LoRA",
128
+ max_lines=1,
129
+ interactive=False,
130
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ user_lora_weight = gr.Slider(
133
+ label="Selected LoRA Weight",
134
+ minimum=0.5,
135
+ maximum=3,
136
+ step=0.1,
137
+ value=1,
138
+ )
139
+
140
+ gallery = gr.Gallery(
141
+ value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
142
+ label="SDXL LoRA Gallery",
143
+ allow_preview=False,
144
+ columns=3,
145
+ elem_id="gallery",
146
+ show_share_button=False,
147
+ )
148
+
149
+ with gr.Column():
150
  with gr.Row():
151
+ prompt = gr.Text(
152
+ label="Prompt",
153
+ show_label=False,
154
+ max_lines=1,
155
+ placeholder="Enter your prompt",
156
+ container=False,
157
+ scale=5,
158
+ )
159
+
160
+ run_button = gr.Button("Run", scale=1)
161
 
162
+ result = gr.Image(label="Result", show_label=False)
163
 
164
  with gr.Accordion("Advanced Settings", open=False):
165
+ pre_prompt = gr.Text(
166
  label="Pre-Prompt",
167
+ show_label=True,
168
+ max_lines=1,
169
  placeholder="Pre Prompt from the LoRA config",
170
+ container=True,
171
+ scale=5,
172
  )
173
 
174
+ seed = gr.Slider(
175
+ label="Seed",
176
+ minimum=0,
177
+ maximum=MAX_SEED,
178
+ step=1,
179
+ value=0,
180
+ )
181
+
182
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
183
+
184
  with gr.Row():
185
+ num_inference_steps = gr.Slider(
186
+ label="Number of inference steps",
187
+ minimum=4,
188
+ maximum=8,
189
  step=1,
190
+ value=4,
191
  )
 
192
 
193
+ with gr.Row():
194
+ guidance_scale = gr.Slider(
195
+ label="Guidance Scale",
196
+ minimum=1,
197
+ maximum=6,
198
+ step=0.5,
199
+ value=1,
200
+ )
201
 
202
+ hint_negative = gr.Markdown(
203
+ """πŸ’‘ _Hint : Negative Prompt will only work with Guidance > 1 but the model was
204
+ trained to be used with guidance = 1 (ie. without guidance).
205
+ Can degrade the results, use cautiously._"""
 
 
206
  )
207
 
208
+ negative_prompt = gr.Text(
209
  label="Negative Prompt",
210
+ show_label=False,
211
+ max_lines=1,
212
  placeholder="Enter a negative Prompt",
213
+ container=False,
214
  )
215
 
216
  gr.on(
217
+ [
218
+ run_button.click,
219
+ seed.change,
220
+ randomize_seed.change,
221
+ prompt.submit,
222
+ negative_prompt.change,
223
+ negative_prompt.submit,
224
+ guidance_scale.change,
225
+ ],
226
  fn=infer,
227
  inputs=[
228
  pre_prompt,
 
233
  negative_prompt,
234
  guidance_scale,
235
  user_lora_selector,
236
+ user_lora_weight,
237
  ],
238
  outputs=[result],
239
  )
240
 
 
 
241
  gallery.select(
242
  fn=update_selection,
243
  inputs=[gr_sdxl_loras],
244
+ outputs=[
245
+ user_lora_selector,
246
+ pre_prompt,
247
+ ],
248
+ show_progress="hidden",
249
  )
250
 
251
+ gr.Markdown("**Disclaimer:**")
252
  gr.Markdown(
253
+ "This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
 
 
 
 
 
 
 
 
 
254
  )
255
 
256
  demo.queue().launch()