Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,212 +1,209 @@
|
|
| 1 |
-
import spaces
|
| 2 |
-
import os
|
| 3 |
-
import gradio as gr
|
| 4 |
-
import torch
|
| 5 |
-
import numpy as np
|
| 6 |
-
import random
|
| 7 |
-
import requests
|
| 8 |
-
import re
|
| 9 |
-
from diffusers import FluxPipeline
|
| 10 |
-
from translatepy import Translator
|
| 11 |
-
from huggingface_hub import hf_hub_download
|
| 12 |
-
|
| 13 |
-
# Environment setup
|
| 14 |
-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 15 |
-
translator = Translator()
|
| 16 |
-
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 17 |
-
|
| 18 |
-
# Constants
|
| 19 |
-
MODEL_ID = "black-forest-labs/FLUX.1-dev"
|
| 20 |
-
DEFAULT_LORA = "nftnik/BR_ohwx_V1"
|
| 21 |
-
DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
|
| 22 |
-
MAX_SEED = np.iinfo(np.int32).max
|
| 23 |
-
|
| 24 |
-
CSS = """
|
| 25 |
-
footer {
|
| 26 |
-
visibility: hidden;
|
| 27 |
-
}
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
JS = """function () {
|
| 31 |
-
gradioURL = window.location.href;
|
| 32 |
-
if (!gradioURL.endsWith('?__theme=dark')) {
|
| 33 |
-
window.location.replace(gradioURL + '?__theme=dark');
|
| 34 |
-
}
|
| 35 |
-
}"""
|
| 36 |
-
|
| 37 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
-
print(f"Using {device.upper()}")
|
| 39 |
-
|
| 40 |
-
# Initialize
|
| 41 |
-
pipe = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
@spaces.GPU()
|
| 98 |
-
def generate_image(
|
| 99 |
-
prompt: str,
|
| 100 |
-
lora_word: str,
|
| 101 |
-
lora_scale: float = 0.8,
|
| 102 |
-
width: int = 896,
|
| 103 |
-
height: int = 1152,
|
| 104 |
-
guidance_scale: float = 3.5,
|
| 105 |
-
steps: int = 25,
|
| 106 |
-
seed: int = -1,
|
| 107 |
-
nums: int = 1,
|
| 108 |
-
progress=gr.Progress(track_tqdm=True)
|
| 109 |
-
):
|
| 110 |
-
#
|
| 111 |
-
pipe.to(device)
|
| 112 |
-
if seed == -1:
|
| 113 |
-
seed = random.randint(0, MAX_SEED)
|
| 114 |
-
seed = int(seed)
|
| 115 |
-
|
| 116 |
-
# Translate prompt to English.
|
| 117 |
-
translation = translator.translate(prompt, "English")
|
| 118 |
-
prompt_english = str(translation) # Adjust if translatepy returns a different attribute
|
| 119 |
-
full_prompt = f"{prompt_english} {lora_word}"
|
| 120 |
-
print(f"Prompt: {full_prompt}")
|
| 121 |
-
|
| 122 |
-
generator = torch.Generator().manual_seed(seed)
|
| 123 |
-
result = pipe(
|
| 124 |
-
prompt=full_prompt,
|
| 125 |
-
height=height,
|
| 126 |
-
width=width,
|
| 127 |
-
guidance_scale=guidance_scale,
|
| 128 |
-
output_type="pil",
|
| 129 |
-
num_inference_steps=steps,
|
| 130 |
-
max_sequence_length=512,
|
| 131 |
-
num_images_per_prompt=nums,
|
| 132 |
-
generator=generator,
|
| 133 |
-
joint_attention_kwargs={"scale": lora_scale},
|
| 134 |
-
)
|
| 135 |
-
return result.images, seed
|
| 136 |
-
|
| 137 |
-
#
|
| 138 |
-
examples = [
|
| 139 |
-
["
|
| 140 |
-
["
|
| 141 |
-
["
|
| 142 |
-
["
|
| 143 |
-
]
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
gr.HTML("<
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
demo.queue().launch()
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import os
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import requests
|
| 8 |
+
import re
|
| 9 |
+
from diffusers import FluxPipeline
|
| 10 |
+
from translatepy import Translator
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
# Environment setup
|
| 14 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 15 |
+
translator = Translator()
|
| 16 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 17 |
+
|
| 18 |
+
# Constants and configuration
|
| 19 |
+
MODEL_ID = "black-forest-labs/FLUX.1-dev"
|
| 20 |
+
DEFAULT_LORA = "nftnik/BR_ohwx_V1"
|
| 21 |
+
DEFAULT_WEIGHT_NAME = "BR_ohwx.safetensors"
|
| 22 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 23 |
+
|
| 24 |
+
CSS = """
|
| 25 |
+
footer {
|
| 26 |
+
visibility: hidden;
|
| 27 |
+
}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
JS = """function () {
|
| 31 |
+
const gradioURL = window.location.href;
|
| 32 |
+
if (!gradioURL.endsWith('?__theme=dark')) {
|
| 33 |
+
window.location.replace(gradioURL + '?__theme=dark');
|
| 34 |
+
}
|
| 35 |
+
}"""
|
| 36 |
+
|
| 37 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
print(f"Using {device.upper()}")
|
| 39 |
+
|
| 40 |
+
# Initialize the Flux pipeline
|
| 41 |
+
pipe = FluxPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
|
| 42 |
+
|
| 43 |
+
# Set the default sampler and scheduler.
|
| 44 |
+
# NOTE: This example assumes that your FluxPipeline has methods `set_sampler` and `set_scheduler`
|
| 45 |
+
# which accept a string indicating the desired method.
|
| 46 |
+
print("Setting default sampler to 'euler' and default scheduler to 'beta' ...")
|
| 47 |
+
pipe.set_sampler("euler") # Replace with the correct call if your API differs
|
| 48 |
+
pipe.set_scheduler("beta") # Replace with the correct call if your API differs
|
| 49 |
+
|
| 50 |
+
# Load the default LoRA weights
|
| 51 |
+
pipe.load_lora_weights(DEFAULT_LORA, weight_name=DEFAULT_WEIGHT_NAME)
|
| 52 |
+
|
| 53 |
+
def scrape_lora_link(url: str):
|
| 54 |
+
try:
|
| 55 |
+
response = requests.get(url)
|
| 56 |
+
response.raise_for_status()
|
| 57 |
+
content = response.text
|
| 58 |
+
pattern = r'href="(.*?lora.*?\.safetensors\?download=true)"'
|
| 59 |
+
pattern2 = r'href="(.*?\.safetensors\?download=true)"'
|
| 60 |
+
match = re.search(pattern, content)
|
| 61 |
+
match2 = re.search(pattern2, content)
|
| 62 |
+
if match:
|
| 63 |
+
safetensors_url = match.group(1)
|
| 64 |
+
filename = safetensors_url.split('/')[-1].split('?')[0]
|
| 65 |
+
return filename
|
| 66 |
+
elif match2:
|
| 67 |
+
safetensors_url = match2.group(1)
|
| 68 |
+
filename = safetensors_url.split('/')[-1].split('?')[0]
|
| 69 |
+
return filename
|
| 70 |
+
else:
|
| 71 |
+
return None
|
| 72 |
+
except requests.RequestException as e:
|
| 73 |
+
raise gr.Error(f"An error occurred while fetching the URL: {e}")
|
| 74 |
+
|
| 75 |
+
def enable_lora(lora_add: str, progress=gr.Progress(track_tqdm=True)):
|
| 76 |
+
pipe.unload_lora_weights()
|
| 77 |
+
if not lora_add:
|
| 78 |
+
gr.Info("No LoRA Loaded, using base model")
|
| 79 |
+
return gr.update(value="")
|
| 80 |
+
else:
|
| 81 |
+
url = f"https://huggingface.co/{lora_add}/tree/main"
|
| 82 |
+
lora_name = scrape_lora_link(url)
|
| 83 |
+
if lora_name:
|
| 84 |
+
print(f"Loading LoRA: {lora_add}/{lora_name}")
|
| 85 |
+
pipe.load_lora_weights(lora_add, weight_name=lora_name)
|
| 86 |
+
gr.Info(f"{lora_add} Loaded")
|
| 87 |
+
return gr.update(label="LoRA Loaded Now")
|
| 88 |
+
else:
|
| 89 |
+
try:
|
| 90 |
+
pipe.load_lora_weights(lora_add)
|
| 91 |
+
print(f"Loading LoRA: {lora_add}")
|
| 92 |
+
gr.Info(f"{lora_add} Loaded")
|
| 93 |
+
return gr.update(label="LoRA Loaded Now")
|
| 94 |
+
except Exception as e:
|
| 95 |
+
raise gr.Error(f"{lora_add} load failed: {e}")
|
| 96 |
+
|
| 97 |
+
@spaces.GPU()
|
| 98 |
+
def generate_image(
|
| 99 |
+
prompt: str,
|
| 100 |
+
lora_word: str,
|
| 101 |
+
lora_scale: float = 0.8,
|
| 102 |
+
width: int = 896,
|
| 103 |
+
height: int = 1152,
|
| 104 |
+
guidance_scale: float = 3.5,
|
| 105 |
+
steps: int = 25,
|
| 106 |
+
seed: int = -1,
|
| 107 |
+
nums: int = 1,
|
| 108 |
+
progress=gr.Progress(track_tqdm=True)
|
| 109 |
+
):
|
| 110 |
+
# Make sure the pipeline is on the correct device.
|
| 111 |
+
pipe.to(device)
|
| 112 |
+
if seed == -1:
|
| 113 |
+
seed = random.randint(0, MAX_SEED)
|
| 114 |
+
seed = int(seed)
|
| 115 |
+
|
| 116 |
+
# Translate the prompt to English.
|
| 117 |
+
translation = translator.translate(prompt, "English")
|
| 118 |
+
prompt_english = str(translation) # Adjust if translatepy returns a different attribute
|
| 119 |
+
full_prompt = f"{prompt_english} {lora_word}"
|
| 120 |
+
print(f"Prompt: {full_prompt}")
|
| 121 |
+
|
| 122 |
+
generator = torch.Generator().manual_seed(seed)
|
| 123 |
+
result = pipe(
|
| 124 |
+
prompt=full_prompt,
|
| 125 |
+
height=height,
|
| 126 |
+
width=width,
|
| 127 |
+
guidance_scale=guidance_scale,
|
| 128 |
+
output_type="pil",
|
| 129 |
+
num_inference_steps=steps,
|
| 130 |
+
max_sequence_length=512,
|
| 131 |
+
num_images_per_prompt=nums,
|
| 132 |
+
generator=generator,
|
| 133 |
+
joint_attention_kwargs={"scale": lora_scale},
|
| 134 |
+
)
|
| 135 |
+
return result.images, seed
|
| 136 |
+
|
| 137 |
+
# Example prompts for demonstration
|
| 138 |
+
examples = [
|
| 139 |
+
["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom. The background consists of a modern, clean showroom with diffused color neon lighting, creating a high-end, sophisticated aesthetic", "ohwx", 0.9],
|
| 140 |
+
["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape. A neon purple and magenta glow reflects on his skin as he stands inside a Metaverse oasis. His expression is focused, with his hands outstretched, interacting with the ambient. The environment is sleek and futuristic, with deep shadows and vibrant lighting creating a cinematic composition", "ohwx", 0.9],
|
| 141 |
+
["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night. His posture is dynamic, mid-stride, arms pumping. The wet pavement reflects the bright neon signs from above, casting colorful reflections on his sleek techwear. The deep shadows and dramatic lighting emphasize the futuristic setting", "ohwx", 0.9],
|
| 142 |
+
["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience. His eyes glow with digital overlays. The lighting is a mix of deep grey ambient hues with bright cyan highlights from the AR projections.", "ohwx", 0.9]
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
# Build the Gradio interface
|
| 146 |
+
with gr.Blocks(css=CSS, js=JS, theme="Nymbo/Nymbo_Theme") as demo:
|
| 147 |
+
gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
|
| 148 |
+
gr.HTML("<p><center>Load the LoRA model on the menu</center></p>")
|
| 149 |
+
with gr.Row():
|
| 150 |
+
with gr.Column(scale=4):
|
| 151 |
+
gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
|
| 152 |
+
with gr.Row():
|
| 153 |
+
prompt_input = gr.Textbox(
|
| 154 |
+
label="Enter Your Prompt (Multi-Languages)",
|
| 155 |
+
lines=2,
|
| 156 |
+
placeholder="Enter prompt...",
|
| 157 |
+
scale=6
|
| 158 |
+
)
|
| 159 |
+
generate_btn = gr.Button(scale=1, variant="primary")
|
| 160 |
+
with gr.Accordion("Advanced Options", open=True):
|
| 161 |
+
with gr.Column(scale=1):
|
| 162 |
+
width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=896)
|
| 163 |
+
height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=1152)
|
| 164 |
+
guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=3.5)
|
| 165 |
+
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=25)
|
| 166 |
+
seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=MAX_SEED, step=1, value=-1)
|
| 167 |
+
nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=4, step=1, value=1)
|
| 168 |
+
with gr.Column(scale=1):
|
| 169 |
+
lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
|
| 170 |
+
lora_add_text = gr.Textbox(
|
| 171 |
+
label="Flux LoRA",
|
| 172 |
+
info="Copy the HF LoRA model name here",
|
| 173 |
+
lines=1,
|
| 174 |
+
value="nftnik/BR_ohwx_V1"
|
| 175 |
+
)
|
| 176 |
+
lora_word_text = gr.Textbox(
|
| 177 |
+
label="Flux LoRA Trigger Word",
|
| 178 |
+
info="Add the Trigger Word",
|
| 179 |
+
lines=1,
|
| 180 |
+
value="ohwx"
|
| 181 |
+
)
|
| 182 |
+
load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
|
| 183 |
+
|
| 184 |
+
gr.Examples(
|
| 185 |
+
examples=examples,
|
| 186 |
+
inputs=[prompt_input, lora_word_text, lora_scale_slider],
|
| 187 |
+
cache_examples=False,
|
| 188 |
+
examples_per_page=4,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
|
| 192 |
+
generate_btn.click(
|
| 193 |
+
fn=generate_image,
|
| 194 |
+
inputs=[
|
| 195 |
+
prompt_input,
|
| 196 |
+
lora_word_text,
|
| 197 |
+
lora_scale_slider,
|
| 198 |
+
width_slider,
|
| 199 |
+
height_slider,
|
| 200 |
+
guidance_slider,
|
| 201 |
+
steps_slider,
|
| 202 |
+
seed_slider,
|
| 203 |
+
nums_slider
|
| 204 |
+
],
|
| 205 |
+
outputs=[gallery, seed_slider],
|
| 206 |
+
api_name="run",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
demo.queue().launch(ssr=False)
|
|
|
|
|
|
|
|
|