File size: 3,727 Bytes
7b1a432
 
 
c93b55a
 
 
7b1a432
02a3a52
7b1a432
c93b55a
 
 
 
 
7b1a432
dc25832
c93b55a
 
7b1a432
dfab3a9
4a6aac0
7b1a432
dc25832
 
c93b55a
dc25832
dfab3a9
dc25832
dfab3a9
 
dc25832
c93b55a
dc25832
 
dfab3a9
c93b55a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc25832
dfab3a9
b9c75da
dc25832
 
b9c75da
23d76b5
dfab3a9
7b1a432
 
23d76b5
 
 
 
7b1a432
23d76b5
 
 
 
 
 
 
7b1a432
dfab3a9
7b1a432
 
 
 
 
 
 
509d782
7b1a432
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
import os
import random
import numpy as np
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from spaces import GPU  # Remove if not in HF Space

# 1. Model and LoRA Loading (Before Gradio)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
token = os.getenv("HF_TOKEN")
model_repo_id = "stabilityai/stable-diffusion-3.5-large"

try:
    pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this
    pipe = pipe.to(device)

    lora_filename = "lora_trained_model.safetensors"  # EXACT filename of your LoRA
    lora_path = os.path.join("./", lora_filename)

    if os.path.exists(lora_path):
        lora_weights = load_file(lora_path)
        text_encoder = pipe.text_encoder
        text_encoder.load_state_dict(lora_weights, strict=False)
        print(f"LoRA loaded successfully from: {lora_path}")
    else:
        print(f"Error: LoRA file not found at: {lora_path}")
        exit()  # Stop if LoRA is not found

    print("Stable Diffusion model and LoRA loaded successfully!")

except Exception as e:
    print(f"Error loading model or LoRA: {e}")
    exit()


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

@GPU(duration=65)  # Only if in HF Space
def infer(
    prompt,
    negative_prompt="",
    seed=42,
    randomize_seed=False,
    width=1024,
    height=1024,
    guidance_scale=4.5,
    num_inference_steps=40,
    progress=gr.Progress(track_tqdm=True),
):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)

    try:
        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
        ).images[0]
        return image, seed
    except Exception as e:
        print(f"Error during image generation: {e}") # Print error for debugging
        return f"Error: {e}", seed # Return error to Gradio interface

# ... (rest of your Gradio code - examples, CSS, etc. - same as before)

# 4. Image generation function (now decorated)
@GPU(duration=65)  # Only if in HF Space
def generate_image(prompt):
    global pipeline
    if pipeline is None:
        print("Error: Pipeline is None (model not loaded)")  # Log this specifically
        return "Error: Model not loaded!"

    try:
        print("Starting image generation...")  # Log before the image generation
        image = pipeline(prompt).images[0]
        print("Image generated successfully!")
        return image
    except Exception as e:
        error_message = f"Error during image generation: {type(e).__name__}: {e}"  # Include exception type
        print(f"Full Error Details:\n{error_message}")  # Print full details
        return error_message  # Return error message to Gradio
    except RuntimeError as re:
        error_message = f"Runtime Error during image generation: {type(re).__name__}: {re}"  # Include exception type
        print(f"Full Runtime Error Details:\n{error_message}")  # Print full details
        return error_message  # Return error message to Gradio

# 5. Gradio interface
with gr.Blocks() as demo:
    prompt_input = gr.Textbox(label="Prompt")
    image_output = gr.Image(label="Generated Image")
    generate_button = gr.Button("Generate")

    generate_button.click(
        fn=generate_image,
        inputs=prompt_input,
        outputs=image_output,
    )

demo.launch()