nftnik commited on
Commit
cc0fe43
·
verified ·
1 Parent(s): 7969a4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -212
app.py CHANGED
@@ -1,212 +1,209 @@
1
- import spaces
2
- import os
3
- import gradio as gr
4
- import torch
5
- import numpy as np
6
- import random
7
- import requests
8
- import re
9
- from diffusers import FluxPipeline
10
- from translatepy import Translator
11
- from huggingface_hub import hf_hub_download
12
-
13
- # Environment setup
14
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15
- translator = Translator()
16
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
-
18
- # Constants
19
- MODEL_ID = "black-forest-labs/FLUX.1-dev"
20
- DEFAULT_LORA = "nftnik/BR_ohwx_V1"
21
- DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
22
- MAX_SEED = np.iinfo(np.int32).max
23
-
24
- CSS = """
25
- footer {
26
- visibility: hidden;
27
- }
28
- """
29
-
30
- JS = """function () {
31
- gradioURL = window.location.href;
32
- if (!gradioURL.endsWith('?__theme=dark')) {
33
- window.location.replace(gradioURL + '?__theme=dark');
34
- }
35
- }"""
36
-
37
- device = "cuda" if torch.cuda.is_available() else "cpu"
38
- print(f"Using {device.upper()}")
39
-
40
- # Initialize pipeline and load default LoRA weights
41
- pipe = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
42
- pipe.load_lora_weights(DEFAULT_LORA, weight_name=DEFAULT_WEIGHT_NAME)
43
-
44
- def scrape_lora_link(url: str):
45
- try:
46
- response = requests.get(url)
47
- response.raise_for_status()
48
- content = response.text
49
- pattern = r'href="(.*?lora.*?\.safetensors\?download=true)"'
50
- pattern2 = r'href="(.*?\.safetensors\?download=true)"'
51
- match = re.search(pattern, content)
52
- match2 = re.search(pattern2, content)
53
- if match:
54
- safetensors_url = match.group(1)
55
- filename = safetensors_url.split('/')[-1].split('?')[0]
56
- return filename
57
- elif match2:
58
- safetensors_url = match2.group(1)
59
- filename = safetensors_url.split('/')[-1].split('?')[0]
60
- return filename
61
- else:
62
- return None
63
- except requests.RequestException as e:
64
- raise gr.Error(f"An error occurred while fetching the URL: {e}")
65
-
66
- def enable_lora(lora_add: str, progress=gr.Progress(track_tqdm=True)):
67
- pipe.unload_lora_weights()
68
- if not lora_add:
69
- gr.Info("No LoRA Loaded, using base model")
70
- return gr.update(value="")
71
- else:
72
- url = f"https://huggingface.co/{lora_add}/tree/main"
73
- lora_name = scrape_lora_link(url)
74
- if lora_name:
75
- print(f"Loading LoRA: {lora_add}/{lora_name}")
76
- pipe.load_lora_weights(lora_add, weight_name=lora_name)
77
- gr.Info(f"{lora_add} Loaded")
78
- return gr.update(label="LoRA Loaded Now")
79
- else:
80
- try:
81
- pipe.load_lora_weights(lora_add)
82
- print(f"Loading LoRA: {lora_add}")
83
- gr.Info(f"{lora_add} Loaded")
84
- return gr.update(label="LoRA Loaded Now")
85
- except Exception as e:
86
- raise gr.Error(f"{lora_add} load failed: {e}")
87
-
88
- # Placeholder function to update flux scheduler and sampler settings.
89
- def update_flux_settings(scheduler_choice: str, sampler_choice: str):
90
- # Replace the code below with actual logic to update your pipeline's scheduler/sampler.
91
- print(f"Setting scheduler to {scheduler_choice} and sampler to {sampler_choice}")
92
- # e.g.:
93
- # pipe.scheduler = SchedulerClassMapping[scheduler_choice].from_config(pipe.scheduler.config)
94
- # pipe.sampler = SamplerClassMapping[sampler_choice](**pipe.sampler_config)
95
- return f"Scheduler set to {scheduler_choice} and Sampler set to {sampler_choice}"
96
-
97
- @spaces.GPU()
98
- def generate_image(
99
- prompt: str,
100
- lora_word: str,
101
- lora_scale: float = 0.8,
102
- width: int = 896,
103
- height: int = 1152,
104
- guidance_scale: float = 3.5,
105
- steps: int = 25,
106
- seed: int = -1,
107
- nums: int = 1,
108
- progress=gr.Progress(track_tqdm=True)
109
- ):
110
- # Ensure the pipeline is on the correct device.
111
- pipe.to(device)
112
- if seed == -1:
113
- seed = random.randint(0, MAX_SEED)
114
- seed = int(seed)
115
-
116
- # Translate prompt to English.
117
- translation = translator.translate(prompt, "English")
118
- prompt_english = str(translation) # Adjust if translatepy returns a different attribute.
119
- full_prompt = f"{prompt_english} {lora_word}"
120
- print(f"Prompt: {full_prompt}")
121
-
122
- generator = torch.Generator().manual_seed(seed)
123
- result = pipe(
124
- prompt=full_prompt,
125
- height=height,
126
- width=width,
127
- guidance_scale=guidance_scale,
128
- output_type="pil",
129
- num_inference_steps=steps,
130
- max_sequence_length=512,
131
- num_images_per_prompt=nums,
132
- generator=generator,
133
- joint_attention_kwargs={"scale": lora_scale},
134
- )
135
- return result.images, seed
136
-
137
- # Sample examples
138
- examples = [
139
- ["close-up portrait of a futuristic alien in cyberpunk attire", "ohwx", 0.9],
140
- ["full-body shot of an alien running through a neon-lit cityscape", "ohwx", 0.9],
141
- ["portrait of a blue alien with techwear in a virtual reality environment", "ohwx", 0.9],
142
- ["cyberpunk style portrait of an alien with glowing eyes", "ohwx", 0.9]
143
- ]
144
-
145
- with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
146
- gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
147
- gr.HTML("<p><center>Load the LoRA model on the menu</center></p>")
148
- with gr.Row():
149
- with gr.Column(scale=4):
150
- gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
151
- with gr.Row():
152
- prompt_input = gr.Textbox(
153
- label="Enter Your Prompt (Multi-Languages)",
154
- lines=2,
155
- placeholder="Enter prompt...",
156
- scale=6
157
- )
158
- generate_btn = gr.Button(scale=1, variant="primary")
159
- with gr.Accordion("Advanced Options", open=True):
160
- with gr.Column(scale=1):
161
- width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=896)
162
- height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=1152)
163
- guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=3.5)
164
- steps_slider = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
165
- seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
166
- nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=4, step=1, value=1)
167
- with gr.Column(scale=1):
168
- lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
169
- lora_add_text = gr.Textbox(
170
- label="Flux LoRA",
171
- info="Copy the HF LoRA model name here",
172
- lines=1,
173
- value="nftnik/BR_ohwx_V1"
174
- )
175
- lora_word_text = gr.Textbox(
176
- label="Flux LoRA Trigger Word",
177
- info="Add the Trigger Word",
178
- lines=1,
179
- value="ohwx"
180
- )
181
- load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
182
- # New dropdowns for flux scheduler and sampler
183
- flux_scheduler = gr.Dropdown(
184
- label="Flux Scheduler",
185
- choices=["DDIM", "PNDM", "DPMSolver"],
186
- value="DDIM"
187
- )
188
- flux_sampler = gr.Dropdown(
189
- label="Flux Sampler",
190
- choices=["Default", "k_euler", "k_lms"],
191
- value="Default"
192
- )
193
- update_flux_btn = gr.Button(value="Update Flux Settings", variant="secondary")
194
- flux_status = gr.Textbox(label="Flux Settings Status", interactive=False)
195
-
196
- gr.Examples(
197
- examples=examples,
198
- inputs=[prompt_input, lora_word_text, lora_scale_slider],
199
- cache_examples=False,
200
- examples_per_page=4,
201
- )
202
-
203
- load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
204
- update_flux_btn.click(fn=update_flux_settings, inputs=[flux_scheduler, flux_sampler], outputs=flux_status)
205
- generate_btn.click(
206
- fn=generate_image,
207
- inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider],
208
- outputs=[gallery, seed_slider],
209
- api_name="run",
210
- )
211
-
212
- demo.queue().launch()
 
