Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline | |
import os | |
# Initialize the pipeline | |
def load_model(): | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
use_safetensors=True | |
) | |
# Load LoRA weights - adjust path as needed | |
pipe.load_lora_weights( | |
"./lora_weights", # This should be your uploaded LoRA folder | |
weight_name="pytorch_lora_weights.safetensors", | |
adapter_name="alextime", | |
) | |
# Move to GPU if available | |
if torch.cuda.is_available(): | |
pipe = pipe.to("cuda") | |
return pipe | |
# Load model once at startup | |
pipe = load_model() | |
def generate_image(prompt, negative_prompt="", num_inference_steps=60, guidance_scale=8.5, seed=42): | |
"""Generate image using the fine-tuned model""" | |
# Set seed for reproducibility | |
if seed != -1: | |
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) | |
else: | |
generator = None | |
# Add your model trigger word/style | |
enhanced_prompt = f"{prompt}, alextime" | |
try: | |
image = pipe( | |
enhanced_prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
).images[0] | |
return image | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
return None | |
# Create Gradio interface | |
with gr.Blocks(title="AlexTime LoRA - Fine-tuned Stable Diffusion XL") as demo: | |
gr.Markdown("# AlexTime LoRA - Stable Diffusion XL") | |
gr.Markdown("Generate images using your fine-tuned LoRA model. The 'alextime' trigger will be automatically added to your prompt.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="a dreadful hour, mexico", | |
lines=3 | |
) | |
negative_prompt_input = gr.Textbox( | |
label="Negative Prompt (optional)", | |
placeholder="blurry, low quality, distorted", | |
lines=2 | |
) | |
with gr.Row(): | |
steps_slider = gr.Slider( | |
minimum=20, | |
maximum=100, | |
value=60, | |
step=1, | |
label="Inference Steps" | |
) | |
guidance_slider = gr.Slider( | |
minimum=1.0, | |
maximum=20.0, | |
value=8.5, | |
step=0.1, | |
label="Guidance Scale" | |
) | |
seed_input = gr.Number( | |
label="Seed (-1 for random)", | |
value=42, | |
precision=0 | |
) | |
generate_btn = gr.Button("Generate Image", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image", type="pil") | |
# Examples | |
gr.Examples( | |
examples=[ | |
["a dreadful hour, mexico", "", 60, 8.5, 42], | |
["a peaceful sunset, alextime style", "", 50, 7.0, 123], | |
["portrait of a person, dramatic lighting", "", 60, 9.0, 456], | |
], | |
inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, seed_input], | |
outputs=output_image, | |
fn=generate_image, | |
cache_examples=True | |
) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, seed_input], | |
outputs=output_image | |
) | |
if __name__ == "__main__": | |
demo.launch() |