File size: 2,474 Bytes
a9b26e4
f71d177
a9b26e4
cb92b08
a9b26e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
from diffusers import StableDiffusion3Pipeline
import os
import spaces

# Use the token saved in secrets
hf_token = os.getenv("HF_TOKEN")

# Specify the pre-trained model ID
model_id = "stabilityai/stable-diffusion-3.5-large"

# Global variable for the pipeline (only initialized once)
pipeline = None

# Function for initializing and caching the pipeline
def initialize_pipeline():
    global pipeline
    if pipeline is None:
        try:
            # Load the pipeline with mixed precision (FP16)
            pipeline = StableDiffusion3Pipeline.from_pretrained(
                model_id,
                use_auth_token=hf_token,
                torch_dtype=torch.float16,  # Use FP16 for mixed precision
            )
            # Enable model offloading and attention slicing for memory efficiency
            pipeline.enable_model_cpu_offload()
            pipeline.enable_attention_slicing()
            print("Pipeline initialized and cached.")
        except Exception as e:
            # Error handling for model loading issues
            print(f"Error loading the model: {e}")
            raise RuntimeError("Failed to initialize the model pipeline.")
    return pipeline

# Function for image generation, decorated to use GPU
@spaces.GPU(duration=65)
def generate_image(prompt):
    pipe = initialize_pipeline()  # Initialize the pipeline (only once)
    # Generate the image using the pipeline
    try:
        image = pipe(prompt).images[0]
    except Exception as e:
        # Catch errors during image generation (e.g., GPU/Memory errors)
        print(f"Error during image generation: {e}")
        raise RuntimeError("Image generation failed.")
    return image

# Set up Gradio interface with a simple input for text and output for image
interface = gr.Interface(fn=generate_image, inputs="text", outputs="image")

# Launch the interface
interface.launch()

# Optimize device and dtype handling for CUDA or CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Additional model validation (this is optional, more for debugging)
pipe = initialize_pipeline()  # Ensure the model is initialized and cached
if not pipe or not hasattr(pipe, 'transformer'):
    raise ValueError("Failed to load the model or the transformer component is missing.")

# Move the pipeline to the correct device (CUDA or CPU)
pipe = pipe.to(device)