|
import gradio as gr |
|
from PIL import Image |
|
import os |
|
import spaces |
|
|
|
from OmniGen import OmniGenPipeline |
|
|
|
pipe = OmniGenPipeline.from_pretrained( |
|
"shitao/tmp-preview" |
|
) |
|
|
|
@spaces.GPU |
|
|
|
def generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps): |
|
input_images = [img1, img2, img3] |
|
|
|
input_images = [img for img in input_images if img is not None] |
|
if len(input_images) == 0: |
|
input_images = None |
|
|
|
output = pipe( |
|
prompt=text, |
|
input_images=input_images, |
|
height=height, |
|
width=width, |
|
guidance_scale=guidance_scale, |
|
img_guidance_scale=1.6, |
|
num_inference_steps=inference_steps, |
|
separate_cfg_infer=True, |
|
use_kv_cache=False, |
|
) |
|
img = output[0] |
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_example(): |
|
case = [ |
|
[ |
|
"A woman holds a bouquet of flowers and faces the camera. Thw woman is the one in <img><|image_1|></img>.", |
|
"./imgs/test_cases/liuyifei.png", |
|
None, |
|
None, |
|
1024, |
|
1024, |
|
3.0, |
|
20, |
|
], |
|
[ |
|
"Three zebras are standing side by side on a vibrant savannah, each showcasing unique patterns and characteristics that highlight their individuality. The zebra on the left has a strikingly bold black and white stripe pattern, with wider stripes that create a dramatic contrast against its sleek body. In the middle, the zebra features a more subtle stripe arrangement, with thinner stripes that blend seamlessly into a slightly sandy-colored coat, giving it a softer appearance. On the right, the zebra's stripes are more irregular, with a distinct patch of brown fur near its shoulder, adding a layer of uniqueness to its overall look. Together, these zebras create a captivating scene, each representing the diverse beauty of their species in the wild. The right zebras is the zebras from <img><|image_1|></img>. The center zebras is from <img><|image_2|></img>. The left zebras is the zebras from <img><|image_3|></img>.", |
|
"./imgs/test_cases/img1.jpg", |
|
"./imgs/test_cases/img2.jpg", |
|
"./imgs/test_cases/img3.jpg", |
|
1024, |
|
1024, |
|
3.0, |
|
20, |
|
], |
|
] |
|
return case |
|
|
|
def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, inference_steps): |
|
return generate_image(text, img1, img2, img3, height, width, guidance_scale, inference_steps) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Text + Multiple Images to Image Generator") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
prompt_input = gr.Textbox( |
|
label="Enter your prompt", placeholder="Type your prompt here..." |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
|
|
image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath") |
|
image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath") |
|
image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath") |
|
|
|
|
|
height_input = gr.Slider( |
|
label="Height", minimum=256, maximum=2048, value=1024, step=16 |
|
) |
|
width_input = gr.Slider( |
|
label="Width", minimum=256, maximum=2048, value=1024, step=16 |
|
) |
|
|
|
|
|
guidance_scale_input = gr.Slider( |
|
label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1 |
|
) |
|
|
|
num_inference_steps = gr.Slider( |
|
label="Inference Steps", minimum=1, maximum=50, value=50, step=1 |
|
) |
|
|
|
|
|
generate_button = gr.Button("Generate Image") |
|
|
|
with gr.Column(): |
|
|
|
output_image = gr.Image(label="Output Image") |
|
|
|
|
|
generate_button.click( |
|
generate_image, |
|
inputs=[ |
|
prompt_input, |
|
image_input_1, |
|
image_input_2, |
|
image_input_3, |
|
height_input, |
|
width_input, |
|
guidance_scale_input, |
|
num_inference_steps, |
|
], |
|
outputs=output_image, |
|
) |
|
|
|
gr.Examples( |
|
examples=get_example(), |
|
fn=run_for_examples, |
|
inputs=[ |
|
prompt_input, |
|
image_input_1, |
|
image_input_2, |
|
image_input_3, |
|
height_input, |
|
width_input, |
|
guidance_scale_input, |
|
num_inference_steps, |
|
], |
|
outputs=output_image, |
|
) |
|
|
|
|
|
demo.launch() |