el_brujerizmo / app.py
cwiz's picture
web ui
7a107be
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline
import os
# Initialize the pipeline
def load_model():
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True
)
# Load LoRA weights - adjust path as needed
pipe.load_lora_weights(
"./lora_weights", # This should be your uploaded LoRA folder
weight_name="pytorch_lora_weights.safetensors",
adapter_name="alextime",
)
# Move to GPU if available
if torch.cuda.is_available():
pipe = pipe.to("cuda")
return pipe
# Load model once at startup
pipe = load_model()
def generate_image(prompt, negative_prompt="", num_inference_steps=60, guidance_scale=8.5, seed=42):
"""Generate image using the fine-tuned model"""
# Set seed for reproducibility
if seed != -1:
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
else:
generator = None
# Add your model trigger word/style
enhanced_prompt = f"{prompt}, alextime"
try:
image = pipe(
enhanced_prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
).images[0]
return image
except Exception as e:
print(f"Error generating image: {e}")
return None
# Create Gradio interface
with gr.Blocks(title="AlexTime LoRA - Fine-tuned Stable Diffusion XL") as demo:
gr.Markdown("# AlexTime LoRA - Stable Diffusion XL")
gr.Markdown("Generate images using your fine-tuned LoRA model. The 'alextime' trigger will be automatically added to your prompt.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="a dreadful hour, mexico",
lines=3
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt (optional)",
placeholder="blurry, low quality, distorted",
lines=2
)
with gr.Row():
steps_slider = gr.Slider(
minimum=20,
maximum=100,
value=60,
step=1,
label="Inference Steps"
)
guidance_slider = gr.Slider(
minimum=1.0,
maximum=20.0,
value=8.5,
step=0.1,
label="Guidance Scale"
)
seed_input = gr.Number(
label="Seed (-1 for random)",
value=42,
precision=0
)
generate_btn = gr.Button("Generate Image", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
# Examples
gr.Examples(
examples=[
["a dreadful hour, mexico", "", 60, 8.5, 42],
["a peaceful sunset, alextime style", "", 50, 7.0, 123],
["portrait of a person, dramatic lighting", "", 60, 9.0, 456],
],
inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, seed_input],
outputs=output_image,
fn=generate_image,
cache_examples=True
)
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, seed_input],
outputs=output_image
)
if __name__ == "__main__":
demo.launch()