File size: 2,670 Bytes
5494b47
 
 
 
 
 
 
 
 
 
 
 
 
6be2e5b
 
 
 
 
5494b47
c5267aa
 
 
 
 
 
 
 
 
5494b47
 
 
6be2e5b
5494b47
6be2e5b
 
5494b47
6be2e5b
 
 
 
 
2b68859
6be2e5b
 
 
 
c5267aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5494b47
 
 
6be2e5b
 
 
 
 
5494b47
 
c5267aa
5494b47
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderKL

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae, torch_dtype=torch.float16, variant="fp16",
    use_safetensors=True
)


def load_model(custom_model):
    # This is where you load your trained weights
    pipe.load_lora_weights(custom_model)
    pipe.to("cuda")
    return "Model loaded!"

def infer (prompt, inf_steps, guidance_scale, seed, lora_weigth, progress=gr.Progress(track_tqdm=True)):
    generator = torch.Generator(device="cuda").manual_seed(seed)
    image = pipe(
        prompt=prompt, 
        num_inference_steps=inf_steps,
        guidance_scale = float(guidance_scale),
        generator=generator,
        cross_attention_kwargs={"scale": float(lora_weight)}
    ).images[0]
    return image

css = """
#col-container {max-width: 580px; margin-left: auto; margin-right: auto;}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
# SD-XL Custom Model Inference
        """)
        with gr.Row():
            with gr.Column():
                custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
                model_status = gr.Textbox(label="model status", interactive=False)
            load_model_btn = gr.Button("Load my model")
        
        prompt_in = gr.Textbox(label="Prompt")   
        inf_steps = gr.Slider(
            label="Inference steps",
            minimum=12,
            maximum=50,
            step=1,
            value=25
        )
        guidance_scale = gr.Slider(
            label="Guidance scale",
            minimum=0.1,
            maximum=0.9,
            step=0.1,
            value=7.5
        )
        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=500000,
            step=1,
            value=42
        )
        lora_weight = gr.Slider(
            label="LoRa weigth",
            minimum=0.0,
            maximum=10.0,
            step=0.01,
            value=0.9
        )
        submit_btn = gr.Button("Submit")
        image_out = gr.Image(label="Image output")

    load_model_btn.click(
        fn = load_model,
        inputs=[custom_model],
        outputs = [model_status]
    )
    submit_btn.click(
        fn = infer,
        inputs = [prompt_in, inf_steps, guidance_scale, seed, lora_weight],
        outputs = [image_out]
    )

demo.queue().launch()