File size: 5,437 Bytes
4120479
 
 
a8a382e
 
4120479
 
 
1475e41
4120479
 
 
 
 
 
 
20706a7
4120479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cfc613
84e6fa0
4120479
1d5c6d0
4120479
1d5c6d0
4120479
 
 
 
1d5c6d0
4120479
 
 
 
1d5c6d0
4120479
 
 
 
 
 
 
 
ea9cf0a
4120479
 
 
 
 
 
 
 
a56c826
 
4120479
a56c826
 
4120479
 
 
 
 
88dfda8
4120479
 
 
 
 
1d5c6d0
a56c826
4120479
 
 
1d5c6d0
a56c826
4120479
 
 
 
 
 
 
 
 
 
 
 
a56c826
 
4120479
 
 
a56c826
4120479
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download

import torch
import json
import random
import copy

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co/spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

saved_names = [
    hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
]

css = '''
#title{text-align:center}
#plus_column{align-self: center}
#plus_button{font-size: 250%; text-align: center}
.gradio-container{width: 700px !important; margin: 0 auto !important}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 57px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
    border-top-left-radius: 0px;}
'''

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) 
original_pipe = copy.deepcopy(pipe)

#@spaces.GPU
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
  print("Copying pipe")
  pipe = copy.deepcopy(original_pipe)
  print("Loading LoRAs")
  pipe.load_lora_weights(shuffled_items[0]['repo'], weight_name=shuffled_items[0]['weights'])
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(shuffled_items[1]['repo'], weight_name=shuffled_items[1]['weights'])
  pipe.fuse_lora(lora_2_scale)
  
  pipe.to(torch_dtype=torch.float16)
  pipe.to("cuda")
  if negative_prompt == "":
    negative_prompt = False
  print("Running inference")
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, guidance_scale=7).images[0]
  return image

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"LoRA trigger word: `{trigger_word}`" if trigger_word else "LoRA trigger word: `none`, will be applied automatically", trigger_word
    
def shuffle_images():
    compatible_items = [item for item in sdxl_loras if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1  = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])

    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])
    
    prompt_description_1 = gr.update(value=description_1, visible=True)
    prompt_description_2 = gr.update(value=description_2, visible=True)
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    
    return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎲</h1>
        <h4>This 2 LoRAs are loaded to SDXL at random, find a fun way to combine them 🎨</h4>
        ''',
        elem_id="title"
  )
  with gr.Row():
    with gr.Column(min_width=10, scale=6):
      lora_1 = gr.Image(interactive=False, height=300)
      lora_1_prompt = gr.Markdown(visible=False)
    with gr.Column(min_width=10, scale=1, elem_id="plus_column"):
      plus = gr.HTML("+", elem_id="plus_button")
    with gr.Column(min_width=10, scale=6):
      lora_2 = gr.Image(interactive=False, height=300)
      lora_2_prompt = gr.Markdown(visible=False)
  with gr.Row():
    prompt = gr.Textbox(label="Your prompt", info="arrange the trigger words of the two LoRAs in a coherent sentence", interactive=True, elem_id="prompt")
    run_btn = gr.Button("Run", elem_id="run_button")
  
  output_image = gr.Image()
  with gr.Accordion("Advanced settings", open=False):
    negative_prompt = gr.Textbox(label="Negative prompt")
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
  shuffle_button = gr.Button("Reshuffle LoRAs!")
  
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")

  run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
  prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])

demo.queue()
demo.launch()