blanchon commited on
Commit
207e147
·
verified ·
1 Parent(s): 0b9d113

Update rgb2x/gradio_demo_rgb2x.py

Browse files
Files changed (1) hide show
  1. rgb2x/gradio_demo_rgb2x.py +30 -24
rgb2x/gradio_demo_rgb2x.py CHANGED
@@ -38,16 +38,21 @@ def generate(
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
  photo_name = photo.name
40
  if photo_name.endswith(".exr"):
41
- photo_tensor = load_exr_image(photo_name, tonemaping=True, clamp=True).to("cuda")
42
- else:
43
- photo_tensor = load_ldr_image(photo_name, from_srgb=True).to("cuda")
44
-
45
- # Resize to multiple of 8
46
- old_height = photo_tensor.shape[1]
47
- old_width = photo_tensor.shape[2]
 
 
 
 
 
 
48
  radio = old_height / old_width
49
  max_side = 1000
50
-
51
  if old_height > old_width:
52
  new_height = max_side
53
  new_width = int(new_height / radio)
@@ -55,10 +60,11 @@ def generate(
55
  new_width = max_side
56
  new_height = int(new_width * radio)
57
 
58
- new_width = new_width // 8 * 8
59
- new_height = new_height // 8 * 8
 
60
 
61
- photo_resized = torchvision.transforms.Resize((new_height, new_width))(photo_tensor)
62
 
63
  required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
64
  prompts = {
@@ -69,37 +75,36 @@ def generate(
69
  "irradiance": "Irradiance (diffuse lighting)",
70
  }
71
 
72
- return_list: list[Image.Image] = []
73
-
74
  for i in range(num_samples):
75
  for aov_name in required_aovs:
76
  prompt = prompts[aov_name]
77
- result = pipe(
78
  prompt=prompt,
79
- photo=photo_resized,
80
  num_inference_steps=inference_step,
81
  height=new_height,
82
  width=new_width,
83
  generator=generator,
84
  required_aovs=[aov_name],
 
 
 
 
85
  )
86
- image_tensor = result.images[0][0] # type: ignore
87
- image_tensor = torchvision.transforms.Resize((old_height, old_width))(image_tensor)
88
- image_pil = torchvision.transforms.ToPILImage()(image_tensor.cpu())
89
- return_list.append(image_pil)
90
 
91
- # Also return the input image at the end
92
- input_image_pil = torchvision.transforms.ToPILImage()(photo_tensor.cpu())
93
- return_list.append(input_image_pil)
94
 
 
95
  return return_list
96
 
97
 
98
  with gr.Blocks() as demo:
99
  with gr.Row():
100
  gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
101
-
102
  with gr.Row():
 
103
  with gr.Column():
104
  gr.Markdown("### Given Image")
105
  photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
@@ -129,6 +134,7 @@ with gr.Blocks() as demo:
129
  value=1,
130
  )
131
 
 
132
  with gr.Column():
133
  gr.Markdown("### Output Gallery")
134
  result_gallery = gr.Gallery(
@@ -162,4 +168,4 @@ with gr.Blocks() as demo:
162
 
163
 
164
  if __name__ == "__main__":
165
- demo.launch(debug=False, share=False, show_api=False)
 
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
  photo_name = photo.name
40
  if photo_name.endswith(".exr"):
41
+ photo = load_exr_image(photo_name, tonemaping=True, clamp=True).to("cuda")
42
+ elif (
43
+ photo_name.endswith(".png")
44
+ or photo_name.endswith(".jpg")
45
+ or photo_name.endswith(".jpeg")
46
+ ):
47
+ photo = load_ldr_image(photo_name, from_srgb=True).to("cuda")
48
+
49
+ # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
50
+ old_height = photo.shape[1]
51
+ old_width = photo.shape[2]
52
+ new_height = old_height
53
+ new_width = old_width
54
  radio = old_height / old_width
55
  max_side = 1000
 
56
  if old_height > old_width:
57
  new_height = max_side
58
  new_width = int(new_height / radio)
 
60
  new_width = max_side
61
  new_height = int(new_width * radio)
62
 
63
+ if new_width % 8 != 0 or new_height % 8 != 0:
64
+ new_width = new_width // 8 * 8
65
+ new_height = new_height // 8 * 8
66
 
67
+ photo = torchvision.transforms.Resize((new_height, new_width))(photo)
68
 
69
  required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
70
  prompts = {
 
75
  "irradiance": "Irradiance (diffuse lighting)",
76
  }
77
 
78
+ return_list = []
 
79
  for i in range(num_samples):
80
  for aov_name in required_aovs:
81
  prompt = prompts[aov_name]
82
+ generated_image = pipe(
83
  prompt=prompt,
84
+ photo=photo,
85
  num_inference_steps=inference_step,
86
  height=new_height,
87
  width=new_width,
88
  generator=generator,
89
  required_aovs=[aov_name],
90
+ ).images[0][0] # type: ignore
91
+
92
+ generated_image = torchvision.transforms.Resize((old_height, old_width))(
93
+ generated_image
94
  )
 
 
 
 
95
 
96
+ generated_image = (generated_image, f"Generated {aov_name} {i}")
97
+ return_list.append(generated_image)
 
98
 
99
+ return_list.append((photo_name, "Input Image"))
100
  return return_list
101
 
102
 
103
  with gr.Blocks() as demo:
104
  with gr.Row():
105
  gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
 
106
  with gr.Row():
107
+ # Input side
108
  with gr.Column():
109
  gr.Markdown("### Given Image")
110
  photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
 
134
  value=1,
135
  )
136
 
137
+ # Output side
138
  with gr.Column():
139
  gr.Markdown("### Output Gallery")
140
  result_gallery = gr.Gallery(
 
168
 
169
 
170
  if __name__ == "__main__":
171
+ demo.launch(debug=False, share=False, show_api=False)