xp3857 commited on
Commit
a65f78b
·
1 Parent(s): f7ebd0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -55
app.py CHANGED
@@ -66,35 +66,30 @@ def inferRestoration(img, model_name):
66
  result = transforms.ToPILImage()(result)
67
  return result
68
 
69
- def inferColorization(img,model_name):
70
- #print(model_name)
71
- if model_name == "Pix2Pix Resnet 9block":
72
- model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b')
73
- elif model_name == "Pix2Pix Unet 256":
74
- model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256')
75
- elif model_name == "Deoldify":
76
- model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization')
77
- transform_list = [
78
- transforms.ToTensor(),
79
- transforms.Normalize((0.5,), (0.5,))
80
- ]
81
- transform = transforms.Compose(transform_list)
82
- #a = transforms.ToTensor()(a)
83
- img = img.convert('L')
84
- img = transform(img)
85
- img = torch.unsqueeze(img, 0)
86
- result = model(img)
87
-
88
- result = result[0].detach()
89
- result = (result +1)/2.0
90
-
91
- #img = transforms.Grayscale(3)(img)
92
- #img = transforms.ToTensor()(img)
93
- #img = torch.unsqueeze(img, 0)
94
- #result = model(img)
95
- #result = torch.clip(result, min=0, max=1)
96
- image_pil = transforms.ToPILImage()(result)
97
- return image_pil
98
 
99
  transform_seq = get_transform(model_name)
100
  img = transform_seq(img)
@@ -129,37 +124,31 @@ def run_cmd(command):
129
  sys.exit(1)
130
 
131
  def run(image,Restoration_mode, Colorizaition_mode):
132
- if Restoration_mode == "BOPBTL":
133
- if os.path.isdir("Temp"):
134
- shutil.rmtree("Temp")
135
-
136
- os.makedirs("Temp")
137
- os.makedirs("Temp/input")
138
- print(type(image))
139
- cv2.imwrite("Temp/input/input_img.png", image)
140
-
141
- command = ("python run.py --input_folder "
142
- + "Temp/input"
143
- + " --output_folder "
144
- + "Temp"
145
- + " --GPU "
146
- + "-1"
147
- + " --with_scratch")
148
- run_cmd(command)
149
-
150
- result_restoration = Image.open("Temp/final_output/input_img.png")
151
  shutil.rmtree("Temp")
152
 
153
- elif Restoration_mode == "Pix2Pix":
154
- result_restoration = inferRestoration(image, Restoration_mode)
155
- print("Restoration_mode",Restoration_mode)
 
 
 
 
 
 
 
 
 
 
156
 
 
 
 
157
  result_colorization = inferColorization(result_restoration,Colorizaition_mode)
158
 
159
  return result_colorization
160
 
161
- examples = [['example/1.jpeg',"BOPBTL","Deoldify"],['example/2.jpg',"BOPBTL","Deoldify"],['example/3.jpg',"BOPBTL","Deoldify"],['example/4.jpg',"BOPBTL","Deoldify"]]
162
  iface = gr.Interface(run,
163
- [gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify", "Pix2Pix Resnet 9block","Pix2Pix Unet 256"])],
164
- outputs="image",
165
- examples=examples).launch(debug=True,share=False)
 
66
  result = transforms.ToPILImage()(result)
67
  return result
68
 
69
+ def inferColorization(img):
70
+ model_name == "Deoldify"
71
+ model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization')
72
+ transform_list = [
73
+ transforms.ToTensor(),
74
+ transforms.Normalize((0.5,), (0.5,))
75
+ ]
76
+ transform = transforms.Compose(transform_list)
77
+ #a = transforms.ToTensor()(a)
78
+ img = img.convert('L')
79
+ img = transform(img)
80
+ img = torch.unsqueeze(img, 0)
81
+ result = model(img)
82
+
83
+ result = result[0].detach()
84
+ result = (result +1)/2.0
85
+
86
+ #img = transforms.Grayscale(3)(img)
87
+ #img = transforms.ToTensor()(img)
88
+ #img = torch.unsqueeze(img, 0)
89
+ #result = model(img)
90
+ #result = torch.clip(result, min=0, max=1)
91
+ image_pil = transforms.ToPILImage()(result)
92
+ return image_pil
 
 
 
 
 
93
 
94
  transform_seq = get_transform(model_name)
95
  img = transform_seq(img)
 
124
  sys.exit(1)
125
 
126
  def run(image,Restoration_mode, Colorizaition_mode):
127
+ Restoration_mode == "BOPBTL"
128
+ if os.path.isdir("Temp"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  shutil.rmtree("Temp")
130
 
131
+ os.makedirs("Temp")
132
+ os.makedirs("Temp/input")
133
+ print(type(image))
134
+ cv2.imwrite("Temp/input/input_img.png", image)
135
+
136
+ command = ("python run.py --input_folder "
137
+ + "Temp/input"
138
+ + " --output_folder "
139
+ + "Temp"
140
+ + " --GPU "
141
+ + "-1"
142
+ + " --with_scratch")
143
+ run_cmd(command)
144
 
145
+ result_restoration = Image.open("Temp/final_output/input_img.png")
146
+ shutil.rmtree("Temp")
147
+
148
  result_colorization = inferColorization(result_restoration,Colorizaition_mode)
149
 
150
  return result_colorization
151
 
 
152
  iface = gr.Interface(run,
153
+ [gr.inputs.Image(),gr.inputs.Radio(["BOPBTL", "Pix2Pix"]),gr.inputs.Radio(["Deoldify"])],
154
+ outputs="image").launch(debug=True,share=False)