File size: 4,490 Bytes
35a9ed4
bdee200
35a9ed4
 
 
2a274cc
35a9ed4
 
2a274cc
 
 
 
 
 
 
35a9ed4
 
e856606
bdee200
19540cf
bdee200
 
 
19540cf
488936c
6207473
 
 
 
35a9ed4
5672cc2
 
 
 
 
 
 
 
 
 
2a274cc
 
 
 
 
 
 
 
 
 
 
35a9ed4
 
 
 
 
 
 
 
 
 
 
 
8c2e68c
35a9ed4
2a274cc
dc9bdbf
8c2e68c
2a274cc
af1dd1a
2a274cc
d28dde6
2a274cc
 
c0784bd
 
 
 
5672cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35a9ed4
 
 
 
 
b47ae2e
5672cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35a9ed4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image 
import numpy as np 
import os 
import tempfile
import gradio as gr

import cv2
try:
    from mmengine.visualization import Visualizer
except ImportError:
    Visualizer = None
    print("Warning: mmengine is not installed, visualization is disabled.")
    
# Load the model and tokenizer 
model_path = "ByteDance/Sa2VA-4B"
 
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True,
).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code = True,
)

from third_parts import VideoReader
def read_video(video_path, video_interval):
    vid_frames = VideoReader(video_path)[::video_interval]
    for frame_idx in range(len(vid_frames)):
        frame_image = vid_frames[frame_idx]
        frame_image = frame_image[..., ::-1]  # BGR (opencv system) to RGB (numpy system)
        frame_image = Image.fromarray(frame_image)
        vid_frames[frame_idx] = frame_image
    return vid_frames

def visualize(pred_mask, image_path, work_dir):
    visualizer = Visualizer()
    img = cv2.imread(image_path)
    visualizer.set_image(img)
    visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
    visual_result = visualizer.get_image()

    output_path = os.path.join(work_dir, os.path.basename(image_path))
    cv2.imwrite(output_path, visual_result)
    return output_path

def image_vision(image_input_path, prompt):
    image_path = image_input_path
    text_prompts = f"<image>{prompt}"
    image = Image.open(image_path).convert('RGB')
    input_dict = {
        'image': image,
        'text': text_prompts,
        'past_text': '',
        'mask_prompts': None,
        'tokenizer': tokenizer,
    }
    return_dict = model.predict_forward(**input_dict)
    print(return_dict)
    answer = return_dict["prediction"] # the text format answer
    
    seg_image = return_dict["prediction_masks"]
    
    if '[SEG]' in answer and Visualizer is not None:
        pred_masks = seg_image[0]
        temp_dir = tempfile.mkdtemp()
        pred_mask = pred_masks
        os.makedirs(temp_dir, exist_ok=True)
        seg_result = visualize(pred_mask, image_input_path, temp_dir)
        return answer, seg_result
    else:
        return answer, None

def video_vision(video_input_path, prompt):
    vid_frames = read_video(video_input_path, video_interval=6)
    # create a question (<image> is a placeholder for the video frames)
    question = f"<image>{prompt}"
    result = model.predict_forward(
        video=vid_frames,
        text=question,
        tokenizer=tokenizer,
    )
    prediction = result['prediction']
    print(prediction)

    return result['prediction'], None
    


# Gradio UI

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
        with gr.Tab("Single Image"):
            with gr.Row():
                with gr.Column():
                    image_input = gr.Image(label="Image IN", type="filepath")
                    with gr.Row():
                        instruction = gr.Textbox(label="Instruction", scale=4)
                        submit_image_btn = gr.Button("Submit", scale=1)
                with gr.Column():
                    output_res = gr.Textbox(label="Response")
                    output_image = gr.Image(label="Segmentation", type="numpy")
    
            submit_image_btn.click(
                fn = image_vision,
                inputs = [image_input, instruction],
                outputs = [output_res, output_image]
            )
        with gr.Tab("Video"):
            with gr.Row():
                with gr.Column():
                    video_input = gr.Image(label="Video IN")
                    with gr.Row():
                        vid_instruction = gr.Textbox(label="Instruction", scale=4)
                        submit_video_btn = gr.Button("Submit", scale=1)
                with gr.Column():
                    vid_output_res = gr.Textbox(label="Response")
                    output_video = gr.Video(label="Segmentation")
            
            submit_video_btn.click(
                fn = video_vision,
                inputs = [video_input, vid_instruction],
                outputs = [vid_output_res, output_video]
            )

demo.queue().launch(show_api=False, show_error=True)