blanchon commited on
Commit
7c54b21
·
verified ·
1 Parent(s): f8dd2d5

Update rgb2x/gradio_demo_rgb2x.py

Browse files
Files changed (1) hide show
  1. rgb2x/gradio_demo_rgb2x.py +23 -29
rgb2x/gradio_demo_rgb2x.py CHANGED
@@ -38,21 +38,16 @@ def generate(
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
  photo_name = photo.name
40
  if photo_name.endswith(".exr"):
41
- photo = load_exr_image(photo_name, tonemaping=True, clamp=True).to("cuda")
42
- elif (
43
- photo_name.endswith(".png")
44
- or photo_name.endswith(".jpg")
45
- or photo_name.endswith(".jpeg")
46
- ):
47
- photo = load_ldr_image(photo_name, from_srgb=True).to("cuda")
48
-
49
- # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
50
- old_height = photo.shape[1]
51
- old_width = photo.shape[2]
52
- new_height = old_height
53
- new_width = old_width
54
  radio = old_height / old_width
55
  max_side = 1000
 
56
  if old_height > old_width:
57
  new_height = max_side
58
  new_width = int(new_height / radio)
@@ -60,11 +55,10 @@ def generate(
60
  new_width = max_side
61
  new_height = int(new_width * radio)
62
 
63
- if new_width % 8 != 0 or new_height % 8 != 0:
64
- new_width = new_width // 8 * 8
65
- new_height = new_height // 8 * 8
66
 
67
- photo = torchvision.transforms.Resize((new_height, new_width))(photo)
68
 
69
  required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
70
  prompts = {
@@ -75,36 +69,37 @@ def generate(
75
  "irradiance": "Irradiance (diffuse lighting)",
76
  }
77
 
78
- return_list = []
 
79
  for i in range(num_samples):
80
  for aov_name in required_aovs:
81
  prompt = prompts[aov_name]
82
- generated_image = pipe(
83
  prompt=prompt,
84
- photo=photo,
85
  num_inference_steps=inference_step,
86
  height=new_height,
87
  width=new_width,
88
  generator=generator,
89
  required_aovs=[aov_name],
90
- ).images[0][0] # type: ignore
91
-
92
- generated_image = torchvision.transforms.Resize((old_height, old_width))(
93
- generated_image
94
  )
 
 
 
 
95
 
96
- generated_image = (generated_image, f"Generated {aov_name} {i}")
97
- return_list.append(generated_image)
 
98
 
99
- return_list.append((photo_name, "Input Image"))
100
  return return_list
101
 
102
 
103
  with gr.Blocks() as demo:
104
  with gr.Row():
105
  gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
 
106
  with gr.Row():
107
- # Input side
108
  with gr.Column():
109
  gr.Markdown("### Given Image")
110
  photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
@@ -134,7 +129,6 @@ with gr.Blocks() as demo:
134
  value=1,
135
  )
136
 
137
- # Output side
138
  with gr.Column():
139
  gr.Markdown("### Output Gallery")
140
  result_gallery = gr.Gallery(
 
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
  photo_name = photo.name
40
  if photo_name.endswith(".exr"):
41
+ photo_tensor = load_exr_image(photo_name, tonemaping=True, clamp=True).to("cuda")
42
+ else:
43
+ photo_tensor = load_ldr_image(photo_name, from_srgb=True).to("cuda")
44
+
45
+ # Resize to multiple of 8
46
+ old_height = photo_tensor.shape[1]
47
+ old_width = photo_tensor.shape[2]
 
 
 
 
 
 
48
  radio = old_height / old_width
49
  max_side = 1000
50
+
51
  if old_height > old_width:
52
  new_height = max_side
53
  new_width = int(new_height / radio)
 
55
  new_width = max_side
56
  new_height = int(new_width * radio)
57
 
58
+ new_width = new_width // 8 * 8
59
+ new_height = new_height // 8 * 8
 
60
 
61
+ photo_resized = torchvision.transforms.Resize((new_height, new_width))(photo_tensor)
62
 
63
  required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
64
  prompts = {
 
69
  "irradiance": "Irradiance (diffuse lighting)",
70
  }
71
 
72
+ return_list: list[Image.Image] = []
73
+
74
  for i in range(num_samples):
75
  for aov_name in required_aovs:
76
  prompt = prompts[aov_name]
77
+ result = pipe(
78
  prompt=prompt,
79
+ photo=photo_resized,
80
  num_inference_steps=inference_step,
81
  height=new_height,
82
  width=new_width,
83
  generator=generator,
84
  required_aovs=[aov_name],
 
 
 
 
85
  )
86
+ image_tensor = result.images[0][0] # type: ignore
87
+ image_tensor = torchvision.transforms.Resize((old_height, old_width))(image_tensor)
88
+ image_pil = torchvision.transforms.ToPILImage()(image_tensor.cpu())
89
+ return_list.append(image_pil)
90
 
91
+ # Also return the input image at the end
92
+ input_image_pil = torchvision.transforms.ToPILImage()(photo_tensor.cpu())
93
+ return_list.append(input_image_pil)
94
 
 
95
  return return_list
96
 
97
 
98
  with gr.Blocks() as demo:
99
  with gr.Row():
100
  gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
101
+
102
  with gr.Row():
 
103
  with gr.Column():
104
  gr.Markdown("### Given Image")
105
  photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
 
129
  value=1,
130
  )
131
 
 
132
  with gr.Column():
133
  gr.Markdown("### Output Gallery")
134
  result_gallery = gr.Gallery(