Spaces:
Paused
Paused
dcahn12
commited on
Commit
·
1dbaf53
1
Parent(s):
f0869cb
Define inference utils
Browse files- gradio_web_server.py +2 -2
- infer_utils.py +118 -0
gradio_web_server.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import gradio as gr
|
6 |
from fastapi import FastAPI
|
7 |
import os
|
8 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
9 |
from PIL import Image
|
10 |
import tempfile
|
11 |
from decord import VideoReader, cpu
|
@@ -18,7 +18,7 @@ from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT
|
|
18 |
from llava.conversation import conv_templates, SeparatorStyle, Conversation
|
19 |
from llava.mm_utils import process_images
|
20 |
|
21 |
-
from
|
22 |
from serve.utils import load_image, image_ext, video_ext
|
23 |
from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
|
24 |
|
|
|
5 |
import gradio as gr
|
6 |
from fastapi import FastAPI
|
7 |
import os
|
8 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
9 |
from PIL import Image
|
10 |
import tempfile
|
11 |
from decord import VideoReader, cpu
|
|
|
18 |
from llava.conversation import conv_templates, SeparatorStyle, Conversation
|
19 |
from llava.mm_utils import process_images
|
20 |
|
21 |
+
from serve.infer_utils import load_video_into_frames
|
22 |
from serve.utils import load_image, image_ext, video_ext
|
23 |
from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
|
24 |
|
infer_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import requests
|
7 |
+
from io import BytesIO
|
8 |
+
import torch
|
9 |
+
from torchvision.transforms import Compose, Lambda, ToTensor
|
10 |
+
from torchvision.transforms.functional import to_pil_image
|
11 |
+
|
12 |
+
|
13 |
+
def load_json(file_path):
|
14 |
+
with open(file_path, 'r') as f:
|
15 |
+
return json.load(f)
|
16 |
+
|
17 |
+
def load_jsonl(file_path):
|
18 |
+
with open(file_path, 'r') as f:
|
19 |
+
return [json.loads(l) for l in f]
|
20 |
+
|
21 |
+
def save_json(data, file_path):
|
22 |
+
with open(file_path, 'w') as f:
|
23 |
+
json.dump(data, f)
|
24 |
+
|
25 |
+
def save_jsonl(data, file_path):
|
26 |
+
with open(file_path, 'w') as f:
|
27 |
+
for d in data:
|
28 |
+
f.write(json.dumps(d) + '\n')
|
29 |
+
|
30 |
+
def split_list(lst, n):
|
31 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
32 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
33 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
34 |
+
|
35 |
+
|
36 |
+
def get_chunk(lst, n, k):
|
37 |
+
chunks = split_list(lst, n)
|
38 |
+
return chunks[k]
|
39 |
+
|
40 |
+
|
41 |
+
def load_image(image_file):
|
42 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
43 |
+
response = requests.get(image_file)
|
44 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
45 |
+
else:
|
46 |
+
image = Image.open(image_file).convert('RGB')
|
47 |
+
return image
|
48 |
+
|
49 |
+
|
50 |
+
def load_frames(frame_names, num_frames=None):
|
51 |
+
frame_names.sort()
|
52 |
+
# sample frames
|
53 |
+
if num_frames is not None and len(frame_names) != num_frames:
|
54 |
+
duration = len(frame_names)
|
55 |
+
frame_id_array = np.linspace(0, duration-1, num_frames, dtype=int)
|
56 |
+
frame_id_list = frame_id_array.tolist()
|
57 |
+
else:
|
58 |
+
frame_id_list = range(num_frames)
|
59 |
+
|
60 |
+
results = []
|
61 |
+
for frame_idx in frame_id_list:
|
62 |
+
frame_name = frame_names[frame_idx]
|
63 |
+
results.append(load_image(frame_name))
|
64 |
+
|
65 |
+
return results
|
66 |
+
|
67 |
+
|
68 |
+
def load_video_into_frames(
|
69 |
+
video_path,
|
70 |
+
video_decode_backend='opencv',
|
71 |
+
num_frames=8,
|
72 |
+
return_tensor=False,
|
73 |
+
):
|
74 |
+
print("VIDEO PATH !!!", video_path)
|
75 |
+
if video_decode_backend == 'decord':
|
76 |
+
import decord
|
77 |
+
from decord import VideoReader, cpu
|
78 |
+
decord.bridge.set_bridge('torch')
|
79 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
80 |
+
duration = len(decord_vr)
|
81 |
+
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
|
82 |
+
video_data = decord_vr.get_batch(frame_id_list)
|
83 |
+
if return_tensor:
|
84 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
85 |
+
else:
|
86 |
+
video_data = [to_pil_image(f) for f in video_data]
|
87 |
+
elif video_decode_backend == 'frames':
|
88 |
+
frames = load_frames([os.path.join(video_path, imname)
|
89 |
+
for imname in os.listdir(video_path)],
|
90 |
+
num_frames=num_frames)
|
91 |
+
video_data = frames
|
92 |
+
if return_tensor:
|
93 |
+
to_tensor = ToTensor()
|
94 |
+
video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
|
95 |
+
elif video_decode_backend == 'opencv':
|
96 |
+
import cv2
|
97 |
+
cv2_vr = cv2.VideoCapture(video_path)
|
98 |
+
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
|
99 |
+
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
|
100 |
+
# frame_id_list = np.linspace(0, duration-5, num_frames, dtype=int)
|
101 |
+
|
102 |
+
video_data = []
|
103 |
+
for frame_idx in frame_id_list:
|
104 |
+
cv2_vr.set(1, frame_idx)
|
105 |
+
ret, frame = cv2_vr.read()
|
106 |
+
if not ret:
|
107 |
+
raise ValueError(f'video error at {video_path}')
|
108 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
109 |
+
if return_tensor:
|
110 |
+
video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
|
111 |
+
else:
|
112 |
+
video_data.append(Image.fromarray(frame))
|
113 |
+
cv2_vr.release()
|
114 |
+
if return_tensor:
|
115 |
+
video_data = torch.stack(video_data, dim=1)
|
116 |
+
else:
|
117 |
+
raise NameError(f'video_decode_backend should specify in (pytorchvideo, decord, opencv, frames) but got {video_decode_backend}')
|
118 |
+
return video_data
|