File size: 1,589 Bytes
5494b47
 
 
 
 
 
 
 
 
 
 
 
 
6be2e5b
 
 
 
 
5494b47
 
 
 
 
 
6be2e5b
5494b47
6be2e5b
 
5494b47
6be2e5b
 
 
 
 
 
 
 
 
 
5494b47
 
 
6be2e5b
 
 
 
 
5494b47
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderKL

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae, torch_dtype=torch.float16, variant="fp16",
    use_safetensors=True
)


def load_model(custom_model):
    # This is where you load your trained weights
    pipe.load_lora_weights(custom_model)
    pipe.to("cuda")
    return "Model loaded!"

def infer (prompt):
    image = pipe(prompt=prompt, num_inference_steps=50).images[0]
    return image

css = """
#col-container {max-width: 580px; margin-left: auto; margin-right: auto;}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
# SD-XL Custom Model Inference
        """)
        with gr.Row():
            with gr.Column():
                custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name")
                model_status = gr.Textbox(label="model status", interactive=False)
            load_model_btn = gr.Button("Load my model")
        
        prompt_in = gr.Textbox(label="Prompt")   
        submit_btn = gr.Button("Submit")
        image_out = gr.Image(label="Image output")

    load_model_btn.click(
        fn = load_model,
        inputs=[custom_model],
        outputs = [model_status]
    )
    submit_btn.click(
        fn = infer,
        inputs = [prompt_in],
        outputs = [image_out]
    )

demo.queue().launch()