wildoctopus commited on
Commit
6648f32
·
verified ·
1 Parent(s): d537479

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -11
app.py CHANGED
@@ -1,16 +1,58 @@
1
- import numpy as np
2
  import gradio as gr
 
 
 
 
3
 
4
- def sepia(input_img):
5
- sepia_filter = np.array([
6
- [0.393, 0.769, 0.189],
7
- [0.349, 0.686, 0.168],
8
- [0.272, 0.534, 0.131]
9
- ])
10
- sepia_img = input_img.dot(sepia_filter.T)
11
- sepia_img /= sepia_img.max()
12
- return sepia_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- demo = gr.Interface(sepia, gr.Image(), "image")
15
  if __name__ == "__main__":
16
  demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ from process import load_seg_model, get_palette, generate_mask
4
+ from PIL import Image
5
+ import os
6
 
7
+ # Model initialization
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ def load_models():
11
+ try:
12
+ net = load_seg_model("model/cloth_segm.pth", device=device)
13
+ palette = get_palette(4)
14
+ return net, palette
15
+ except Exception as e:
16
+ raise gr.Error(f"Model failed to load: {str(e)}")
17
+
18
+ net, palette = load_models()
19
+
20
+ def predict(image):
21
+ if image is None:
22
+ raise gr.Error("Please upload or capture an image first")
23
+ try:
24
+ if not isinstance(image, Image.Image):
25
+ image = Image.fromarray(image)
26
+ return generate_mask(image, net=net, palette=palette, device=device)
27
+ except Exception as e:
28
+ raise gr.Error(f"Processing error: {str(e)}")
29
+
30
+ # Interface
31
+ with gr.Blocks(title="Cloth Segmentation") as demo:
32
+ gr.Markdown("## 👕 Cloth Segmentation Tool")
33
+
34
+ with gr.Row():
35
+ with gr.Column():
36
+ img_input = gr.Image(sources=["upload", "webcam"],
37
+ type="pil",
38
+ label="Input Image")
39
+ btn = gr.Button("Generate Mask", variant="primary")
40
+
41
+ with gr.Column():
42
+ img_output = gr.Image(label="Segmentation Result")
43
+
44
+ # Optional examples
45
+ if os.path.exists("examples"):
46
+ gr.Examples(
47
+ examples=[os.path.join("examples", f) for f in os.listdir("examples")
48
+ if f.endswith(('.png','.jpg','.jpeg'))],
49
+ inputs=img_input,
50
+ outputs=img_output,
51
+ fn=predict,
52
+ cache_examples=True
53
+ )
54
+
55
+ btn.click(predict, inputs=img_input, outputs=img_output)
56
 
 
57
  if __name__ == "__main__":
58
  demo.launch()