not-lain commited on
Commit
94d1b20
·
1 Parent(s): f65ed8a
Files changed (1) hide show
  1. app.py +233 -369
app.py CHANGED
@@ -1,21 +1,19 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from loadimg import load_img # Assuming loadimg.py exists with load_img function
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation, pipeline
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
 
 
9
  import numpy as np
10
  from simple_lama_inpainting import SimpleLama
11
  from contextlib import contextmanager
 
12
  import gc
13
 
14
- # --- Add Translation Imports ---
15
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16
-
17
-
18
- # --- Utility Functions ---
19
  @contextmanager
20
  def float32_high_matmul_precision():
21
  torch.set_float32_matmul_precision("high")
@@ -25,33 +23,14 @@ def float32_high_matmul_precision():
25
  torch.set_float32_matmul_precision("highest")
26
 
27
 
28
- # --- Model Loading ---
29
- # Use context manager for precision during model loading if needed
30
- with float32_high_matmul_precision():
31
- pipe = FluxFillPipeline.from_pretrained(
32
- "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
33
- ).to("cuda")
34
-
35
- birefnet = AutoModelForImageSegmentation.from_pretrained(
36
- "ZhengPeng7/BiRefNet", trust_remote_code=True
37
- ).to("cuda")
38
-
39
- simple_lama = SimpleLama() # Initialize Lama globally if used often
40
-
41
- # --- Translation Model and Tokenizer Loading ---
42
- translation_model_name = "facebook/mbart-large-50-many-to-many-mmt"
43
- try:
44
- translation_model = MBartForConditionalGeneration.from_pretrained(
45
- translation_model_name
46
- ).to("cuda") # Move to GPU
47
- translation_tokenizer = MBart50TokenizerFast.from_pretrained(translation_model_name)
48
- except Exception as e:
49
- print(f"Error loading translation model/tokenizer: {e}")
50
- # Consider exiting or disabling the translation tab if loading fails
51
- translation_model = None
52
- translation_tokenizer = None
53
 
54
- # --- Image Processing Functions ---
 
 
 
55
 
