dcahn12 commited on
Commit
1dbaf53
·
1 Parent(s): f0869cb

Define inference utils

Browse files
Files changed (2) hide show
  1. gradio_web_server.py +2 -2
  2. infer_utils.py +118 -0
gradio_web_server.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import gradio as gr
6
  from fastapi import FastAPI
7
  import os
8
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9
  from PIL import Image
10
  import tempfile
11
  from decord import VideoReader, cpu
@@ -18,7 +18,7 @@ from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT
18
  from llava.conversation import conv_templates, SeparatorStyle, Conversation
19
  from llava.mm_utils import process_images
20
 
21
- from Evaluation.infer_utils import load_video_into_frames
22
  from serve.utils import load_image, image_ext, video_ext
23
  from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
24
 
 
5
  import gradio as gr
6
  from fastapi import FastAPI
7
  import os
8
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9
  from PIL import Image
10
  import tempfile
11
  from decord import VideoReader, cpu
 
18
  from llava.conversation import conv_templates, SeparatorStyle, Conversation
19
  from llava.mm_utils import process_images
20
 
21
+ from serve.infer_utils import load_video_into_frames
22
  from serve.utils import load_image, image_ext, video_ext
23
  from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
24
 
infer_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+ import numpy as np
5
+ from PIL import Image
6
+ import requests
7
+ from io import BytesIO
8
+ import torch
9
+ from torchvision.transforms import Compose, Lambda, ToTensor
10
+ from torchvision.transforms.functional import to_pil_image
11
+
12
+
13
+ def load_json(file_path):
14
+ with open(file_path, 'r') as f:
15
+ return json.load(f)
16
+
17
+ def load_jsonl(file_path):
18
+ with open(file_path, 'r') as f:
19
+ return [json.loads(l) for l in f]
20
+
21
+ def save_json(data, file_path):
22
+ with open(file_path, 'w') as f:
23
+ json.dump(data, f)
24
+
25
+ def save_jsonl(data, file_path):
26
+ with open(file_path, 'w') as f:
27
+ for d in data:
28
+ f.write(json.dumps(d) + '\n')
29
+
30
+ def split_list(lst, n):
31
+ """Split a list into n (roughly) equal-sized chunks"""
32
+ chunk_size = math.ceil(len(lst) / n) # integer division
33
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
34
+
35
+
36
+ def get_chunk(lst, n, k):
37
+ chunks = split_list(lst, n)
38
+ return chunks[k]
39
+
40
+
41
+ def load_image(image_file):
42
+ if image_file.startswith('http://') or image_file.startswith('https://'):
43
+ response = requests.get(image_file)
44
+ image = Image.open(BytesIO(response.content)).convert('RGB')
45
+ else:
46
+ image = Image.open(image_file).convert('RGB')
47
+ return image
48
+
49
+
50
+ def load_frames(frame_names, num_frames=None):
51
+ frame_names.sort()
52
+ # sample frames
53
+ if num_frames is not None and len(frame_names) != num_frames:
54
+ duration = len(frame_names)
55
+ frame_id_array = np.linspace(0, duration-1, num_frames, dtype=int)
56
+ frame_id_list = frame_id_array.tolist()
57
+ else:
58
+ frame_id_list = range(num_frames)
59
+
60
+ results = []
61
+ for frame_idx in frame_id_list:
62
+ frame_name = frame_names[frame_idx]
63
+ results.append(load_image(frame_name))
64
+
65
+ return results
66
+
67
+
68
+ def load_video_into_frames(
69
+ video_path,
70
+ video_decode_backend='opencv',
71
+ num_frames=8,
72
+ return_tensor=False,
73
+ ):
74
+ print("VIDEO PATH !!!", video_path)
75
+ if video_decode_backend == 'decord':
76
+ import decord
77
+ from decord import VideoReader, cpu
78
+ decord.bridge.set_bridge('torch')
79
+ decord_vr = VideoReader(video_path, ctx=cpu(0))
80
+ duration = len(decord_vr)
81
+ frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
82
+ video_data = decord_vr.get_batch(frame_id_list)
83
+ if return_tensor:
84
+ video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
85
+ else:
86
+ video_data = [to_pil_image(f) for f in video_data]
87
+ elif video_decode_backend == 'frames':
88
+ frames = load_frames([os.path.join(video_path, imname)
89
+ for imname in os.listdir(video_path)],
90
+ num_frames=num_frames)
91
+ video_data = frames
92
+ if return_tensor:
93
+ to_tensor = ToTensor()
94
+ video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
95
+ elif video_decode_backend == 'opencv':
96
+ import cv2
97
+ cv2_vr = cv2.VideoCapture(video_path)
98
+ duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
99
+ frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
100
+ # frame_id_list = np.linspace(0, duration-5, num_frames, dtype=int)
101
+
102
+ video_data = []
103
+ for frame_idx in frame_id_list:
104
+ cv2_vr.set(1, frame_idx)
105
+ ret, frame = cv2_vr.read()
106
+ if not ret:
107
+ raise ValueError(f'video error at {video_path}')
108
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
109
+ if return_tensor:
110
+ video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
111
+ else:
112
+ video_data.append(Image.fromarray(frame))
113
+ cv2_vr.release()
114
+ if return_tensor:
115
+ video_data = torch.stack(video_data, dim=1)
116
+ else:
117
+ raise NameError(f'video_decode_backend should specify in (pytorchvideo, decord, opencv, frames) but got {video_decode_backend}')
118
+ return video_data