wildoctopus commited on
Commit
1369068
·
verified ·
1 Parent(s): d03c73f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -104
app.py CHANGED
@@ -4,138 +4,74 @@ import gradio as gr
4
  import os
5
  from process import load_seg_model, get_palette, generate_mask
6
 
7
-
8
-
9
  device = 'cpu'
10
 
11
  def read_content(file_path: str) -> str:
12
- """read the content of target file
13
- """
14
  with open(file_path, 'r', encoding='utf-8') as f:
15
- content = f.read()
16
-
17
- return content
18
 
19
  def initialize_and_load_models():
20
-
21
  checkpoint_path = 'model/cloth_segm.pth'
22
- net = load_seg_model(checkpoint_path, device=device)
23
-
24
- return net
25
 
26
  net = initialize_and_load_models()
27
  palette = get_palette(4)
28
 
29
-
30
  def run(img):
31
-
32
  cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
33
  return cloth_seg
34
 
35
- # Define input and output interfaces
36
- input_image = gr.Image(label="Input Image", type="pil")
37
-
38
- # Define the Gradio interface
39
- cloth_seg_image = gr.Image(label="Cloth Segmentation", type="pil")
40
-
41
- title = "Demo for Cloth Segmentation"
42
- description = "An app for Cloth Segmentation"
43
- inputs = [input_image]
44
- outputs = [cloth_seg_image]
45
-
46
  css = '''
47
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
48
  #image_upload{min-height:400px}
49
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
50
- #mask_radio .gr-form{background:transparent; border: none}
51
- #word_mask{margin-top: .75em !important}
52
- #word_mask textarea:disabled{opacity: 0.3}
53
  .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
54
  .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
55
  .dark .footer {border-color: #303030}
56
  .dark .footer>p {background: #0b0f19}
57
  .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
58
  #image_upload .touch-none{display: flex}
59
- @keyframes spin {
60
- from {
61
- transform: rotate(0deg);
62
- }
63
- to {
64
- transform: rotate(360deg);
65
- }
66
- }
67
- #share-btn-container {
68
- display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
69
- }
70
- #share-btn {
71
- all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
72
- }
73
- #share-btn * {
74
- all: unset;
75
- }
76
- #share-btn-container div:nth-child(-n+2){
77
- width: auto !important;
78
- min-height: 0px !important;
79
- }
80
- #share-btn-container .wrap {
81
- display: none !important;
82
- }
83
  '''
84
- example={}
85
- image_dir='input'
86
 
87
- image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir)]
 
 
 
88
  image_list.sort()
89
 
90
-
91
- image_blocks = gr.Blocks(css=css)
92
- with image_blocks as demo:
93
  gr.HTML(read_content("header.html"))
94
- with gr.Group():
95
- with gr.Box():
96
- with gr.Row():
97
- with gr.Column():
98
- image = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Input Image")
99
-
100
-
101
- with gr.Column():
102
- image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
103
-
104
-
105
-
106
-
107
-
108
- with gr.Row():
109
- with gr.Column():
110
- gr.Examples(image_list, inputs=[image],label="Examples - Input Images",examples_per_page=12)
111
- with gr.Column():
112
- with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
113
- btn = gr.Button("Run!").style(
114
- margin=False,
115
- rounded=(False, True, True, False),
116
- full_width=True,
117
- )
118
-
119
-
120
-
121
- btn.click(fn=run, inputs=[image], outputs=[image_out])
122
-
123
-
124
-
125
-
126
- gr.HTML(
127
- """
128
- <div class="footer">
129
- <p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face
130
- </p>
131
- </div>
132
- <div class="acknowledgments">
133
- <p><h4>ACKNOWLEDGEMENTS</h4></p>
134
- <p>
135
- U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" style="text-decoration: underline;" target="_blank">Xuebin Qin</a> for amazing repo.</p>
136
- <p>Codes are modified from <a href="https://github.com/levindabhi/cloth-segmentation" style="text-decoration: underline;" target="_blank">levindabhi/cloth-segmentation</a>
137
- </p>
138
- """
139
- )
140
 
141
- image_blocks.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import os
5
  from process import load_seg_model, get_palette, generate_mask
6
 
 
 
7
  device = 'cpu'
8
 
9
  def read_content(file_path: str) -> str:
 
 
10
  with open(file_path, 'r', encoding='utf-8') as f:
11
+ return f.read()
 
 
12
 
13
  def initialize_and_load_models():
 
14
  checkpoint_path = 'model/cloth_segm.pth'
15
+ return load_seg_model(checkpoint_path, device=device)
 
 
16
 
17
  net = initialize_and_load_models()
18
  palette = get_palette(4)
19
 
 
20
  def run(img):
 
21
  cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
22
  return cloth_seg
23
 
24
+ # CSS styling
 
 
 
 
 
 
 
 
 
 
25
  css = '''
26
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
27
  #image_upload{min-height:400px}
28
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
 
 
 
29
  .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
30
  .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
31
  .dark .footer {border-color: #303030}
32
  .dark .footer>p {background: #0b0f19}
33
  .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
34
  #image_upload .touch-none{display: flex}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  '''
 
 
36
 
37
+ # Collect example images
38
+ example = {}
39
+ image_dir = 'input'
40
+ image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
41
  image_list.sort()
42
 
43
+ with gr.Blocks(css=css) as demo:
 
 
44
  gr.HTML(read_content("header.html"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ with gr.Row():
47
+ with gr.Column():
48
+ image = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Input Image")
49
+
50
+ with gr.Column():
51
+ image_out = gr.Image(label="Output", elem_id="output-img")
52
+
53
+ with gr.Row():
54
+ gr.Examples(
55
+ examples=image_list,
56
+ inputs=[image],
57
+ label="Examples - Input Images",
58
+ examples_per_page=12
59
+ )
60
+ btn = gr.Button("Run!")
61
+
62
+ btn.click(fn=run, inputs=[image], outputs=[image_out])
63
+
64
+ gr.HTML(
65
+ """
66
+ <div class="footer">
67
+ <p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
68
+ </div>
69
+ <div class="acknowledgments">
70
+ <p><h4>ACKNOWLEDGEMENTS</h4></p>
71
+ <p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
72
+ <p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
73
+ </div>
74
+ """
75
+ )
76
+
77
+ demo.launch()