not-lain commited on
Commit
b9bde42
·
1 Parent(s): cce19ac

rollback to last stable

Browse files
Files changed (1) hide show
  1. app.py +158 -15
app.py CHANGED
@@ -3,11 +3,28 @@ import spaces
3
  import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
- from transformers import AutoModelForImageSegmentation
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
 
10
- torch.set_float32_matmul_precision(["high", "highest"][0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
@@ -22,10 +39,6 @@ transform_image = transforms.Compose(
22
  ]
23
  )
24
 
25
- pipe = FluxFillPipeline.from_pretrained(
26
- "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
27
- ).to("cuda")
28
-
29
 
30
  def prepare_image_and_mask(
31
  image,
@@ -110,9 +123,10 @@ def rmbg(image=None, url=None):
110
  image = load_img(image).convert("RGB")
111
  image_size = image.size
112
  input_images = transform_image(image).unsqueeze(0).to("cuda")
113
- # Prediction
114
- with torch.no_grad():
115
- preds = birefnet(input_images)[-1].sigmoid().cpu()
 
116
  pred = preds[0].squeeze()
117
  pred_pil = transforms.ToPILImage()(pred)
118
  mask = pred_pil.resize(image_size)
@@ -120,7 +134,65 @@ def rmbg(image=None, url=None):
120
  return image
121
 
122
 
123
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def main(*args):
125
  api_num = args[0]
126
  args = args[1:]
@@ -130,12 +202,18 @@ def main(*args):
130
  return outpaint(*args)
131
  elif api_num == 3:
132
  return inpaint(*args)
 
 
 
 
 
 
133
 
134
 
135
  rmbg_tab = gr.Interface(
136
  fn=main,
137
  inputs=[
138
- gr.Number(1, visible=False),
139
  "image",
140
  gr.Text("", label="url"),
141
  ],
@@ -149,7 +227,7 @@ rmbg_tab = gr.Interface(
149
  outpaint_tab = gr.Interface(
150
  fn=main,
151
  inputs=[
152
- gr.Number(2, visible=False),
153
  gr.Image(label="image", type="pil"),
154
  gr.Number(label="padding top"),
155
  gr.Number(label="padding bottom"),
@@ -169,7 +247,7 @@ outpaint_tab = gr.Interface(
169
  inpaint_tab = gr.Interface(
170
  fn=main,
171
  inputs=[
172
- gr.Number(3, visible=False),
173
  gr.Image(label="image", type="pil"),
174
  gr.Image(label="mask", type="pil"),
175
  gr.Text(label="prompt"),
@@ -183,9 +261,74 @@ inpaint_tab = gr.Interface(
183
  description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  demo = gr.TabbedInterface(
187
- [rmbg_tab, outpaint_tab, inpaint_tab],
188
- ["remove background", "outpainting", "inpainting"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  title="Utilities that require GPU",
190
  )
191
 
 
3
  import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation, pipeline
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
 
10
+ # from sam2.sam2_image_predictor import SAM2ImagePredictor
11
+ import numpy as np
12
+ from simple_lama_inpainting import SimpleLama
13
+ from contextlib import contextmanager
14
+
15
+
16
+ @contextmanager
17
+ def float32_high_matmul_precision():
18
+ torch.set_float32_matmul_precision("high")
19
+ try:
20
+ yield
21
+ finally:
22
+ torch.set_float32_matmul_precision("highest")
23
+
24
+
25
+ pipe = FluxFillPipeline.from_pretrained(
26
+ "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
27
+ ).to("cuda")
28
 
29
  birefnet = AutoModelForImageSegmentation.from_pretrained(
30
  "ZhengPeng7/BiRefNet", trust_remote_code=True
 
39
  ]
40
  )
41
 
 
 
 
 
42
 
43
  def prepare_image_and_mask(
44
  image,
 
123
  image = load_img(image).convert("RGB")
124
  image_size = image.size
125
  input_images = transform_image(image).unsqueeze(0).to("cuda")
126
+ with float32_high_matmul_precision():
127
+ # Prediction
128
+ with torch.no_grad():
129
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
130
  pred = preds[0].squeeze()
131
  pred_pil = transforms.ToPILImage()(pred)
132
  mask = pred_pil.resize(image_size)
 
134
  return image
135
 
136
 
137
+ # def mask_generation(image=None, d=None):
138
+ # # use bfloat16 for the entire notebook
139
+ # # torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
140
+ # # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
141
+ # # if torch.cuda.get_device_properties(0).major >= 8:
142
+ # # torch.backends.cuda.matmul.allow_tf32 = True
143
+ # # torch.backends.cudnn.allow_tf32 = True
144
+ # d = eval(d) # convert this to dictionary
145
+ # with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
146
+ # predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
147
+ # predictor.set_image(image)
148
+ # input_point = np.array(d["input_points"])
149
+ # input_label = np.array(d["input_labels"])
150
+ # masks, scores, logits = predictor.predict(
151
+ # point_coords=input_point,
152
+ # point_labels=input_label,
153
+ # multimask_output=True,
154
+ # )
155
+ # sorted_ind = np.argsort(scores)[::-1]
156
+ # masks = masks[sorted_ind]
157
+ # scores = scores[sorted_ind]
158
+ # logits = logits[sorted_ind]
159
+
160
+ # out = []
161
+ # for i in range(len(masks)):
162
+ # m = Image.fromarray(masks[i] * 255).convert("L")
163
+ # comp = Image.composite(image, m, m)
164
+ # out.append((comp, f"image {i}"))
165
+
166
+ # return out
167
+
168
+
169
+ def erase(image=None, mask=None):
170
+ simple_lama = SimpleLama()
171
+ image = load_img(image)
172
+ mask = load_img(mask).convert("L")
173
+ return simple_lama(image, mask)
174
+
175
+
176
+ # Initialize Whisper model
177
+ whisper = pipeline(
178
+ task="automatic-speech-recognition",
179
+ model="openai/whisper-large-v3",
180
+ chunk_length_s=30,
181
+ device="cuda" if torch.cuda.is_available() else "cpu",
182
+ )
183
+
184
+
185
+ def transcribe(audio, task="transcribe"):
186
+ if audio is None:
187
+ raise gr.Error("No audio file submitted!")
188
+
189
+ text = whisper(
190
+ audio, batch_size=8, generate_kwargs={"task": task}, return_timestamps=True
191
+ )["text"]
192
+ return text
193
+
194
+
195
+ @spaces.GPU(duration=120)
196
  def main(*args):
197
  api_num = args[0]
198
  args = args[1:]
 
202
  return outpaint(*args)
203
  elif api_num == 3:
204
  return inpaint(*args)
205
+ # elif api_num == 4:
206
+ # return mask_generation(*args)
207
+ elif api_num == 5:
208
+ return erase(*args)
209
+ elif api_num == 6:
210
+ return transcribe(*args)
211
 
212
 
213
  rmbg_tab = gr.Interface(
214
  fn=main,
215
  inputs=[
216
+ gr.Number(1, interactive=False),
217
  "image",
218
  gr.Text("", label="url"),
219
  ],
 
227
  outpaint_tab = gr.Interface(
228
  fn=main,
229
  inputs=[
230
+ gr.Number(2, interactive=False),
231
  gr.Image(label="image", type="pil"),
232
  gr.Number(label="padding top"),
233
  gr.Number(label="padding bottom"),
 
247
  inpaint_tab = gr.Interface(
248
  fn=main,
249
  inputs=[
250
+ gr.Number(3, interactive=False),
251
  gr.Image(label="image", type="pil"),
252
  gr.Image(label="mask", type="pil"),
253
  gr.Text(label="prompt"),
 
261
  description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
262
  )
263
 
264
+
265
+ # sam2_tab = gr.Interface(
266
+ # main,
267
+ # inputs=[
268
+ # gr.Number(4, interactive=False),
269
+ # gr.Image(type="pil"),
270
+ # gr.Text(),
271
+ # ],
272
+ # outputs=gr.Gallery(),
273
+ # examples=[
274
+ # [
275
+ # 4,
276
+ # "./assets/truck.jpg",
277
+ # '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
278
+ # ]
279
+ # ],
280
+ # api_name="sam2",
281
+ # cache_examples=False,
282
+ # )
283
+
284
+ erase_tab = gr.Interface(
285
+ main,
286
+ inputs=[
287
+ gr.Number(5, interactive=False),
288
+ gr.Image(type="pil"),
289
+ gr.Image(type="pil"),
290
+ ],
291
+ outputs=gr.Image(),
292
+ examples=[
293
+ [
294
+ 5,
295
+ "./assets/rocket.png",
296
+ "./assets/Inpainting mask.png",
297
+ ]
298
+ ],
299
+ api_name="erase",
300
+ cache_examples=False,
301
+ )
302
+
303
+ transcribe_tab = gr.Interface(
304
+ fn=main,
305
+ inputs=[
306
+ gr.Number(6, interactive=False),
307
+ gr.Audio(type="filepath"),
308
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
309
+ ],
310
+ outputs="text",
311
+ api_name="transcribe",
312
+ description="Upload an audio file to extract text using Whisper Large V3",
313
+ )
314
+
315
  demo = gr.TabbedInterface(
316
+ [
317
+ rmbg_tab,
318
+ outpaint_tab,
319
+ inpaint_tab,
320
+ # sam2_tab,
321
+ erase_tab,
322
+ transcribe_tab,
323
+ ],
324
+ [
325
+ "remove background",
326
+ "outpainting",
327
+ "inpainting",
328
+ # "sam2",
329
+ "erase",
330
+ "transcribe",
331
+ ],
332
  title="Utilities that require GPU",
333
  )
334