import gradio as gr import torch from diffusers import StableDiffusionXLPipeline import os # Initialize the pipeline def load_model(): pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True ) # Load LoRA weights - adjust path as needed pipe.load_lora_weights( "./lora_weights", # This should be your uploaded LoRA folder weight_name="pytorch_lora_weights.safetensors", adapter_name="alextime", ) # Move to GPU if available if torch.cuda.is_available(): pipe = pipe.to("cuda") return pipe # Load model once at startup pipe = load_model() def generate_image(prompt, negative_prompt="", num_inference_steps=60, guidance_scale=8.5, seed=42): """Generate image using the fine-tuned model""" # Set seed for reproducibility if seed != -1: generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) else: generator = None # Add your model trigger word/style enhanced_prompt = f"{prompt}, alextime" try: image = pipe( enhanced_prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, ).images[0] return image except Exception as e: print(f"Error generating image: {e}") return None # Create Gradio interface with gr.Blocks(title="AlexTime LoRA - Fine-tuned Stable Diffusion XL") as demo: gr.Markdown("# AlexTime LoRA - Stable Diffusion XL") gr.Markdown("Generate images using your fine-tuned LoRA model. The 'alextime' trigger will be automatically added to your prompt.") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="a dreadful hour, mexico", lines=3 ) negative_prompt_input = gr.Textbox( label="Negative Prompt (optional)", placeholder="blurry, low quality, distorted", lines=2 ) with gr.Row(): steps_slider = gr.Slider( minimum=20, maximum=100, value=60, step=1, label="Inference Steps" ) guidance_slider = gr.Slider( minimum=1.0, maximum=20.0, value=8.5, step=0.1, label="Guidance Scale" ) seed_input = gr.Number( label="Seed (-1 for random)", value=42, precision=0 ) generate_btn = gr.Button("Generate Image", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") # Examples gr.Examples( examples=[ ["a dreadful hour, mexico", "", 60, 8.5, 42], ["a peaceful sunset, alextime style", "", 50, 7.0, 123], ["portrait of a person, dramatic lighting", "", 60, 9.0, 456], ], inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, seed_input], outputs=output_image, fn=generate_image, cache_examples=True ) generate_btn.click( fn=generate_image, inputs=[prompt_input, negative_prompt_input, steps_slider, guidance_slider, seed_input], outputs=output_image ) if __name__ == "__main__": demo.launch()