File size: 2,655 Bytes
c2a8649
78f6a44
6031dc7
 
 
78f6a44
 
 
6031dc7
78f6a44
 
 
 
 
 
234658e
78f6a44
 
 
 
 
 
 
 
 
 
 
1888310
 
78f6a44
 
 
c2a8649
78f6a44
 
 
 
 
 
9af81fd
6031dc7
 
1888310
9af81fd
c2a8649
 
78f6a44
 
 
 
 
 
 
 
 
 
 
 
 
c2a8649
 
 
6031dc7
78f6a44
6031dc7
 
78f6a44
 
 
 
 
c2a8649
6031dc7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch

# Function to automatically switch between GPU and CPU
def load_model(base_model_id, adapter_model_id):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    info = f"Running on {'GPU (CUDA) 🔥' if device == 'cuda' else 'CPU 🥶'}"
    
    try:
        # Load the base model dynamically on the correct device
        pipe = StableDiffusionPipeline.from_pretrained(
            base_model_id, 
            torch_dtype=torch.float16 if device == "cuda" else torch.float32
        ).to(device)

        # If an adapter model is provided, load and merge the adapter model
        if adapter_model_id:
            adapter_pipe = DiffusionPipeline.from_pretrained(adapter_model_id)
            adapter_pipe.load_lora_weights(base_model_id)
            pipe = pipe.to(device)

        return pipe, info
    except Exception as e:
        return None, f"Error loading model: {str(e)}"

# Function for text-to-image generation
def generate_image(base_model_id, adapter_model_id, prompt):
    pipe, info = load_model(base_model_id, adapter_model_id)
    
    if pipe is None:
        return None, info

    # Generate image based on the prompt
    try:
        image = pipe(prompt).images[0]
        return image, info
    except Exception as e:
        return None, f"Error generating image: {str(e)}"

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")

    with gr.Row():
        with gr.Column():
            base_model_id = gr.Textbox(
                label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", 
                placeholder="Base Model ID"
            )
            adapter_model_id = gr.Textbox(
                label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)", 
                placeholder="Adapter Model ID (optional)", 
                value=""
            )
            prompt = gr.Textbox(
                label="Enter your prompt", 
                placeholder="Describe the image you want to generate"
            )
            generate_btn = gr.Button("Generate Image")
        
        with gr.Column():
            output_image = gr.Image(label="Generated Image")
            device_info = gr.Markdown()  # To display device info and any error messages
    
    # Link the button to the image generation function
    generate_btn.click(
        fn=generate_image, 
        inputs=[base_model_id, adapter_model_id, prompt], 
        outputs=[output_image, device_info]
    )

# Launch the app
demo.launch()