reiirene commited on
Commit
e235264
·
verified ·
1 Parent(s): 0bd8737

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -28
app.py CHANGED
@@ -1,23 +1,27 @@
 
 
1
  import torch
2
  from diffusers import DiffusionPipeline
3
- import huggingface_hub
4
- import requests
5
  from PIL import Image
6
- from io import BytesIO
7
- import numpy as np
8
- import gradio as gr
9
 
10
- multi_view_diffusion_pipeline = DiffusionPipeline.from_pretrained(
11
- "reiirene/multi-view-diffusion",
12
- custom_pipeline="reiirene/multi-view-diffusion",
 
 
 
 
 
 
 
 
 
 
 
13
  torch_dtype=torch.float16,
14
  trust_remote_code=True,
15
  ).to("cuda")
16
 
17
- image_url = "https://huggingface.co/datasets/dylanebert/3d-arena/resolve/main/inputs/images/a_cat_statue.jpg"
18
- response = requests.get(image_url)
19
- image = Image.open(BytesIO(response.content))
20
- image
21
 
22
  def create_image_grid(images):
23
  images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
@@ -32,26 +36,53 @@ def create_image_grid(images):
32
 
33
  return grid_img
34
 
35
- image = np.array(image, dtype=np.float32) / 255.0
36
- images = multi_view_diffusion_pipeline("", image, guidance_scale=5, num_inference_steps=30, elevation=0)
37
 
38
- create_image_grid(images)
 
 
 
 
 
 
39
 
40
- def run(image):
41
- image = np.array(image, dtype=np.float32) / 255.0
42
- images = multi_view_diffusion_pipeline("", image, guidance_scale=5, num_inference_steps=30, elevation=0)
 
 
 
 
43
 
44
- images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
45
 
46
- width, height = images[0].size
47
- grid_img = Image.new("RGB", (2 * width, 2 * height))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- grid_img.paste(images[0], (0, 0))
50
- grid_img.paste(images[1], (width, 0))
51
- grid_img.paste(images[2], (0, height))
52
- grid_img.paste(images[3], (width, height))
53
 
54
- return grid_img
55
 
56
- demo = gr.Interface(fn=run, inputs="image", outputs="image")
57
- demo.launch(debug=True)
 
1
+ import gradio as gr
2
+ import spaces
3
  import torch
4
  from diffusers import DiffusionPipeline
 
 
5
  from PIL import Image
 
 
 
6
 
7
+
8
+ # Text-to-Multi-View Diffusion pipeline
9
+ text_pipeline = DiffusionPipeline.from_pretrained(
10
+ "dylanebert/mvdream",
11
+ custom_pipeline="dylanebert/multi-view-diffusion",
12
+ torch_dtype=torch.float16,
13
+ trust_remote_code=True,
14
+ ).to("cuda")
15
+
16
+
17
+ # Image-to-Multi-View Diffusion pipeline
18
+ image_pipeline = DiffusionPipeline.from_pretrained(
19
+ "dylanebert/multi-view-diffusion",
20
+ custom_pipeline="dylanebert/multi-view-diffusion",
21
  torch_dtype=torch.float16,
22
  trust_remote_code=True,
23
  ).to("cuda")
24
 
 
 
 
 
25
 
26
  def create_image_grid(images):
27
  images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
 
36
 
37
  return grid_img
38
 
 
 
39
 
40
+ @spaces.GPU
41
+ def text_to_mv(prompt):
42
+ images = text_pipeline(
43
+ prompt, guidance_scale=5, num_inference_steps=30, elevation=0
44
+ )
45
+ return create_image_grid(images)
46
+
47
 
48
+ @spaces.GPU
49
+ def image_to_mv(image, prompt):
50
+ image = image.astype("float32") / 255.0
51
+ images = image_pipeline(
52
+ prompt, image, guidance_scale=5, num_inference_steps=30, elevation=0
53
+ )
54
+ return create_image_grid(images)
55
 
 
56
 
57
+ with gr.Blocks() as demo:
58
+ with gr.Row():
59
+ with gr.Column():
60
+ with gr.Tab("Text Input"):
61
+ text_input = gr.Textbox(
62
+ lines=2,
63
+ show_label=False,
64
+ placeholder="Enter a prompt here (e.g. 'a cat statue')",
65
+ )
66
+ text_btn = gr.Button("Generate Multi-View Images")
67
+ with gr.Tab("Image Input"):
68
+ image_input = gr.Image(
69
+ label="Image Input",
70
+ type="numpy",
71
+ )
72
+ optional_text_input = gr.Textbox(
73
+ lines=2,
74
+ show_label=False,
75
+ placeholder="Enter an optional prompt here",
76
+ )
77
+ image_btn = gr.Button("Generate Multi-View Images")
78
+ with gr.Column():
79
+ output = gr.Image(label="Generated Images")
80
 
81
+ text_btn.click(fn=text_to_mv, inputs=text_input, outputs=output)
82
+ image_btn.click(
83
+ fn=image_to_mv, inputs=[image_input, optional_text_input], outputs=output
84
+ )
85
 
 
86
 
87
+ if __name__ == "__main__":
88
+ demo.queue().launch()