File size: 2,431 Bytes
7b1a432
 
 
 
 
02a3a52
7b1a432
dfab3a9
02a3a52
dfab3a9
7b1a432
dfab3a9
 
509d782
dfab3a9
dc25832
02a3a52
 
 
 
 
 
 
 
 
 
 
 
7b1a432
dfab3a9
4a6aac0
7b1a432
dc25832
 
 
 
dfab3a9
dc25832
dfab3a9
 
dc25832
 
 
 
dfab3a9
 
dc25832
dfab3a9
 
dc25832
 
dfab3a9
 
7b1a432
 
dfab3a9
7b1a432
 
 
 
dfab3a9
7b1a432
 
 
 
 
 
 
509d782
7b1a432
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
import os
from diffusers import StableDiffusion3Pipeline
from safetensors.torch import load_file
from spaces import GPU  # Remove if not in HF Space

# 1. Define model ID and HF_TOKEN (at the VERY beginning)
model_id = "stabilityai/stable-diffusion-3.5-large"  # Or your preferred model ID
hf_token = os.getenv("HF_TOKEN")  # For private models (set in HF Space settings)

# 2. Initialize pipeline (to None initially)
pipeline = None

# 3. Load Stable Diffusion and LoRA (before Gradio)
try:
    if hf_token: # check if the token exists, if not, then do not pass the token
        pipeline = StableDiffusion3Pipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            cache_dir="./model_cache"  # For caching
        )
    else:
        pipeline = StableDiffusion3Pipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            cache_dir="./model_cache"  # For caching
        )

    lora_filename = "lora_trained_model.safetensors"  # EXACT filename of your LoRA
    lora_path = os.path.join("./", lora_filename)

    if os.path.exists(lora_path):
        lora_weights = load_file(lora_path)
        text_encoder = pipeline.text_encoder
        text_encoder.load_state_dict(lora_weights, strict=False)
        print(f"LoRA loaded successfully from: {lora_path}")
    else:
        print(f"Error: LoRA file not found at: {lora_path}")
        exit()  # Stop if LoRA is not found

    print("Stable Diffusion model loaded successfully!")

except Exception as e:
    print(f"Error loading model or LoRA: {e}")
    exit()  # Stop if model loading fails

# 4. Image generation function (now decorated)
@GPU(duration=65)  # ONLY if in a HF Space, remove if not
def generate_image(prompt):
    global pipeline
    if pipeline is None:  # Should not happen, but good to check
        return "Error: Model not loaded!"

    try:
        image = pipeline(prompt).images[0]  # Access the first image from the list
        return image
    except Exception as e:
        return f"Error generating image: {e}"

# 5. Gradio interface
with gr.Blocks() as demo:
    prompt_input = gr.Textbox(label="Prompt")
    image_output = gr.Image(label="Generated Image")
    generate_button = gr.Button("Generate")

    generate_button.click(
        fn=generate_image,
        inputs=prompt_input,
        outputs=image_output,
    )

demo.launch()