Yuanshi commited on
Commit
9fbf1b0
·
1 Parent(s): 759b91c
Files changed (2) hide show
  1. app.py +163 -4
  2. ominicontrol.py +129 -0
app.py CHANGED
@@ -1,7 +1,166 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from ominicontrol import generate_image
3
 
4
+ import spaces
 
5
 
6
+ USE_ZERO_GPU = True
7
+
8
+ css = """
9
+ .inputPanel {
10
+ width: 320px;
11
+ display: flex;
12
+ align-items: center;
13
+ }
14
+ .outputPanel {
15
+ display: flex;
16
+ align-items: center;
17
+ }
18
+ .hint {
19
+ font-size: 14px;
20
+ color: #777;
21
+ # border: 1px solid #ccc;
22
+ padding: 4px;
23
+ border-radius: 5px;
24
+ # background-color: #efefef;
25
+ }
26
+ """
27
+
28
+
29
+ def style_transfer(image, style):
30
+ return image
31
+
32
+
33
+ styles = [
34
+ "Studio Ghibli",
35
+ "Irasutoya Illustration",
36
+ "The Simpsons",
37
+ "Snoopy",
38
+ ]
39
+
40
+
41
+ def gradio_interface():
42
+ with gr.Blocks(css=css) as demo:
43
+ gr.Markdown("# 🌍 OminiControl (Image Stylization)")
44
+ with gr.Row(equal_height=False):
45
+ with gr.Column(variant="panel", elem_classes="inputPanel"):
46
+ original_image = gr.Image(
47
+ type="pil",
48
+ label="Condition Image",
49
+ width=400,
50
+ height=400,
51
+ )
52
+ style = gr.Radio(
53
+ styles,
54
+ label="🎨 Select Style",
55
+ value=styles[0],
56
+ )
57
+ # Advanced settings
58
+ with gr.Accordion(
59
+ "⚙️ Advanced Settings", open=False
60
+ ) as advanced_settings:
61
+ inference_mode = gr.Radio(
62
+ ["High Quality", "Fast"],
63
+ value="High Quality",
64
+ label="Generating Mode",
65
+ )
66
+ image_ratio = gr.Radio(
67
+ ["Auto", "Square(1:1)", "Portrait(2:3)", "Landscape(3:2)"],
68
+ label="Image Ratio",
69
+ value="Auto",
70
+ )
71
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
72
+ seed = gr.Number(
73
+ label="Seed",
74
+ value=42,
75
+ visible=(not use_random_seed.value),
76
+ )
77
+ use_random_seed.change(
78
+ lambda x: gr.update(visible=(not x)),
79
+ use_random_seed,
80
+ seed,
81
+ show_progress="hidden",
82
+ )
83
+ image_guidance = gr.Slider(
84
+ label="Image Guidance",
85
+ minimum=1.1,
86
+ maximum=5,
87
+ value=1.5,
88
+ step=0.1,
89
+ )
90
+ steps = gr.Slider(
91
+ label="Steps",
92
+ minimum=10,
93
+ maximum=50,
94
+ value=20,
95
+ step=1,
96
+ )
97
+ inference_mode.change(
98
+ lambda x: gr.update(interactive=(x == "High Quality")),
99
+ inference_mode,
100
+ image_guidance,
101
+ show_progress="hidden",
102
+ )
103
+
104
+ btn = gr.Button("Generate Image")
105
+
106
+ with gr.Column(elem_classes="outputPanel"):
107
+ output_images = gr.Image(
108
+ type="pil",
109
+ width=640,
110
+ height=640,
111
+ label="Output Image",
112
+ )
113
+
114
+ btn.click(
115
+ fn=infer,
116
+ inputs=[
117
+ style,
118
+ original_image,
119
+ inference_mode,
120
+ image_guidance,
121
+ image_ratio,
122
+ use_random_seed,
123
+ seed,
124
+ steps,
125
+ ],
126
+ outputs=output_images,
127
+ )
128
+
129
+ return demo
130
+
131
+
132
+ def infer(
133
+ style,
134
+ original_image,
135
+ inference_mode,
136
+ image_guidance,
137
+ image_ratio,
138
+ use_random_seed,
139
+ seed,
140
+ steps,
141
+ ):
142
+ print(
143
+ f"Style: {style}, Inference Mode: {inference_mode}, Image Guidance: {image_guidance}, Image Ratio: {image_ratio}, Use Random Seed: {use_random_seed}, Seed: {seed}"
144
+ )
145
+ result_image = generate_image(
146
+ image=original_image,
147
+ style=style,
148
+ inference_mode=inference_mode,
149
+ image_guidance=image_guidance,
150
+ image_ratio=image_ratio,
151
+ use_random_seed=use_random_seed,
152
+ seed=seed,
153
+ steps=steps,
154
+ )
155
+ return result_image
156
+
157
+
158
+ if USE_ZERO_GPU:
159
+ infer = spaces.GPU(infer, duration=360)
160
+
161
+ if __name__ == "__main__":
162
+ demo = gradio_interface()
163
+ demo.launch(
164
+ debug=True,
165
+ server_name="0.0.0.0",
166
+ )
ominicontrol.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines import FluxPipeline
3
+ from OminiControl.src.flux.condition import Condition
4
+ from PIL import Image
5
+ import random
6
+ import os
7
+
8
+ from OminiControl.src.flux.generate import generate, seed_everything
9
+
10
+ HF_TOKEN=os.getenv("HF_TOKEN")
11
+
12
+ print("Loading model...")
13
+ pipe = FluxPipeline.from_pretrained(
14
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, use_auth_token=HF_TOKEN
15
+ )
16
+ pipe = pipe.to("cuda")
17
+
18
+ pipe.unload_lora_weights()
19
+
20
+ pipe.load_lora_weights(
21
+ "Yuanshi/OminiControlStyle",
22
+ weight_name=f"v0/ghibli.safetensors",
23
+ adapter_name="ghibli",
24
+ use_auth_token=HF_TOKEN
25
+ )
26
+ pipe.load_lora_weights(
27
+ "Yuanshi/OminiControlStyle",
28
+ weight_name=f"v0/irasutoya.safetensors",
29
+ adapter_name="irasutoya",
30
+ use_auth_token=HF_TOKEN
31
+ )
32
+ pipe.load_lora_weights(
33
+ "Yuanshi/OminiControlStyle",
34
+ weight_name=f"v0/simpsons.safetensors",
35
+ adapter_name="simpsons",
36
+ use_auth_token=HF_TOKEN
37
+ )
38
+ pipe.load_lora_weights(
39
+ "Yuanshi/OminiControlStyle",
40
+ weight_name=f"v0/snoopy.safetensors",
41
+ adapter_name="snoopy",
42
+ use_auth_token=HF_TOKEN
43
+ )
44
+
45
+
46
+
47
+ def generate_image(
48
+ image,
49
+ style,
50
+ inference_mode,
51
+ image_guidance,
52
+ image_ratio,
53
+ steps,
54
+ use_random_seed,
55
+ seed,
56
+ ):
57
+ # Prepare Condition
58
+ def resize(img, factor=16):
59
+ w, h = img.size
60
+ new_w, new_h = w // factor * factor, h // factor * factor
61
+ padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2
62
+ img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))
63
+ return img
64
+
65
+ # Set Adapter
66
+ activate_adapter_name = {
67
+ "Studio Ghibli": "ghibli",
68
+ "Irasutoya Illustration": "irasutoya",
69
+ "The Simpsons": "simpsons",
70
+ "Snoopy": "snoopy",
71
+ }[style]
72
+ pipe.set_adapters(activate_adapter_name)
73
+
74
+ factor = 512 / max(image.size)
75
+ image = resize(
76
+ image.resize(
77
+ (int(image.size[0] * factor), int(image.size[1] * factor)),
78
+ Image.LANCZOS,
79
+ )
80
+ )
81
+ delta = -image.size[0] // 16
82
+ condition = Condition(
83
+ "subject",
84
+ # activate_adapter_name,
85
+ image,
86
+ position_delta=(0, delta),
87
+ )
88
+
89
+ # Prepare seed
90
+ if use_random_seed:
91
+ seed = random.randint(0, 2**32 - 1)
92
+ seed_everything(seed)
93
+
94
+ # Image guidance scale
95
+ image_guidance = 1.0 if inference_mode == "Fast" else image_guidance
96
+
97
+ # Output size
98
+ if image_ratio == "Auto":
99
+ r = image.size[0] / image.size[1]
100
+ ratio = min([0.67, 1, 1.5], key=lambda x: abs(x - r))
101
+ else:
102
+ ratio = {
103
+ "Square(1:1)": 1,
104
+ "Portrait(2:3)": 0.67,
105
+ "Landscape(3:2)": 1.5,
106
+ }[image_ratio]
107
+ width, height = {
108
+ 0.67: (640, 960),
109
+ 1: (640, 640),
110
+ 1.5: (960, 640),
111
+ }[ratio]
112
+
113
+ print(
114
+ f"Image Ratio: {image_ratio}, Inference Mode: {inference_mode}, Image Guidance: {image_guidance}, Seed: {seed}, Steps: {steps}, Size: {width}x{height}"
115
+ )
116
+ # Generate
117
+ result_img = generate(
118
+ pipe,
119
+ prompt="",
120
+ conditions=[condition],
121
+ num_inference_steps=steps,
122
+ width=width,
123
+ height=height,
124
+ image_guidance_scale=image_guidance,
125
+ default_lora=True,
126
+ max_sequence_length=32,
127
+ ).images[0]
128
+
129
+ return result_img