File size: 3,819 Bytes
3ee92ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f4cbf4
 
3ee92ac
 
 
d46ad28
 
 
 
3ee92ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f30af
1602533
022f833
4c82dea
3ee92ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f057b7d
3ee92ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: utf-8 -*-

import sys
import io
import requests
import json
import base64
from PIL import Image
import numpy as np
import gradio as gr

def inference_mask1(prompt,
              img,
              img_):
    files = {
        "pimage" : resizeImg(prompt["image"]),
        "pmask" : resizeImg(prompt["mask"]),
        "img" : resizeImg(img),
        "img_" : resizeImg(img_)
    }
    #r = requests.post("https://flagstudio.baai.ac.cn/painter/run", json = files)
    r = requests.post("http://120.92.79.209/painter/run", json = files)
    a = json.loads(r.text)
    res = []
    for i in range(len(a)):
        #out = Image.open(io.BytesIO(base64.b64decode(a[i])))
        #out = out.resize((224, 224))
        #res.append(np.uint8(np.array(out)))
        res.append(np.uint8(np.array(Image.open(io.BytesIO(base64.b64decode(a[i]))))))
    return res

def resizeImg(img):
    res, hres = 448, 448
    img = Image.fromarray(img).convert("RGB")
    img = img.resize((res, hres))
    temp = io.BytesIO()
    img.save(temp, format="WEBP")
    return base64.b64encode(temp.getvalue()).decode('ascii')

def inference_mask_cat(
              prompt,
              img,
              img_,
              ):
    output_list = [img, img_]
    return output_list


# define app features and run

examples = [
            ['./images/hmbb_1.jpg', './images/hmbb_2.jpg', './images/hmbb_3.jpg'],
            ['./images/rainbow_1.jpg', './images/rainbow_2.jpg', './images/rainbow_3.jpg'],
            ['./images/earth_1.jpg', './images/earth_2.jpg', './images/earth_3.jpg'],
            ['./images/obj_1.jpg', './images/obj_2.jpg', './images/obj_3.jpg'],
            ['./images/xray_1.jpg', './images/xray_2.jpg', './images/xray_3.jpg'],
            ['./images/ydt_2.jpg', './images/ydt_1.jpg', './images/ydt_3.jpg'],
           ]

demo_mask = gr.Interface(fn=inference_mask1, 
                   inputs=[gr.ImageMask(brush_radius=8, label="prompt (提示图)"), gr.Image(label="img1 (测试图1)"), gr.Image(label="img2 (测试图2)")], 
                    #outputs=[gr.Image(shape=(448, 448), label="output1 (输出图1)"), gr.Image(shape=(448, 448), label="output2 (输出图2)")],
                    outputs=[gr.Image(label="output1 (输出图1)").style(height=384, width=384), gr.Image(label="output2 (输出图2)").style(height=384, width=384)],
                    #outputs=gr.Gallery(label="outputs (输出图)"),
                    examples=examples,
                    #title="SegGPT for Any Segmentation<br>(Painter Inside)",
                    description="<p> \
                    Choose an example below &#128293; &#128293;  &#128293; <br>\
                    Or, upload by yourself: <br>\
                    1. Upload images to be tested to 'img1' and/or 'img2'. <br>2. Upload a prompt image to 'prompt' and draw a mask.  <br>\
                            Tips: The more accurate you annotate, the more accurate the model predicts.;) \
</p>",
                   cache_examples=False,
                   allow_flagging="never",
                   )


title = "SegGPT: Segmenting Everything In Context<br> \
<div align='center'> \
<h2><a href='https://arxiv.org/abs/2304.03284' target='_blank' rel='noopener'>[paper]</a> \
<a href='https://github.com/baaivision/Painter' target='_blank' rel='noopener'>[code]</a></h2> \
<br> \
<image src='file/seggpt_teaser.png' width='720px' /> \
<h2>SegGPT performs arbitrary segmentation tasks in images or videos via in-context inference, such as object instance, stuff, part, contour, and text, with only one single model.</h2> \
</div> \
"

demo = gr.TabbedInterface([demo_mask, ], ['General 1-shot', ], title=title)

#demo.launch(share=True, auth=("baai", "vision"))
demo.launch(enable_queue=False)
#demo.launch(server_name="0.0.0.0", server_port=34311)
# -