nftnik commited on
Commit
8d6c96d
·
verified ·
1 Parent(s): fd95bae

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ import requests
8
+ import re
9
+ from diffusers import FluxPipeline
10
+ from translatepy import Translator
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # Environment setup
14
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15
+ translator = Translator()
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
+
18
+ # Constants
19
+ MODEL_ID = "black-forest-labs/FLUX.1-dev"
20
+ DEFAULT_LORA = "nftnik/BR_ohwx_V1"
21
+ DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+
24
+ CSS = """
25
+ footer {
26
+ visibility: hidden;
27
+ }
28
+ """
29
+
30
+ JS = """function () {
31
+ gradioURL = window.location.href;
32
+ if (!gradioURL.endsWith('?__theme=dark')) {
33
+ window.location.replace(gradioURL + '?__theme=dark');
34
+ }
35
+ }"""
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ print(f"Using {device.upper()}")
39
+
40
+ # Initialize pipeline and load default LoRA weights
41
+ pipe = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
42
+ pipe.load_lora_weights(DEFAULT_LORA, weight_name=DEFAULT_WEIGHT_NAME)
43
+
44
+ def scrape_lora_link(url: str):
45
+ try:
46
+ response = requests.get(url)
47
+ response.raise_for_status()
48
+ content = response.text
49
+ pattern = r'href="(.*?lora.*?\.safetensors\?download=true)"'
50
+ pattern2 = r'href="(.*?\.safetensors\?download=true)"'
51
+ match = re.search(pattern, content)
52
+ match2 = re.search(pattern2, content)
53
+ if match:
54
+ safetensors_url = match.group(1)
55
+ filename = safetensors_url.split('/')[-1].split('?')[0]
56
+ return filename
57
+ elif match2:
58
+ safetensors_url = match2.group(1)
59
+ filename = safetensors_url.split('/')[-1].split('?')[0]
60
+ return filename
61
+ else:
62
+ return None
63
+ except requests.RequestException as e:
64
+ raise gr.Error(f"An error occurred while fetching the URL: {e}")
65
+
66
+ def enable_lora(lora_add: str, progress=gr.Progress(track_tqdm=True)):
67
+ pipe.unload_lora_weights()
68
+ if not lora_add:
69
+ gr.Info("No LoRA Loaded, using base model")
70
+ return gr.update(value="")
71
+ else:
72
+ url = f"https://huggingface.co/{lora_add}/tree/main"
73
+ lora_name = scrape_lora_link(url)
74
+ if lora_name:
75
+ print(f"Loading LoRA: {lora_add}/{lora_name}")
76
+ pipe.load_lora_weights(lora_add, weight_name=lora_name)
77
+ gr.Info(f"{lora_add} Loaded")
78
+ return gr.update(label="LoRA Loaded Now")
79
+ else:
80
+ try:
81
+ pipe.load_lora_weights(lora_add)
82
+ print(f"Loading LoRA: {lora_add}")
83
+ gr.Info(f"{lora_add} Loaded")
84
+ return gr.update(label="LoRA Loaded Now")
85
+ except Exception as e:
86
+ raise gr.Error(f"{lora_add} load failed: {e}")
87
+
88
+ # Placeholder function to update flux scheduler and sampler settings.
89
+ def update_flux_settings(scheduler_choice: str, sampler_choice: str):
90
+ # Replace the code below with actual logic to update your pipeline's scheduler/sampler.
91
+ print(f"Setting scheduler to {scheduler_choice} and sampler to {sampler_choice}")
92
+ # e.g.:
93
+ # pipe.scheduler = SchedulerClassMapping[scheduler_choice].from_config(pipe.scheduler.config)
94
+ # pipe.sampler = SamplerClassMapping[sampler_choice](**pipe.sampler_config)
95
+ return f"Scheduler set to {scheduler_choice} and Sampler set to {sampler_choice}"
96
+
97
+ @spaces.GPU()
98
+ def generate_image(
99
+ prompt: str,
100
+ lora_word: str,
101
+ lora_scale: float = 0.8,
102
+ width: int = 896,
103
+ height: int = 1152,
104
+ guidance_scale: float = 3.5,
105
+ steps: int = 25,
106
+ seed: int = -1,
107
+ nums: int = 1,
108
+ progress=gr.Progress(track_tqdm=True)
109
+ ):
110
+ # Ensure the pipeline is on the correct device.
111
+ pipe.to(device)
112
+ if seed == -1:
113
+ seed = random.randint(0, MAX_SEED)
114
+ seed = int(seed)
115
+
116
+ # Translate prompt to English.
117
+ translation = translator.translate(prompt, "English")
118
+ prompt_english = str(translation) # Adjust if translatepy returns a different attribute.
119
+ full_prompt = f"{prompt_english} {lora_word}"
120
+ print(f"Prompt: {full_prompt}")
121
+
122
+ generator = torch.Generator().manual_seed(seed)
123
+ result = pipe(
124
+ prompt=full_prompt,
125
+ height=height,
126
+ width=width,
127
+ guidance_scale=guidance_scale,
128
+ output_type="pil",
129
+ num_inference_steps=steps,
130
+ max_sequence_length=512,
131
+ num_images_per_prompt=nums,
132
+ generator=generator,
133
+ joint_attention_kwargs={"scale": lora_scale},
134
+ )
135
+ return result.images, seed
136
+
137
+ # Sample examples
138
+ examples = [
139
+ ["close-up portrait of a futuristic alien in cyberpunk attire", "ohwx", 0.9],
140
+ ["full-body shot of an alien running through a neon-lit cityscape", "ohwx", 0.9],
141
+ ["portrait of a blue alien with techwear in a virtual reality environment", "ohwx", 0.9],
142
+ ["cyberpunk style portrait of an alien with glowing eyes", "ohwx", 0.9]
143
+ ]
144
+
145
+ with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
146
+ gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
147
+ gr.HTML("<p><center>Load the LoRA model on the menu</center></p>")
148
+ with gr.Row():
149
+ with gr.Column(scale=4):
150
+ gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
151
+ with gr.Row():
152
+ prompt_input = gr.Textbox(
153
+ label="Enter Your Prompt (Multi-Languages)",
154
+ lines=2,
155
+ placeholder="Enter prompt...",
156
+ scale=6
157
+ )
158
+ generate_btn = gr.Button(scale=1, variant="primary")
159
+ with gr.Accordion("Advanced Options", open=True):
160
+ with gr.Column(scale=1):
161
+ width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=896)
162
+ height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=1152)
163
+ guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=3.5)
164
+ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
165
+ seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
166
+ nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=4, step=1, value=1)
167
+ with gr.Column(scale=1):
168
+ lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
169
+ lora_add_text = gr.Textbox(
170
+ label="Flux LoRA",
171
+ info="Copy the HF LoRA model name here",
172
+ lines=1,
173
+ value="nftnik/BR_ohwx_V1"
174
+ )
175
+ lora_word_text = gr.Textbox(
176
+ label="Flux LoRA Trigger Word",
177
+ info="Add the Trigger Word",
178
+ lines=1,
179
+ value="ohwx"
180
+ )
181
+ load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
182
+ # New dropdowns for flux scheduler and sampler
183
+ flux_scheduler = gr.Dropdown(
184
+ label="Flux Scheduler",
185
+ choices=["DDIM", "PNDM", "DPMSolver"],
186
+ value="DDIM"
187
+ )
188
+ flux_sampler = gr.Dropdown(
189
+ label="Flux Sampler",
190
+ choices=["Default", "k_euler", "k_lms"],
191
+ value="Default"
192
+ )
193
+ update_flux_btn = gr.Button(value="Update Flux Settings", variant="secondary")
194
+ flux_status = gr.Textbox(label="Flux Settings Status", interactive=False)
195
+
196
+ gr.Examples(
197
+ examples=examples,
198
+ inputs=[prompt_input, lora_word_text, lora_scale_slider],
199
+ cache_examples=False,
200
+ examples_per_page=4,
201
+ )
202
+
203
+ load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
204
+ update_flux_btn.click(fn=update_flux_settings, inputs=[flux_scheduler, flux_sampler], outputs=flux_status)
205
+ generate_btn.click(
206
+ fn=generate_image,
207
+ inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider],
208
+ outputs=[gallery, seed_slider],
209
+ api_name="run",
210
+ )
211
+
212
+ demo.queue().launch()