Kims12 commited on
Commit
797deb4
ยท
verified ยท
1 Parent(s): 3550322

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -49
app.py CHANGED
@@ -1,19 +1,12 @@
1
- import gradio as gr
2
- from gradio_imageslider import ImageSlider
3
- from loadimg import load_img
4
- from transformers import AutoModelForImageSegmentation
5
- import torch
6
- from torchvision import transforms
7
- from PIL import Image
8
- import os
9
 
10
- # ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ CPU๋กœ ์„ค์ •
11
  birefnet = AutoModelForImageSegmentation.from_pretrained(
12
  "ZhengPeng7/BiRefNet", trust_remote_code=True
13
  )
14
  birefnet.to("cpu") # GPU -> CPU๋กœ ๋ณ€๊ฒฝ
15
 
16
- # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
17
  transform_image = transforms.Compose(
18
  [
19
  transforms.Resize((1024, 1024)),
@@ -22,10 +15,20 @@ transform_image = transforms.Compose(
22
  ]
23
  )
24
 
 
 
 
 
 
 
 
 
 
 
25
  def process(image):
26
  image_size = image.size
27
- input_images = transform_image(image).unsqueeze(0).to("cpu") # CPU๋กœ ๋ณ€๊ฒฝ
28
- # ์˜ˆ์ธก ์ˆ˜ํ–‰
29
  with torch.no_grad():
30
  preds = birefnet(input_images)[-1].sigmoid().cpu()
31
  pred = preds[0].squeeze()
@@ -34,50 +37,32 @@ def process(image):
34
  image.putalpha(mask)
35
  return image
36
 
37
- def fn(image):
38
- im = load_img(image, output_type="pil")
 
39
  im = im.convert("RGB")
40
- origin = im.copy()
41
- processed_image = process(im)
42
-
43
- # JPG๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ €์žฅ
44
- jpg_image = origin.copy()
45
- jpg_image = jpg_image.convert("RGB")
46
- jpg_path = "output.jpg"
47
- jpg_image.save(jpg_path, format="JPEG")
48
-
49
- return [processed_image], jpg_path # ImageSlider๋Š” ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ธฐ๋Œ€ํ•จ
50
-
51
- def convert_to_jpg(image):
52
- if image is None:
53
- return None
54
- jpg_image = image.convert("RGB")
55
- jpg_path = "downloaded_output.jpg"
56
- jpg_image.save(jpg_path, format="JPEG")
57
- return jpg_path
58
 
59
- # Gradio ์ปดํฌ๋„ŒํŠธ ์ •์˜
60
  slider1 = ImageSlider(label="Processed Image", type="pil")
 
61
  image_upload = gr.Image(label="Upload an image")
62
- output_download = gr.File(label="Download JPG File")
 
 
63
 
64
- # ์ƒˆ๋กœ์šด ์ƒ˜ํ”Œ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€ (app.py์™€ ๋™์ผํ•œ ํด๋”์— ์œ„์น˜ํ•ด์•ผ ํ•จ)
65
- sample_images = [["1.png"], ["2.jpg"], ["3.png"]]
 
66
 
67
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
68
- tab1 = gr.Interface(
69
- fn=fn,
70
- inputs=image_upload,
71
- outputs=[slider1, output_download],
72
- examples=sample_images,
73
- api_name="image"
74
- )
75
 
76
- demo = gr.Interface(
77
- tab1,
78
- title="Background Removal Tool",
79
- description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํ™•์ธํ•˜๊ณ  JPG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
80
  )
81
 
82
  if __name__ == "__main__":
83
- demo.launch(show_error=True)
 
1
+ # GPU ์„ค์ •์„ CPU๋กœ ๋ณ€๊ฒฝ
2
+ # GPU ์„ค์ •์„ ์‚ญ์ œํ•˜๊ฑฐ๋‚˜ "cuda"๋ฅผ "cpu"๋กœ ๋ณ€๊ฒฝ
3
+ # torch.set_float32_matmul_precision("high")๋Š” CPU์—์„  ํ•„์š” ์—†์Œ.
 
 
 
 
 
4
 
 
5
  birefnet = AutoModelForImageSegmentation.from_pretrained(
6
  "ZhengPeng7/BiRefNet", trust_remote_code=True
7
  )
8
  birefnet.to("cpu") # GPU -> CPU๋กœ ๋ณ€๊ฒฝ
9
 
 
10
  transform_image = transforms.Compose(
11
  [
12
  transforms.Resize((1024, 1024)),
 
15
  ]
16
  )
17
 
18
+ def fn(image):
19
+ im = load_img(image, output_type="pil")
20
+ im = im.convert("RGB")
21
+ origin = im.copy()
22
+ processed_image = process(im)
23
+ return (processed_image, origin)
24
+
25
+ # @spaces.GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ์ œ๊ฑฐ
26
+ # CPU ํ™˜๊ฒฝ์—์„œ ๋™์ž‘ํ•˜๋„๋ก ์„ค์ •
27
+
28
  def process(image):
29
  image_size = image.size
30
+ input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU๋กœ ๋ณ€๊ฒฝ
31
+ # Prediction
32
  with torch.no_grad():
33
  preds = birefnet(input_images)[-1].sigmoid().cpu()
34
  pred = preds[0].squeeze()
 
37
  image.putalpha(mask)
38
  return image
39
 
40
+ def process_file(f):
41
+ name_path = f.rsplit(".", 1)[0] + ".png"
42
+ im = load_img(f, output_type="pil")
43
  im = im.convert("RGB")
44
+ transparent = process(im)
45
+ transparent.save(name_path)
46
+ return name_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
48
  slider1 = ImageSlider(label="Processed Image", type="pil")
49
+ slider2 = ImageSlider(label="Processed Image from URL", type="pil")
50
  image_upload = gr.Image(label="Upload an image")
51
+ image_file_upload = gr.Image(label="Upload an image", type="filepath")
52
+ url_input = gr.Textbox(label="Paste an image URL")
53
+ output_file = gr.File(label="Output PNG File")
54
 
55
+ # Example images
56
+ chameleon = load_img("butterfly.jpg", output_type="pil")
57
+ url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
58
 
59
+ tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
60
+ tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
61
+ tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
 
 
 
 
 
62
 
63
+ demo = gr.TabbedInterface(
64
+ [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
 
 
65
  )
66
 
67
  if __name__ == "__main__":
68
+ demo.launch(show_error=True)