not-lain commited on
Commit
f9858d8
·
verified ·
1 Parent(s): 16bb0cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
- import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
@@ -11,7 +11,8 @@ torch.set_float32_matmul_precision(["high", "highest"][0])
11
  birefnet = AutoModelForImageSegmentation.from_pretrained(
12
  "ZhengPeng7/BiRefNet", trust_remote_code=True
13
  )
14
- birefnet.to("cuda")
 
15
  transform_image = transforms.Compose(
16
  [
17
  transforms.Resize((1024, 1024)),
@@ -21,14 +22,14 @@ transform_image = transforms.Compose(
21
  )
22
 
23
 
24
- @spaces.GPU
25
  def fn(image):
26
  im = load_img(image, output_type="pil")
27
  im = im.convert("RGB")
28
  image_size = im.size
29
  origin = im.copy()
30
  image = load_img(im)
31
- input_images = transform_image(image).unsqueeze(0).to("cuda")
32
  # Prediction
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
1
  import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
+ # import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
 
11
  birefnet = AutoModelForImageSegmentation.from_pretrained(
12
  "ZhengPeng7/BiRefNet", trust_remote_code=True
13
  )
14
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
+ birefnet.to(device)
16
  transform_image = transforms.Compose(
17
  [
18
  transforms.Resize((1024, 1024)),
 
22
  )
23
 
24
 
25
+ # @spaces.GPU
26
  def fn(image):
27
  im = load_img(image, output_type="pil")
28
  im = im.convert("RGB")
29
  image_size = im.size
30
  origin = im.copy()
31
  image = load_img(im)
32
+ input_images = transform_image(image).unsqueeze(0).to(birefnet)
33
  # Prediction
34
  with torch.no_grad():
35
  preds = birefnet(input_images)[-1].sigmoid().cpu()