prithivMLmods commited on
Commit
dddc2d6
·
verified ·
1 Parent(s): ea2a25d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -299
app.py CHANGED
@@ -1,329 +1,309 @@
1
  import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- import re
8
- from threading import Thread
9
-
10
  import gradio as gr
11
- import spaces
 
12
  import torch
13
- import numpy as np
14
  from PIL import Image
15
- import cv2
16
-
17
- from transformers import (
18
- AutoProcessor,
19
- Gemma3ForConditionalGeneration,
20
- Qwen2VLForConditionalGeneration,
21
- TextIteratorStreamer,
22
- )
23
- from transformers.image_utils import load_image
24
 
25
  # Constants
26
- MAX_MAX_NEW_TOKENS = 2048
27
- DEFAULT_MAX_NEW_TOKENS = 1024
28
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
- MAX_SEED = np.iinfo(np.int32).max
30
-
31
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
- # Helper function to return a progress bar HTML snippet.
34
- def progress_bar_html(label: str) -> str:
35
- return f'''
36
- <div style="display: flex; align-items: center;">
37
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
38
- <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
39
- <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
40
- </div>
41
- </div>
42
- <style>
43
- @keyframes loading {{
44
- 0% {{ transform: translateX(-100%); }}
45
- 100% {{ transform: translateX(100%); }}
46
- }}
47
- </style>
48
- '''
49
 
50
- # Qwen2-VL (for optional image inference)
 
 
 
 
 
 
 
 
51
 
52
- MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
- processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
54
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
55
- MODEL_ID_VL,
56
- trust_remote_code=True,
57
- torch_dtype=torch.float16
58
- ).to("cuda").eval()
59
 
60
- def clean_chat_history(chat_history):
61
- cleaned = []
62
- for msg in chat_history:
63
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
64
- cleaned.append(msg)
65
- return cleaned
66
 
67
- bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
68
- bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
69
- default_negative = os.getenv("default_negative", "")
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- def check_text(prompt, negative=""):
72
- for i in bad_words:
73
- if i in prompt:
74
- return True
75
- for i in bad_words_negative:
76
- if i in negative:
77
- return True
78
- return False
79
 
80
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
81
- if randomize_seed:
82
- seed = random.randint(0, MAX_SEED)
83
- return seed
 
 
 
 
84
 
85
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
86
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
87
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
88
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
 
 
 
 
 
89
 
90
- dtype = torch.float16 if device.type == "cuda" else torch.float32
 
 
 
 
91
 
 
 
 
 
 
 
 
92
 
93
- # Gemma3 Model (default for text, image, & video inference)
 
 
94
 
95
- gemma3_model_id = "google/gemma-3-4b-it" #[or] Duplicate the space to use 12b
96
- gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
97
- gemma3_model_id, device_map="auto"
98
- ).eval()
99
- gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
100
 
101
- # VIDEO PROCESSING HELPER
 
 
 
102
 
103
- def downsample_video(video_path):
104
- vidcap = cv2.VideoCapture(video_path)
105
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
106
- fps = vidcap.get(cv2.CAP_PROP_FPS)
107
- frames = []
108
- # Sample 10 evenly spaced frames.
109
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
110
- for i in frame_indices:
111
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
112
- success, image = vidcap.read()
113
- if success:
114
- # Convert from BGR to RGB and then to PIL Image.
115
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
116
- pil_image = Image.fromarray(image)
117
- timestamp = round(i / fps, 2)
118
- frames.append((pil_image, timestamp))
119
- vidcap.release()
120
- return frames
121
 
122
- # MAIN GENERATION FUNCTION
123
 
124
- @spaces.GPU
125
- def generate(
126
- input_dict: dict,
127
- chat_history: list[dict],
128
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
129
- temperature: float = 0.6,
130
- top_p: float = 0.9,
131
- top_k: int = 50,
132
- repetition_penalty: float = 1.2,
133
- ):
134
- text = input_dict["text"]
135
- files = input_dict.get("files", [])
136
- lower_text = text.lower().strip()
137
 
138
- # ----- Qwen2-VL branch (triggered with @qwen2-vl) -----
139
- if lower_text.startswith("@qwen2-vl"):
140
- prompt_clean = re.sub(r"@qwen2-vl", "", text, flags=re.IGNORECASE).strip().strip('"')
141
- if files:
142
- images = [load_image(f) for f in files]
143
- messages = [{
144
- "role": "user",
145
- "content": [
146
- *[{"type": "image", "image": image} for image in images],
147
- {"type": "text", "text": prompt_clean},
148
- ]
149
- }]
150
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
152
  else:
