prateekbh commited on
Commit
6e02423
·
verified ·
1 Parent(s): e3b409e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -1,14 +1,14 @@
1
- from threading import Thread
2
  import gradio as gr
3
  import torch
4
- from transformers import AutoModel, AutoProcessor
5
- from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
6
  import numpy as np
7
  import torch.nn.functional as F
 
 
 
 
8
  from torchvision.transforms.functional import normalize
9
  from huggingface_hub import hf_hub_download
10
  from briarmbg import BriaRMBG
11
- import PIL
12
  from PIL import Image
13
  from typing import Tuple
14
 
@@ -106,31 +106,30 @@ def process(image):
106
  orig_image = Image.fromarray(image)
107
  w,h = orig_im_size = orig_image.size
108
  image = resize_image(orig_image)
109
- return image.tobytes()
110
- # im_np = np.array(image)
111
- # im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
112
- # im_tensor = torch.unsqueeze(im_tensor,0)
113
- # im_tensor = torch.divide(im_tensor,255.0)
114
- # im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
115
- # if torch.cuda.is_available():
116
- # im_tensor=im_tensor.cuda()
117
-
118
- # #inference
119
- # result=net(im_tensor)
120
- # # post process
121
- # result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
122
- # ma = torch.max(result)
123
- # mi = torch.min(result)
124
- # result = (result-mi)/(ma-mi)
125
- # # image to pil
126
- # im_array = (result*255).cpu().data.numpy().astype(np.uint8)
127
- # pil_im = Image.fromarray(np.squeeze(im_array))
128
- # # paste the mask on the original image
129
- # new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
130
- # new_im.paste(orig_image, mask=pil_im)
131
- # # new_orig_image = orig_image.convert('RGBA')
132
-
133
- # return new_im
134
 
135
 
136
  title = """<h1 style="text-align: center;">Product description generator</h1>"""
@@ -149,7 +148,7 @@ with gr.Blocks(css=css) as demo:
149
  chat = gr.Chatbot(show_label=False)
150
  submit = gr.Button(value="Upload", variant="primary")
151
  with gr.Column():
152
- output = gr.Image(type="pil", interactive=False)
153
 
154
  response_handler = (
155
  response,
 
 
1
  import gradio as gr
2
  import torch
 
 
3
  import numpy as np
4
  import torch.nn.functional as F
5
+ import PIL
6
+ from threading import Thread
7
+ from transformers import AutoModel, AutoProcessor
8
+ from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
9
  from torchvision.transforms.functional import normalize
10
  from huggingface_hub import hf_hub_download
11
  from briarmbg import BriaRMBG
 
12
  from PIL import Image
13
  from typing import Tuple
14
 
 
106
  orig_image = Image.fromarray(image)
107
  w,h = orig_im_size = orig_image.size
108
  image = resize_image(orig_image)
109
+ im_np = np.array(image)
110
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
111
+ im_tensor = torch.unsqueeze(im_tensor,0)
112
+ im_tensor = torch.divide(im_tensor,255.0)
113
+ im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
114
+ if torch.cuda.is_available():
115
+ im_tensor=im_tensor.cuda()
116
+
117
+ #inference
118
+ result=net(im_tensor)
119
+ # post process
120
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
121
+ ma = torch.max(result)
122
+ mi = torch.min(result)
123
+ result = (result-mi)/(ma-mi)
124
+ # image to pil
125
+ im_array = (result*255).cpu().data.numpy().astype(np.uint8)
126
+ pil_im = Image.fromarray(np.squeeze(im_array))
127
+ # paste the mask on the original image
128
+ new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
129
+ new_im.paste(orig_image, mask=pil_im)
130
+ # new_orig_image = orig_image.convert('RGBA')
131
+
132
+ return new_im.tobytes()
 
133
 
134
 
135
  title = """<h1 style="text-align: center;">Product description generator</h1>"""
 
148
  chat = gr.Chatbot(show_label=False)
149
  submit = gr.Button(value="Upload", variant="primary")
150
  with gr.Column():
151
+ output = gr.Image(type="pil")
152
 
153
  response_handler = (
154
  response,