56
  transform_image = transforms.Compose(
57
  [
@@ -70,6 +49,7 @@ def prepare_image_and_mask(
70
  padding_right=0,
71
  ):
72
  image = load_img(image).convert("RGB")
 
73
  background = ImageOps.expand(
74
  image,
75
  border=(padding_left, padding_top, padding_right, padding_bottom),
@@ -97,19 +77,19 @@ def outpaint(
97
  background, mask = prepare_image_and_mask(
98
  image, padding_top, padding_bottom, padding_left, padding_right
99
  )
100
- with (
101
- float32_high_matmul_precision()
102
- ): # Apply precision context if needed for inference
103
- result = pipe(
104
- prompt=prompt,
105
- height=background.height,
106
- width=background.width,
107
- image=background,
108
- mask_image=mask,
109
- num_inference_steps=num_inference_steps,
110
- guidance_scale=guidance_scale,
111
- ).images[0]
112
  result = result.convert("RGBA")
 
113
  return result
114
 
115
 
@@ -122,391 +102,275 @@ def inpaint(
122
  ):
123
  background = image.convert("RGB")
124
  mask = mask.convert("L")
125
- with (
126
- float32_high_matmul_precision()
127
- ): # Apply precision context if needed for inference
128
- result = pipe(
129
- prompt=prompt,
130
- height=background.height,
131
- width=background.width,
132
- image=background,
133
- mask_image=mask,
134
- num_inference_steps=num_inference_steps,
135
- guidance_scale=guidance_scale,
136
- ).images[0]
137
  result = result.convert("RGBA")
 
138
  return result
139
 
140
 
141
  def rmbg(image=None, url=None):
142
- if image is None and url:
143
- # Basic check for URL format, improve as needed
144
- if not url.startswith(("http://", "https://")):
145
- return "Invalid URL provided."
146
- image = url # load_img should handle URLs if configured correctly
147
- elif image is None:
148
- return "Please provide an image or a URL."
149
-
150
- try:
151
- image_pil = load_img(image).convert("RGB")
152
- except Exception as e:
153
- return f"Error loading image: {e}"
154
-
155
- image_size = image_pil.size
156
- input_images = transform_image(image_pil).unsqueeze(0).to("cuda")
157
  with float32_high_matmul_precision():
 
158
  with torch.no_grad():
159
  preds = birefnet(input_images)[-1].sigmoid().cpu()
160
  pred = preds[0].squeeze()
161
  pred_pil = transforms.ToPILImage()(pred)
162
  mask = pred_pil.resize(image_size)
163
- image_pil.putalpha(mask)
164
- # Clean up GPU memory if needed
165
- del input_images, preds, pred
166
- torch.cuda.empty_cache()
167
- gc.collect()
168
- return image_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  def erase(image=None, mask=None):
172
- if image is None or mask is None:
173
- return "Please provide both an image and a mask."
174
- try:
175
- image_pil = load_img(image)
176
- mask_pil = load_img(mask).convert("L")
177
- result = simple_lama(image_pil, mask_pil)
178
- # Clean up
179
- gc.collect()
180
- return result
181
- except Exception as e:
182
- return f"Error during erase operation: {e}"
183
-
184
-
185
- # --- Translation Functionality ---
186
-
187
- # Language Mapping
188
- lang_data = {
189
- "Arabic": "ar_AR",
190
- "Czech": "cs_CZ",
191
- "German": "de_DE",
192
- "English": "en_XX",
193
- "Spanish": "es_XX",
194
- "Estonian": "et_EE",
195
- "Finnish": "fi_FI",
196
- "French": "fr_XX",
197
- "Gujarati": "gu_IN",
198
- "Hindi": "hi_IN",
199
- "Italian": "it_IT",
200
- "Japanese": "ja_XX",
201
- "Kazakh": "kk_KZ",
202
- "Korean": "ko_KR",
203
- "Lithuanian": "lt_LT",
204
- "Latvian": "lv_LV",
205
- "Burmese": "my_MM",
206
- "Nepali": "ne_NP",
207
- "Dutch": "nl_XX",
208
- "Romanian": "ro_RO",
209
- "Russian": "ru_RU",
210
- "Sinhala": "si_LK",
211
- "Turkish": "tr_TR",
212
- "Vietnamese": "vi_VN",
213
- "Chinese": "zh_CN",
214
- "Afrikaans": "af_ZA",
215
- "Azerbaijani": "az_AZ",
216
- "Bengali": "bn_IN",
217
- "Persian": "fa_IR",
218
- "Hebrew": "he_IL",
219
- "Croatian": "hr_HR",
220
- "Indonesian": "id_ID",
221
- "Georgian": "ka_GE",
222
- "Khmer": "km_KH",
223
- "Macedonian": "mk_MK",
224
- "Malayalam": "ml_IN",
225
- "Mongolian": "mn_MN",
226
- "Marathi": "mr_IN",
227
- "Polish": "pl_PL",
228
- "Pashto": "ps_AF",
229
- "Portuguese": "pt_XX",
230
- "Swedish": "sv_SE",
231
- "Swahili": "sw_KE",
232
- "Tamil": "ta_IN",
233
- "Telugu": "te_IN",
234
- "Thai": "th_TH",
235
- "Tagalog": "tl_XX",
236
- "Ukrainian": "uk_UA",
237
- "Urdu": "ur_PK",
238
- "Xhosa": "xh_ZA",
239
- "Galician": "gl_ES",
240
- "Slovene": "sl_SI",
241
- }
242
- language_names = sorted(list(lang_data.keys()))
243
-
244
-
245
- def translate_text(text_to_translate, source_language_name, target_language_name):
246
- """
247
- Translates text using the loaded mBART model.
248
- """
249
- if translation_model is None or translation_tokenizer is None:
250
- return "Translation model not loaded. Cannot perform translation."
251
- if not text_to_translate:
252
- return "Please enter text to translate."
253
- if not source_language_name:
254
- return "Please select a source language."
255
- if not target_language_name:
256
- return "Please select a target language."
257
-
258
- try:
259
- source_lang_code = lang_data[source_language_name]
260
- target_lang_code = lang_data[target_language_name]
261
-
262
- translation_tokenizer.src_lang = source_lang_code
263
- encoded_text = translation_tokenizer(text_to_translate, return_tensors="pt").to(
264
- "cuda"
265
- ) # Move input to GPU
266
- target_lang_id = translation_tokenizer.lang_code_to_id[target_lang_code]
267
-
268
- # Generate translation on GPU
269
- with torch.no_grad(): # Use no_grad for inference
270
- generated_tokens = translation_model.generate(
271
- **encoded_text, forced_bos_token_id=target_lang_id, max_length=200
272
- )
273
-
274
- translated_text = translation_tokenizer.batch_decode(
275
- generated_tokens, skip_special_tokens=True
276
- )
277
-
278
- # Clean up GPU memory
279
- del encoded_text, generated_tokens
280
- torch.cuda.empty_cache()
281
- gc.collect()
282
-
283
- return translated_text[0]
284
-
285
- except KeyError as e:
286
- return f"Error: Language code not found for {e}. Check language mappings."
287
- except Exception as e:
288
- print(f"Translation error: {e}")
289
- # Clean up GPU memory on error too
290
- torch.cuda.empty_cache()
291
- gc.collect()
292
- return f"An error occurred during translation: {e}"
293
-
294
-
295
- # --- Main Function Router (for image tasks) ---
296
- # Note: Translation uses its own function directly
297
- @spaces.GPU(duration=120) # Keep GPU decorator if needed for image tasks
298
  def main(*args):
299
  api_num = args[0]
300
  args = args[1:]
301
- gc.collect() # Try to collect garbage before starting task
302
- torch.cuda.empty_cache() # Clear cache before starting task
 
 
 
 
 
 
 
 
 
 
303
 
304
- result = None
305
- try:
306
- if api_num == 1:
307
- result = rmbg(*args)
308
- elif api_num == 2:
309
- result = outpaint(*args)
310
- elif api_num == 3:
311
- result = inpaint(*args)
312
- # elif api_num == 4: # Keep commented out as in original
313
- # return mask_generation(*args)
314
- elif api_num == 5:
315
- result = erase(*args)
316
- else:
317
- result = "Invalid API number."
318
- except Exception as e:
319
- print(f"Error in main task routing (api_num={api_num}): {e}")
320
- result = f"An error occurred: {e}"
321
- finally:
322
- # Ensure memory cleanup happens even if there's an error
323
- gc.collect()
324
- torch.cuda.empty_cache()
325
-
326
- return result
327
-
328
-
329
- # --- Define Gradio Interfaces for Each Tab ---
330
 
331
- # Image Task Tabs
332
  rmbg_tab = gr.Interface(
333
  fn=main,
334
  inputs=[
335
- gr.Number(1, interactive=False, visible=False), # Hide API number
336
- gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
337
- gr.Text(label="Or Image URL (optional)"),
338
  ],
339
- outputs=gr.Image(label="Output Image", type="pil"),
340
- title="Remove Background",
341
- description="Upload an image or provide a URL to remove its background.",
342
  api_name="rmbg",
343
- # examples=[[1, "./assets/sample_rmbg.png", ""]], # Update example path if needed
344
  cache_examples=False,
 
345
  )
346
 
347
  outpaint_tab = gr.Interface(
348
  fn=main,
349
  inputs=[
350
- gr.Number(2, interactive=False, visible=False),
351
- gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
352
- gr.Number(value=0, label="Padding Top (pixels)"),
353
- gr.Number(value=0, label="Padding Bottom (pixels)"),
354
- gr.Number(value=0, label="Padding Left (pixels)"),
355
- gr.Number(value=0, label="Padding Right (pixels)"),
356
- gr.Text(
357
- label="Prompt (optional)",
358
- info="Describe what to fill the extended area with",
359
- ),
360
- gr.Slider(
361
- minimum=10, maximum=100, step=1, value=28, label="Inference Steps"
362
- ), # Use slider for steps
363
- gr.Slider(
364
- minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"
365
- ), # Use slider for guidance
366
  ],
367
- outputs=gr.Image(label="Outpainted Image", type="pil"),
368
- title="Outpainting",
369
- description="Extend an image by adding padding and filling the new area using a diffusion model.",
370
  api_name="outpainting",
371
- # examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 28, 50]], # Update example path
372
  cache_examples=False,
373
  )
374
 
 
375
  inpaint_tab = gr.Interface(
376
  fn=main,
377
  inputs=[
378
- gr.Number(3, interactive=False, visible=False),
379
- gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
380
- gr.Image(
381
- label="Mask Image (White=Inpaint Area)",
382
- type="pil",
383
- sources=["upload", "clipboard"],
384
- ),
385
- gr.Text(
386
- label="Prompt (optional)", info="Describe what to fill the masked area with"
387
- ),
388
- gr.Slider(minimum=10, maximum=100, step=1, value=28, label="Inference Steps"),
389
- gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Guidance Scale"),
390
  ],
391
- outputs=gr.Image(label="Inpainted Image", type="pil"),
392
- title="Inpainting",
393
- description="Fill in the white areas of a mask applied to an image using a diffusion model.",
394
  api_name="inpaint",
395
- # examples=[[3, "./assets/rocket.png", "./assets/Inpainting_mask.png", "", 28, 50]], # Update example paths
396
  cache_examples=False,
 
397
  )
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  erase_tab = gr.Interface(
400
- fn=main,
401
  inputs=[
402
- gr.Number(5, interactive=False, visible=False),
403
- gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"]),
404
- gr.Image(
405
- label="Mask Image (White=Erase Area)",
406
- type="pil",
407
- sources=["upload", "clipboard"],
408
- ),
 
 
 
 
409
  ],
410
- outputs=gr.Image(label="Result Image", type="pil"),
411
- title="Erase Object (LAMA)",
412
- description="Erase objects from an image based on a mask using the LaMa inpainting model.",
413
  api_name="erase",
414
- # examples=[[5, "./assets/rocket.png", "./assets/Inpainting_mask.png"]], # Update example paths
415
  cache_examples=False,
416
  )
417
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
- # --- Define Translation Tab using gr.Blocks ---
420
- with gr.Blocks() as translation_tab:
421
- gr.Markdown(
422
- """
423
- ## Multilingual Translation (mBART-50)
424
- Translate text between 50 different languages.
425
- Select the source and target languages, enter your text, and click Translate.
426
- """
427
- )
428
- with gr.Row():
429
- with gr.Column(scale=1):
430
- source_lang_dropdown = gr.Dropdown(
431
- label="Source Language",
432
- choices=language_names,
433
- info="Select the language of your input text.",
434
- )
435
- target_lang_dropdown = gr.Dropdown(
436
- label="Target Language",
437
- choices=language_names,
438
- info="Select the language you want to translate to.",
439
- )
440
- with gr.Column(scale=2):
441
- input_textbox = gr.Textbox(
442
- label="Text to Translate",
443
- lines=6, # Increased lines
444
- placeholder="Enter text here...",
445
- )
446
- translate_button = gr.Button(
447
- "Translate", variant="primary"
448
- ) # Added variant
449
- output_textbox = gr.Textbox(
450
- label="Translated Text",
451
- lines=6, # Increased lines
452
- interactive=False, # Make output read-only
453
- )
454
-
455
- # Connect Components to the translation function directly
456
- translate_button.click(
457
- fn=translate_text,
458
- inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown],
459
- outputs=output_textbox,
460
- api_name="translate", # Add API name for the translation endpoint
461
- )
462
-
463
- # Add Translation Examples
464
- gr.Examples(
465
- examples=[
466
- [
467
- "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है",
468
- "Hindi",
469
- "French",
470
- ],
471
- [
472
- "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا.",
473
- "Arabic",
474
- "English",
475
- ],
476
- [
477
- "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie.",
478
- "French",
479
- "German",
480
- ],
481
- ["Hello world! How are you today?", "English", "Spanish"],
482
- ["Guten Tag!", "German", "Japanese"],
483
- ["これはテストです", "Japanese", "English"],
484
- ],
485
- inputs=[input_textbox, source_lang_dropdown, target_lang_dropdown],
486
- outputs=output_textbox,
487
- fn=translate_text,
488
- cache_examples=False,
489
- )
490
-
491
- # --- Combine all tabs ---
492
  demo = gr.TabbedInterface(
493
  [
494
  rmbg_tab,
495
  outpaint_tab,
496
  inpaint_tab,
 
497
  erase_tab,
498
- translation_tab, # Add the translation tab
499
- # sam2_tab, # Keep commented out
500
  ],
501
  [
502
- "Remove Background", # Tab title
503
- "Outpainting", # Tab title
504
- "Inpainting", # Tab title
505
- "Erase (LAMA)", # Tab title
506
- "Translate", # Tab title for translation
507
  # "sam2",
 
 
508
  ],
509
- title="Image & Text Utilities (GPU)", # Updated title
510
  )
511
 
512
- demo.launch()
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from loadimg import load_img
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation, pipeline
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
+
10
+ # from sam2.sam2_image_predictor import SAM2ImagePredictor
11
  import numpy as np
12
  from simple_lama_inpainting import SimpleLama
13
  from contextlib import contextmanager
14
+ # import whisperx
15
  import gc
16
 
 
 
 
 
 
17
  @contextmanager
18
  def float32_high_matmul_precision():
19
  torch.set_float32_matmul_precision("high")
 
23
  torch.set_float32_matmul_precision("highest")
24
 
25
 
26
+ pipe = FluxFillPipeline.from_pretrained(
27
+ "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
28
+ ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
31
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
32
+ )
33
+ birefnet.to("cuda")
34
 
35
  transform_image = transforms.Compose(
36
  [
 
49
  padding_right=0,
50
  ):
51
  image = load_img(image).convert("RGB")
52
+ # expand image (left,top,right,bottom)
53
  background = ImageOps.expand(
54
  image,
55
  border=(padding_left, padding_top, padding_right, padding_bottom),
 
77
  background, mask = prepare_image_and_mask(
78
  image, padding_top, padding_bottom, padding_left, padding_right
79
  )
80
+
81
+ result = pipe(
82
+ prompt=prompt,
83
+ height=background.height,
84
+ width=background.width,
85
+ image=background,
86
+ mask_image=mask,
87
+ num_inference_steps=num_inference_steps,
88
+ guidance_scale=guidance_scale,
89
+ ).images[0]
90
+
 
91
  result = result.convert("RGBA")
92
+
93
  return result
94
 
95
 
 
102
  ):
103
  background = image.convert("RGB")
104
  mask = mask.convert("L")
105
+
106
+ result = pipe(
107
+ prompt=prompt,
108
+ height=background.height,
109
+ width=background.width,
110
+ image=background,
111
+ mask_image=mask,
112
+ num_inference_steps=num_inference_steps,
113
+ guidance_scale=guidance_scale,
114
+ ).images[0]
115
+
 
116
  result = result.convert("RGBA")
117
+
118
  return result
119
 
120
 
121
  def rmbg(image=None, url=None):
122
+ if image is None:
123
+ image = url
124
+ image = load_img(image).convert("RGB")
125
+ image_size = image.size
126
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
 
 
 
 
 
 
 
 
 
 
127
  with float32_high_matmul_precision():
128
+ # Prediction
129
  with torch.no_grad():
130
  preds = birefnet(input_images)[-1].sigmoid().cpu()
131
  pred = preds[0].squeeze()
132
  pred_pil = transforms.ToPILImage()(pred)
133
  mask = pred_pil.resize(image_size)
134
+ image.putalpha(mask)
135
+ return image
136
+
137
+
138
+ # def mask_generation(image=None, d=None):
139
+ # # use bfloat16 for the entire notebook
140
+ # # torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
141
+ # # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
142
+ # # if torch.cuda.get_device_properties(0).major >= 8:
143
+ # # torch.backends.cuda.matmul.allow_tf32 = True
144
+ # # torch.backends.cudnn.allow_tf32 = True
145
+ # d = eval(d) # convert this to dictionary
146
+ # with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
147
+ # predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
148
+ # predictor.set_image(image)
149
+ # input_point = np.array(d["input_points"])
150
+ # input_label = np.array(d["input_labels"])
151
+ # masks, scores, logits = predictor.predict(
152
+ # point_coords=input_point,
153
+ # point_labels=input_label,
154
+ # multimask_output=True,
155
+ # )
156
+ # sorted_ind = np.argsort(scores)[::-1]
157
+ # masks = masks[sorted_ind]
158
+ # scores = scores[sorted_ind]
159
+ # logits = logits[sorted_ind]
160
+
161
+ # out = []
162
+ # for i in range(len(masks)):
163
+ # m = Image.fromarray(masks[i] * 255).convert("L")
164
+ # comp = Image.composite(image, m, m)
165
+ # out.append((comp, f"image {i}"))
166
+
167
+ # return out
168
 
169
 
170
  def erase(image=None, mask=None):
171
+ simple_lama = SimpleLama()
172
+ image = load_img(image)
173
+ mask = load_img(mask).convert("L")
174
+ return simple_lama(image, mask)
175
+
176
+
177
+ # def transcribe(audio):
178
+ # if audio is None:
179
+ # raise gr.Error("No audio file submitted!")
180
+
181
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
182
+ # compute_type = "float16"
183
+ # batch_size = 8 # reduced batch size to be conservative with memory
184
+
185
+ # try:
186
+ # # 1. Load model and transcribe
187
+ # model = whisperx.load_model("large-v2", device, compute_type=compute_type)
188
+ # audio_input = whisperx.load_audio(audio)
189
+ # result = model.transcribe(audio_input, batch_size=batch_size)
190
+
191
+ # # Clear GPU memory
192
+ # del model
193
+ # gc.collect()
194
+ # torch.cuda.empty_cache()
195
+
196
+ # # 2. Align whisper output
197
+ # model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
198
+ # result = whisperx.align(result["segments"], model_a, metadata, audio_input, device, return_char_alignments=False)
199
+
200
+ # # Clear GPU memory
201
+ # del model_a
202
+ # gc.collect()
203
+ # torch.cuda.empty_cache()
204
+
205
+ # # 3. Assign speaker labels
206
+ # diarize_model = whisperx.DiarizationPipeline(device=device)
207
+ # diarize_segments = diarize_model(audio_input)
208
+
209
+ # # Combine transcription with speaker diarization
210
+ # result = whisperx.assign_word_speakers(diarize_segments, result)
211
+
212
+ # # Format output with speaker labels and timestamps
213
+ # formatted_text = []
214
+ # for segment in result["segments"]:
215
+ # if not isinstance(segment, dict):
216
+ # continue
217
+
218
+ # speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
219
+ # start_time = f"{float(segment.get('start', 0)):.2f}"
220
+ # end_time = f"{float(segment.get('end', 0)):.2f}"
221
+ # text = segment.get('text', '').strip()
222
+ # formatted_text.append(f"[{start_time}s - {end_time}s] {speaker}: {text}")
223
+
224
+ # return "\n".join(formatted_text)
225
+
226
+ # except Exception as e:
227
+ # raise gr.Error(f"Transcription failed: {str(e)}")
228
+ # finally:
229
+ # # Ensure GPU memory is cleared even if an error occurs
230
+ # gc.collect()
231
+ # torch.cuda.empty_cache()
232
+
233
+
234
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  def main(*args):
236
  api_num = args[0]
237
  args = args[1:]
238
+ if api_num == 1:
239
+ return rmbg(*args)
240
+ elif api_num == 2:
241
+ return outpaint(*args)
242
+ elif api_num == 3:
243
+ return inpaint(*args)
244
+ # elif api_num == 4:
245
+ # return mask_generation(*args)
246
+ elif api_num == 5:
247
+ return erase(*args)
248
+ # elif api_num == 6:
249
+ # return transcribe(*args)
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
 
252
  rmbg_tab = gr.Interface(
253
  fn=main,
254
  inputs=[
255
+ gr.Number(1, interactive=False),
256
+ "image",
257
+ gr.Text("", label="url"),
258
  ],
259
+ outputs=["image"],
 
 
260
  api_name="rmbg",
261
+ examples=[[1, "./assets/Inpainting mask.png", ""]],
262
  cache_examples=False,
263
+ description="pass an image or a url of an image",
264
  )
265
 
266
  outpaint_tab = gr.Interface(
267
  fn=main,
268
  inputs=[
269
+ gr.Number(2, interactive=False),
270
+ gr.Image(label="image", type="pil"),
271
+ gr.Number(label="padding top"),
272
+ gr.Number(label="padding bottom"),
273
+ gr.Number(label="padding left"),
274
+ gr.Number(label="padding right"),
275
+ gr.Text(label="prompt"),
276
+ gr.Number(value=50, label="num_inference_steps"),
277
+ gr.Number(value=28, label="guidance_scale"),
 
 
 
 
 
 
 
278
  ],
279
+ outputs=["image"],
 
 
280
  api_name="outpainting",
281
+ examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]],
282
  cache_examples=False,
283
  )
284
 
285
+
286
  inpaint_tab = gr.Interface(
287
  fn=main,
288
  inputs=[
289
+ gr.Number(3, interactive=False),
290
+ gr.Image(label="image", type="pil"),
291
+ gr.Image(label="mask", type="pil"),
292
+ gr.Text(label="prompt"),
293
+ gr.Number(value=50, label="num_inference_steps"),
294
+ gr.Number(value=28, label="guidance_scale"),
 
 
 
 
 
 
295
  ],
296
+ outputs=["image"],
 
 
297
  api_name="inpaint",
298
+ examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]],
299
  cache_examples=False,
300
+ description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
301
  )