153
- messages = [
154
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
155
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
156
- ]
157
- inputs = processor.apply_chat_template(
158
- messages, add_generation_prompt=True, tokenize=True,
159
- return_dict=True, return_tensors="pt"
160
- ).to("cuda", dtype=torch.float16)
161
- streamer = TextIteratorStreamer(processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
162
- generation_kwargs = {
163
- **inputs,
164
- "streamer": streamer,
165
- "max_new_tokens": max_new_tokens,
166
- "do_sample": True,
167
- "temperature": temperature,
168
- "top_p": top_p,
169
- "top_k": top_k,
170
- "repetition_penalty": repetition_penalty,
171
- }
172
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
173
- thread.start()
174
- buffer = ""
175
- yield progress_bar_html("Processing with Qwen2VL")
176
- for new_text in streamer:
177
- buffer += new_text
178
- buffer = buffer.replace("<|im_end|>", "")
179
- time.sleep(0.01)
180
- yield buffer
181
- return
182
 
183
- # ----- Default branch: Gemma3 (for text, image, & video inference) -----
184
- if files:
185
- # Check if any provided file is a video based on extension.
186
- video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
187
- if any(str(f).lower().endswith(video_extensions) for f in files):
188
- # Video inference branch.
189
- prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
190
- video_path = files[0]
191
- frames = downsample_video(video_path)
192
- messages = [
193
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
194
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
195
- ]
196
- # Append each frame (with its timestamp) to the conversation.
197
- for frame in frames:
198
- image, timestamp = frame
199
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
200
- image.save(image_path)
201
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
202
- messages[1]["content"].append({"type": "image", "url": image_path})
203
- inputs = gemma3_processor.apply_chat_template(
204
- messages, add_generation_prompt=True, tokenize=True,
205
- return_dict=True, return_tensors="pt"
206
- ).to(gemma3_model.device, dtype=torch.bfloat16)
207
- streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
208
- generation_kwargs = {
209
- **inputs,
210
- "streamer": streamer,
211
- "max_new_tokens": max_new_tokens,
212
- "do_sample": True,
213
- "temperature": temperature,
214
- "top_p": top_p,
215
- "top_k": top_k,
216
- "repetition_penalty": repetition_penalty,
217
- }
218
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
219
- thread.start()
220
- buffer = ""
221
- yield progress_bar_html("Processing video with Gemma3")
222
- for new_text in streamer:
223
- buffer += new_text
224
- time.sleep(0.01)
225
- yield buffer
226
- return
227
  else:
228
- # Image inference branch.
229
- prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
230
- images = [load_image(f) for f in files]
231
- messages = [{
232
- "role": "user",
233
- "content": [
234
- *[{"type": "image", "image": image} for image in images],
235
- {"type": "text", "text": prompt_clean},
236
- ]
237
- }]
238
- inputs = gemma3_processor.apply_chat_template(
239
- messages, tokenize=True, add_generation_prompt=True,
240
- return_dict=True, return_tensors="pt"
241
- ).to(gemma3_model.device, dtype=torch.bfloat16)
242
- streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
243
- generation_kwargs = {
244
- **inputs,
245
- "streamer": streamer,
246
- "max_new_tokens": max_new_tokens,
247
- "do_sample": True,
248
- "temperature": temperature,
249
- "top_p": top_p,
250
- "top_k": top_k,
251
- "repetition_penalty": repetition_penalty,
252
- }
253
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
254
- thread.start()
255
- buffer = ""
256
- yield progress_bar_html("Processing with Gemma3")
257
- for new_text in streamer:
258
- buffer += new_text
259
- time.sleep(0.01)
260
- yield buffer
261
- return
262
- else:
263
- # Text-only inference branch.
264
- messages = [
265
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
266
- {"role": "user", "content": [{"type": "text", "text": text}]}
267
- ]
268
- inputs = gemma3_processor.apply_chat_template(
269
- messages, add_generation_prompt=True, tokenize=True,
270
- return_dict=True, return_tensors="pt"
271
- ).to(gemma3_model.device, dtype=torch.bfloat16)
272
- streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
273
- generation_kwargs = {
274
- **inputs,
275
- "streamer": streamer,
276
- "max_new_tokens": max_new_tokens,
277
- "do_sample": True,
278
- "temperature": temperature,
279
- "top_p": top_p,
280
- "top_k": top_k,
281
- "repetition_penalty": repetition_penalty,
282
- }
283
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
284
- thread.start()
285
- outputs = []
286
- for new_text in streamer:
287
- outputs.append(new_text)
288
- yield "".join(outputs)
289
- final_response = "".join(outputs)
290
- yield final_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- # Gradio Interface
 
 
 
 
 
 
 
 
 
 
294
 
295
- demo = gr.ChatInterface(
296
- fn=generate,
297
- additional_inputs=[
298
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
299
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
300
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
301
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
302
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
303
- ],
304
- examples=[
305
- [{"text": "Create a short story based on the image.","files": ["examples/1111.jpg"]}],
306
- [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
307
- [{"text": "Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
308
- [{"text": "Which movie character is this?", "files": ["examples/9999.jpg"]}],
309
- ["Explain Critical Temperature of Substance"],
310
- [{"text": "@qwen2-vl Transcription of the letter", "files": ["examples/222.png"]}],
311
- [{"text": "Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
312
- [{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
313
- [{"text": "Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
314
- [{"text": "Summarize the events in this video", "files": ["examples/sky.mp4"]}],
315
- [{"text": "What is in the video ?", "files": ["examples/redlight.mp4"]}],
316
- ["Python Program for Array Rotation"],
317
- ["Explain Critical Temperature of Substance"]
318
- ],
319
- cache_examples=False,
320
- type="messages",
321
- description="# **Gemma 3 Multimodal** \n`Use @qwen2-vl to switch to Qwen2-VL OCR for image inference and @video-infer for video input`",
322
- fill_height=True,
323
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag with @qwen2-vl for Qwen2-VL inference if needed."),
324
- stop_btn="Stop Generation",
325
- multimodal=True,
326
- )
327
 
328
- if __name__ == "__main__":
329
- demo.queue(max_size=20).launch(share=True)
 
1
  import os
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
+ import json
4
+ import logging
5
  import torch
 
6
  from PIL import Image
7
+ import random
8
+ import time
9
+ from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
10
+ from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
11
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
12
+ from huggingface_hub import ModelCard
 
 
 
13
 
14
  # Constants
15
+ MODEL_PREFIX = "HiDream-ai"
16
+ LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 
 
 
 
17
 
18
+ FAST_MODEL_CONFIG = {
19
+ "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
20
+ "guidance_scale": 5.0,
21
+ "num_inference_steps": 50,
22
+ "shift": 3.0,
23
+ "scheduler": FlowUniPCMultistepScheduler
24
+ }
 
 
 
 
 
 
 
 
 
25
 
26
+ RESOLUTION_OPTIONS = [
27
+ "1024 × 1024 (Square)",
28
+ "768 × 1360 (Portrait)",
29
+ "1360 × 768 (Landscape)",
30
+ "880 × 1168 (Portrait)",
31
+ "1168 × 880 (Landscape)",
32
+ "1248 × 832 (Landscape)",
33
+ "832 × 1248 (Portrait)"
34
+ ]
35
 
36
+ # Load LoRAs from JSON file (assumed to be compatible with Hi-Dream)
37
+ with open('loras.json', 'r') as f:
38
+ loras = json.load(f)
 
 
 
 
39
 
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ MAX_SEED = 2**32 - 1
 
 
 
 
42
 
43
+ # Parse resolution string to height and width
44
+ def parse_resolution(res_str):
45
+ mapping = {
46
+ "1024 × 1024": (1024, 1024),
47
+ "768 × 1360": (768, 1360),
48
+ "1360 × 768": (1360, 768),
49
+ "880 × 1168": (880, 1168),
50
+ "1168 × 880": (1168, 880),
51
+ "1248 × 832": (1248, 832),
52
+ "832 × 1248": (832, 1248)
53
+ }
54
+ for key, (h, w) in mapping.items():
55
+ if key in res_str:
56
+ return h, w
57
+ return 1024, 1024 # fallback
58
 
59
+ # Load the Hi-Dream Fast Model pipeline
60
+ pipe, MODEL_CONFIG = None, None
 
 
 
 
 
 
61
 
62
+ def load_fast_model():
63
+ global pipe, MODEL_CONFIG
64
+ config = FAST_MODEL_CONFIG
65
+ scheduler = config["scheduler"](
66
+ num_train_timesteps=1000,
67
+ shift=config["shift"],
68
+ use_dynamic_shifting=False
69
+ )
70
 
71
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
72
+ LLAMA_MODEL_NAME,
73
+ use_fast=False
74
+ )
75
+ text_encoder = LlamaForCausalLM.from_pretrained(
76
+ LLAMA_MODEL_NAME,
77
+ output_hidden_states=True,
78
+ output_attentions=True,
79
+ torch_dtype=torch.bfloat16
80
+ ).to(device)
81
 
82
+ transformer = HiDreamImageTransformer2DModel.from_pretrained(
83
+ config["path"],
84
+ subfolder="transformer",
85
+ torch_dtype=torch.bfloat16
86
+ ).to(device)
87
 
88
+ pipe = HiDreamImagePipeline.from_pretrained(
89
+ config["path"],
90
+ scheduler=scheduler,
91
+ tokenizer_4=tokenizer,
92
+ text_encoder_4=text_encoder,
93
+ torch_dtype=torch.bfloat16
94
+ ).to(device, torch.bfloat16)
95
 
96
+ pipe.transformer = transformer
97
+ MODEL_CONFIG = config
98
+ return pipe, config
99
 
100
+ # Generate image
101
+ def generate_image(prompt, resolution, seed, guidance_scale, num_inference_steps):
102
+ global pipe, MODEL_CONFIG
103
+ if pipe is None:
104
+ pipe, MODEL_CONFIG = load_fast_model()
105
 
106
+ height, width = parse_resolution(resolution)
107
+ if seed == -1 or seed is None:
108
+ seed = random.randint(0, MAX_SEED)
109
+ generator = torch.Generator(device=device).manual_seed(int(seed))
110
 
111
+ result = pipe(
112
+ prompt=prompt,
113
+ height=height,
114
+ width=width,
115
+ guidance_scale=guidance_scale,
116
+ num_inference_steps=num_inference_steps,
117
+ num_images_per_prompt=1,
118
+ generator=generator
119
+ )
 
 
 
 
 
 
 
 
 
120
 
121
+ return result.images[0], seed
122
 
123
+ class calculateDuration:
124
+ def __init__(self, activity_name=""):
125
+ self.activity_name = activity_name
 
 
 
 
 
 
 
 
 
 
126
 
127
+ def __enter__(self):
128
+ self.start_time = time.time()
129
+ return self
130
+
131
+ def __exit__(self, exc_type, exc_value, traceback):
132
+ self.end_time = time.time()
133
+ self.elapsed_time = self.end_time - self.start_time
134
+ if self.activity_name:
135
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
 
 
 
 
 
136
  else:
137
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ def update_selection(evt: gr.SelectData, resolution):
140
+ selected_lora = loras[evt.index]
141
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
142
+ lora_repo = selected_lora["repo"]
143
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
144
+ if "aspect" in selected_lora:
145
+ if selected_lora["aspect"] == "portrait":
146
+ resolution = "768 × 1360 (Portrait)"
147
+ elif selected_lora["aspect"] == "landscape":
148
+ resolution = "1360 × 768 (Landscape)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  else:
150
+ resolution = "1024 × 1024 (Square)"
151
+ return (
152
+ gr.update(placeholder=new_placeholder),
153
+ updated_text,
154
+ evt.index,
155
+ resolution,
156
+ )
157
+
158
+ def run_lora(prompt, resolution, cfg_scale, steps, selected_index, randomize_seed, seed):
159
+ global pipe
160
+ if pipe is None:
161
+ pipe, _ = load_fast_model()
162
+
163
+ if selected_index is not None:
164
+ selected_lora = loras[selected_index]
165
+ lora_path = selected_lora["repo"]
166
+ weight_name = selected_lora.get("weights", None)
167
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
168
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True)
169
+ trigger_word = selected_lora.get("trigger_word", "")
170
+ if trigger_word:
171
+ if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
172
+ prompt = f"{trigger_word} {prompt}"
173
+ else:
174
+ prompt = f"{prompt} {trigger_word}"
175
+
176
+ if randomize_seed:
177
+ seed = random.randint(0, MAX_SEED)
178
+
179
+ with calculateDuration("Generating image"):
180
+ final_image, used_seed = generate_image(prompt, resolution, seed, cfg_scale, steps)
181
+ return final_image, used_seed
182
+
183
+ def check_custom_model(link):
184
+ split_link = link.split("/")
185
+ if len(split_link) != 2:
186
+ raise Exception("Invalid Hugging Face repository link format.")
187
+ model_card = ModelCard.load(link)
188
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
189
+ trigger_word = model_card.data.get("instance_prompt", "")
190
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
191
+ safetensors_name = None # Simplified; assumes a safetensors file exists
192
+ return split_link[1], link, safetensors_name, trigger_word, image_url
193
+
194
+ def add_custom_lora(custom_lora):
195
+ global loras
196
+ if custom_lora:
197
+ try:
198
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
199
+ card = f'''
200
+ <div class="custom_lora_card">
201
+ <span>Loaded custom LoRA:</span>
202
+ <div class="card_internal">
203
+ <img src="{image}" />
204
+ <div>
205
+ <h3>{title}</h3>
206
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found."}</small>
207
+ </div>
208
+ </div>
209
+ </div>
210
+ '''
211
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
212
+ if not existing_item_index:
213
+ new_item = {
214
+ "image": image,
215
+ "title": title,
216
+ "repo": repo,
217
+ "weights": path,
218
+ "trigger_word": trigger_word
219
+ }
220
+ existing_item_index = len(loras)
221
+ loras.append(new_item)
222
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
223
+ except Exception as e:
224
+ gr.Warning(f"Invalid LoRA: {str(e)}")
225
+ return gr.update(visible=True, value=f"Invalid LoRA: {str(e)}"), gr.update(visible=True), gr.update(), "", None, ""
226
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
227
+
228
+ def remove_custom_lora():
229
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
230
+
231
+ css = '''
232
+ #gen_btn{height: 100%}
233
+ #gen_column{align-self: stretch}
234
+ #title{text-align: center}
235
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
236
+ #title img{width: 100px; margin-right: 0.5em}
237
+ #gallery .grid-wrap{height: 10vh}
238
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
239
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
240
+ .card_internal img{margin-right: 1em}
241
+ .styler{--form-gap-width: 0px !important}
242
+ '''
243
 
244
+ font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
245
+ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app:
246
+ title = gr.HTML(
247
+ """<h1>Hi-Dream Full LoRA DLC 🤩</h1>""",
248
+ elem_id="title",
249
+ )
250
+ selected_index = gr.State(None)
251
+ with gr.Row():
252
+ with gr.Column(scale=3):
253
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
254
+ with gr.Column(scale=1, elem_id="gen_column"):
255
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
256
+ with gr.Row():
257
+ with gr.Column():
258
+ selected_info = gr.Markdown("")
259
+ gallery = gr.Gallery(
260
+ [(item["image"], item["title"]) for item in loras],
261
+ label="LoRA Gallery",
262
+ allow_preview=False,
263
+ columns=3,
264
+ elem_id="gallery",
265
+ show_share_button=False
266
+ )
267
+ with gr.Group():
268
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux")
269
+ gr.Markdown("[Check the list of Hi-Dream LoRAs]", elem_id="lora_list")
270
+ custom_lora_info = gr.HTML(visible=False)
271
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
272
+ with gr.Column():
273
+ result = gr.Image(label="Generated Image")
274
 
275
+ with gr.Row():
276
+ with gr.Accordion("Advanced Settings", open=False):
277
+ cfg_scale = gr.Slider(label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=FAST_MODEL_CONFIG["guidance_scale"])
278
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=FAST_MODEL_CONFIG["num_inference_steps"])
279
+ resolution = gr.Radio(
280
+ choices=RESOLUTION_OPTIONS,
281
+ value=RESOLUTION_OPTIONS[0],
282
+ label="Resolution"
283
+ )
284
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
285
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
286
 
287
+ gallery.select(
288
+ update_selection,
289
+ inputs=[resolution],
290
+ outputs=[prompt, selected_info, selected_index, resolution]
291
+ )
292
+ custom_lora.input(
293
+ add_custom_lora,
294
+ inputs=[custom_lora],
295
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
296
+ )
297
+ custom_lora_button.click(
298
+ remove_custom_lora,
299
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
300
+ )
301
+ gr.on(
302
+ triggers=[generate_button.click, prompt.submit],
303
+ fn=run_lora,
304
+ inputs=[prompt, resolution, cfg_scale, steps, selected_index, randomize_seed, seed],
305
+ outputs=[result, seed]
306
+ )
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ app.queue()
309
+ app.launch()