1
+ import spaces
2
+ import os
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ import requests
8
+ import re
9
+ from diffusers import FluxPipeline
10
+ from translatepy import Translator
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # Environment setup
14
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15
+ translator = Translator()
16
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
+
18
+ # Constants and configuration
19
+ MODEL_ID = "black-forest-labs/FLUX.1-dev"
20
+ DEFAULT_LORA = "nftnik/BR_ohwx_V1"
21
+ DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+
24
+ CSS = """
25
+ footer {
26
+ visibility: hidden;
27
+ }
28
+ """
29
+
30
+ JS = """function () {
31
+ const gradioURL = window.location.href;
32
+ if (!gradioURL.endsWith('?__theme=dark')) {
33
+ window.location.replace(gradioURL + '?__theme=dark');
34
+ }
35
+ }"""
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ print(f"Using {device.upper()}")
39
+
40
+ # Initialize the Flux pipeline
41
+ pipe = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
42
+
43
+ # Set the default sampler and scheduler.
44
+ # NOTE: This example assumes that your FluxPipeline has methods `set_sampler` and `set_scheduler`
45
+ # which accept a string indicating the desired method.
46
+ print("Setting default sampler to 'euler' and default scheduler to 'beta' ...")
47
+ pipe.set_sampler("euler") # Replace with the correct call if your API differs
48
+ pipe.set_scheduler("beta") # Replace with the correct call if your API differs
49
+
50
+ # Load the default LoRA weights
51
+ pipe.load_lora_weights(DEFAULT_LORA, weight_name=DEFAULT_WEIGHT_NAME)
52
+
53
+ def scrape_lora_link(url: str):
54
+ try:
55
+ response = requests.get(url)
56
+ response.raise_for_status()
57
+ content = response.text
58
+ pattern = r'href="(.*?lora.*?\.safetensors\?download=true)"'
59
+ pattern2 = r'href="(.*?\.safetensors\?download=true)"'
60
+ match = re.search(pattern, content)
61
+ match2 = re.search(pattern2, content)
62
+ if match:
63
+ safetensors_url = match.group(1)
64
+ filename = safetensors_url.split('/')[-1].split('?')[0]
65
+ return filename
66
+ elif match2:
67
+ safetensors_url = match2.group(1)
68
+ filename = safetensors_url.split('/')[-1].split('?')[0]
69
+ return filename
70
+ else:
71
+ return None
72
+ except requests.RequestException as e:
73
+ raise gr.Error(f"An error occurred while fetching the URL: {e}")
74
+
75
+ def enable_lora(lora_add: str, progress=gr.Progress(track_tqdm=True)):
76
+ pipe.unload_lora_weights()
77
+ if not lora_add:
78
+ gr.Info("No LoRA Loaded, using base model")
79
+ return gr.update(value="")
80
+ else:
81
+ url = f"https://huggingface.co/{lora_add}/tree/main"
82
+ lora_name = scrape_lora_link(url)
83
+ if lora_name:
84
+ print(f"Loading LoRA: {lora_add}/{lora_name}")
85
+ pipe.load_lora_weights(lora_add, weight_name=lora_name)
86
+ gr.Info(f"{lora_add} Loaded")
87
+ return gr.update(label="LoRA Loaded Now")
88
+ else:
89
+ try:
90
+ pipe.load_lora_weights(lora_add)
91
+ print(f"Loading LoRA: {lora_add}")
92
+ gr.Info(f"{lora_add} Loaded")
93
+ return gr.update(label="LoRA Loaded Now")
94
+ except Exception as e:
95
+ raise gr.Error(f"{lora_add} load failed: {e}")
96
+
97
+ @spaces.GPU()
98
+ def generate_image(
99
+ prompt: str,
100
+ lora_word: str,
101
+ lora_scale: float = 0.8,
102
+ width: int = 896,
103
+ height: int = 1152,
104
+ guidance_scale: float = 3.5,
105
+ steps: int = 25,
106
+ seed: int = -1,
107
+ nums: int = 1,
108
+ progress=gr.Progress(track_tqdm=True)
109
+ ):
110
+ # Make sure the pipeline is on the correct device.
111
+ pipe.to(device)
112
+ if seed == -1:
113
+ seed = random.randint(0, MAX_SEED)
114
+ seed = int(seed)
115
+
116
+ # Translate the prompt to English.
117
+ translation = translator.translate(prompt, "English")
118
+ prompt_english = str(translation) # Adjust if translatepy returns a different attribute
119
+ full_prompt = f"{prompt_english} {lora_word}"
120
+ print(f"Prompt: {full_prompt}")
121
+
122
+ generator = torch.Generator().manual_seed(seed)
123
+ result = pipe(
124
+ prompt=full_prompt,
125
+ height=height,
126
+ width=width,
127
+ guidance_scale=guidance_scale,
128
+ output_type="pil",
129
+ num_inference_steps=steps,
130
+ max_sequence_length=512,
131
+ num_images_per_prompt=nums,
132
+ generator=generator,
133
+ joint_attention_kwargs={"scale": lora_scale},
134
+ )
135
+ return result.images, seed
136
+
137
+ # Example prompts for demonstration
138
+ examples = [
139
+ ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom. The background consists of a modern, clean showroom with diffused color neon lighting, creating a high-end, sophisticated aesthetic", "ohwx", 0.9],
140
+ ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape. A neon purple and magenta glow reflects on his skin as he stands inside a Metaverse oasis. His expression is focused, with his hands outstretched, interacting with the ambient. The environment is sleek and futuristic, with deep shadows and vibrant lighting creating a cinematic composition", "ohwx", 0.9],
141
+ ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night. His posture is dynamic, mid-stride, arms pumping. The wet pavement reflects the bright neon signs from above, casting colorful reflections on his sleek techwear. The deep shadows and dramatic lighting emphasize the futuristic setting", "ohwx", 0.9],
142
+ ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience. His eyes glow with digital overlays. The lighting is a mix of deep grey ambient hues with bright cyan highlights from the AR projections.", "ohwx", 0.9]
143
+ ]
144
+
145
+ # Build the Gradio interface
146
+ with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
147
+ gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
148
+ gr.HTML("<p><center>Load the LoRA model on the menu</center></p>")
149
+ with gr.Row():
150
+ with gr.Column(scale=4):
151
+ gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
152
+ with gr.Row():
153
+ prompt_input = gr.Textbox(
154
+ label="Enter Your Prompt (Multi-Languages)",
155
+ lines=2,
156
+ placeholder="Enter prompt...",
157
+ scale=6
158
+ )
159
+ generate_btn = gr.Button(scale=1, variant="primary")
160
+ with gr.Accordion("Advanced Options", open=True):
161
+ with gr.Column(scale=1):
162
+ width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=896)
163
+ height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=1152)
164
+ guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=3.5)
165
+ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
166
+ seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
167
+ nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=4, step=1, value=1)
168
+ with gr.Column(scale=1):
169
+ lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
170
+ lora_add_text = gr.Textbox(
171
+ label="Flux LoRA",
172
+ info="Copy the HF LoRA model name here",
173
+ lines=1,
174
+ value="nftnik/BR_ohwx_V1"
175
+ )
176
+ lora_word_text = gr.Textbox(
177
+ label="Flux LoRA Trigger Word",
178
+ info="Add the Trigger Word",
179
+ lines=1,
180
+ value="ohwx"
181
+ )
182
+ load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
183
+
184
+ gr.Examples(
185
+ examples=examples,
186
+ inputs=[prompt_input, lora_word_text, lora_scale_slider],
187
+ cache_examples=False,
188
+ examples_per_page=4,
189
+ )
190
+
191
+ load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
192
+ generate_btn.click(
193
+ fn=generate_image,
194
+ inputs=[
195
+ prompt_input,
196
+ lora_word_text,
197
+ lora_scale_slider,
198
+ width_slider,
199
+ height_slider,
200
+ guidance_slider,
201
+ steps_slider,
202
+ seed_slider,
203
+ nums_slider
204
+ ],
205
+ outputs=[gallery, seed_slider],
206
+ api_name="run",
207
+ )
208
+
209
+ demo.queue().launch(ssr=False)