302
 
303
+
304
+ # sam2_tab = gr.Interface(
305
+ # main,
306
+ # inputs=[
307
+ # gr.Number(4, interactive=False),
308
+ # gr.Image(type="pil"),
309
+ # gr.Text(),
310
+ # ],
311
+ # outputs=gr.Gallery(),
312
+ # examples=[
313
+ # [
314
+ # 4,
315
+ # "./assets/truck.jpg",
316
+ # '{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
317
+ # ]
318
+ # ],
319
+ # api_name="sam2",
320
+ # cache_examples=False,
321
+ # )
322
+
323
  erase_tab = gr.Interface(
324
+ main,
325
  inputs=[
326
+ gr.Number(5, interactive=False),
327
+ gr.Image(type="pil"),
328
+ gr.Image(type="pil"),
329
+ ],
330
+ outputs=gr.Image(),
331
+ examples=[
332
+ [
333
+ 5,
334
+ "./assets/rocket.png",
335
+ "./assets/Inpainting mask.png",
336
+ ]
337
  ],
 
 
 
338
  api_name="erase",
 
339
  cache_examples=False,
340
  )
341
 
342
+ transcribe_tab = gr.Interface(
343
+ fn=main,
344
+ inputs=[
345
+ gr.Number(value=6, interactive=False), # API number
346
+ gr.Audio(type="filepath", label="Audio File"),
347
+ ],
348
+ outputs=gr.Textbox(label="Transcription"),
349
+ title="Audio Transcription",
350
+ description="Upload an audio file to extract text using WhisperX with speaker diarization",
351
+ api_name="transcribe",
352
+ examples=[]
353
+ )
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  demo = gr.TabbedInterface(
356
  [
357
  rmbg_tab,
358
  outpaint_tab,
359
  inpaint_tab,
360
+ # sam2_tab,
361
  erase_tab,
362
+ transcribe_tab,
 
363
  ],
364
  [
365
+ "remove background",
366
+ "outpainting",
367
+ "inpainting",
 
 
368
  # "sam2",
369
+ "erase",
370
+ # "transcribe",
371
  ],
372
+ title="Utilities that require GPU",
373
  )
374
 
375
+
376
+ demo.launch()