Spaces:
Runtime error
Runtime error
Commit
·
b6eec52
1
Parent(s):
2b6d53c
Basic version is complete and working.
Browse files- app.py +16 -8
- pickle_lama_model.ipynb +12 -2
app.py
CHANGED
|
@@ -1,17 +1,25 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
|
| 6 |
-
def predict(input_img):
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
gradio_app = gr.Interface(
|
| 11 |
predict,
|
| 12 |
-
inputs=
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
|
| 4 |
|
| 5 |
+
model = torch.jit.load("models/lama.pt")
|
| 6 |
|
| 7 |
+
def predict(input_img, input_mask):
|
| 8 |
+
# numpy gives the image as (w,h,c)
|
| 9 |
+
# Image shape should be (1, 3, 512, 512) and be in the range 0-1.
|
| 10 |
+
# Mask shape should be (1, 1, 512, 512) AND have values 0.0 or 1.0, not in-between.
|
| 11 |
+
#out = model(torch.tensor(input_img[None, (2,0,1), :, :])/255.0, torch.tensor(1 * (input_mask[:,:,0] > 0)).unsqueeze(0))
|
| 12 |
+
out = model((pil_to_tensor(input_img.convert('RGB')) / 255.0).unsqueeze(0), 1 * (pil_to_tensor(input_mask.convert('L')) > 0).unsqueeze(0))[0]
|
| 13 |
+
return to_pil_image(out)
|
| 14 |
|
| 15 |
gradio_app = gr.Interface(
|
| 16 |
predict,
|
| 17 |
+
inputs=[
|
| 18 |
+
gr.Image(label="Select Base Image", sources=['upload',], type="pil"),
|
| 19 |
+
gr.Image(label="Select Image Mask (White will be inpainted)", sources=['upload',], type="pil"),
|
| 20 |
+
],
|
| 21 |
+
outputs=[gr.Image(label="Inpainted Image"),],
|
| 22 |
+
title="LAMA Inpainting",
|
| 23 |
)
|
| 24 |
|
| 25 |
if __name__ == "__main__":
|
pickle_lama_model.ipynb
CHANGED
|
@@ -160,7 +160,7 @@
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "code",
|
| 163 |
-
"execution_count":
|
| 164 |
"id": "163db07c-93a3-40d2-837d-4fade79b07f0",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [
|
|
@@ -181,12 +181,22 @@
|
|
| 181 |
},
|
| 182 |
"metadata": {},
|
| 183 |
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
}
|
| 185 |
],
|
| 186 |
"source": [
|
| 187 |
"print(out['predicted_image'].shape)\n",
|
| 188 |
"import numpy\n",
|
| 189 |
-
"display(tvf.to_pil_image((out['predicted_image'])[0]))"
|
|
|
|
|
|
|
| 190 |
]
|
| 191 |
},
|
| 192 |
{
|
|
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "code",
|
| 163 |
+
"execution_count": 76,
|
| 164 |
"id": "163db07c-93a3-40d2-837d-4fade79b07f0",
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [
|
|
|
|
| 181 |
},
|
| 182 |
"metadata": {},
|
| 183 |
"output_type": "display_data"
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"name": "stdout",
|
| 187 |
+
"output_type": "stream",
|
| 188 |
+
"text": [
|
| 189 |
+
"tensor(1.)\n",
|
| 190 |
+
"tensor(1)\n"
|
| 191 |
+
]
|
| 192 |
}
|
| 193 |
],
|
| 194 |
"source": [
|
| 195 |
"print(out['predicted_image'].shape)\n",
|
| 196 |
"import numpy\n",
|
| 197 |
+
"display(tvf.to_pil_image((out['predicted_image'])[0]))\n",
|
| 198 |
+
"print(torch.max(image))\n",
|
| 199 |
+
"print(torch.max(mask))"
|
| 200 |
]
|
| 201 |
},
|
| 202 |
{
|