Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Dict, Union | |
import os | |
import random | |
import tempfile | |
from PIL import Image, ImageSequence | |
import base64 | |
import io | |
import re | |
import uuid | |
import json | |
import numpy as np | |
import pyarrow.fs as pf | |
import func_timeout | |
from func_timeout import func_set_timeout | |
import math | |
# fmt: on | |
import decord | |
# fmt: off | |
def denorm_box(points, height, width): | |
new_points = [] | |
for p in points: | |
new_points.append((round(p[0] * width), round(p[1] * height))) | |
return new_points | |
def process_image_for_tiktok(frames: List[Image.Image], mask_boxes): | |
mask_boxes = mask_boxes[:len(frames)] | |
frames = [np.array(f) for f in frames] | |
# assert len(mask_boxes) == len(frames) | |
height, width = frames[0].shape[:2] | |
new_frames = [] | |
for boxes, frame in zip(mask_boxes, frames): | |
left, top, right, bottom = 0, 0, width, height | |
for box in boxes: | |
pts = np.array(denorm_box(box, height, width), np.int32) | |
upper_bound = max([p[1] for p in pts]) + 30 | |
if bottom > upper_bound: | |
bottom = upper_bound | |
frame[pts[0][1]: pts[2][1], pts[0][0]: pts[1][0]] = 0 | |
new_frames.append(Image.fromarray(frame[top: bottom, left: right])) | |
return new_frames | |
# 先将视频分成 n_frames 份。训练时,每份随机抽一帧;测试时,每份抽中间的那一帧。 | |
def _sample_frame_indices_v2( | |
total_frames: int, | |
n_frames: int, | |
is_training=False, | |
video_sampling_strategy = {}, | |
): | |
total_frames_idxs = list(range(total_frames)) | |
if total_frames <= n_frames: | |
return total_frames_idxs | |
k, m = divmod(total_frames, n_frames) | |
frame_splits = [total_frames_idxs[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(n_frames))] | |
if is_training: | |
sample_ids = [random.choice(i) for i in frame_splits] | |
else: | |
sample_ids = [i[(len(i)+1)//2-1] for i in frame_splits] | |
return sample_ids | |
# 均匀抽帧,必采样首尾帧。 | |
def _sample_frame_indices_v1(total_frames: int, n_frames: int, is_training=False, video_sampling_strategy = {}): | |
if n_frames == 1: | |
return [0] # sample first frame in default | |
if total_frames <= n_frames: | |
return list(range(total_frames)) | |
sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)] | |
return sample_ids | |
def conduct_disturb_frame(frame_indices): | |
disturb_type = random.choice(['exchange', 'crop', 'reverse', 'discard']) | |
n_frames = len(frame_indices) | |
frame_indices_new = [] | |
if disturb_type == 'exchange': | |
# 均等分成4个segments, 随机交换两个segment | |
seg_len = math.ceil(n_frames / 4) | |
seg_idxs = list(range(0, n_frames, seg_len)) | |
target_idxs = random.sample(range(0, 4), 2) | |
seg_idxs[target_idxs[0]], seg_idxs[target_idxs[1]] = seg_idxs[target_idxs[1]], seg_idxs[target_idxs[0]] | |
for idx in seg_idxs: | |
frame_indices_new += frame_indices[idx: idx+seg_len] | |
elif disturb_type == 'crop': | |
# 随机截取出3/4时长,再采均匀n_frames帧 | |
crop_len = math.ceil(n_frames / 4) | |
idx_s = random.choice(range(0, crop_len+1)) | |
idx_e = n_frames - 1 - (crop_len - idx_s) | |
frame_indices_new = np.linspace(frame_indices[idx_s], frame_indices[idx_e], n_frames, dtype=int).tolist() | |
elif disturb_type == 'reverse': | |
# 随机选择长度为[1/2, 1]时长的片段进行顺序颠倒 | |
reverse_len = math.ceil(random.uniform(0.5,1) * n_frames) | |
idx_s = random.choice(range(0, n_frames-reverse_len+1)) | |
idx_e = idx_s + reverse_len - 1 | |
frame_indices_new = frame_indices[:idx_s] + list(reversed(frame_indices[idx_s: idx_e+1])) + frame_indices[idx_e+1:] | |
elif disturb_type == 'discard': | |
# 随机丢弃一半帧 | |
frame_indices_new = random.sample(frame_indices, n_frames//2) | |
frame_indices_new.sort() | |
return disturb_type, frame_indices_new | |
def _download_file(path): | |
if path.startswith("hdfs"): | |
local_path = os.path.join(tempfile.gettempdir(), f'{uuid.uuid4()}_' + os.path.basename(path)) | |
fs = pf.HadoopFileSystem.from_uri(uri="hdfs://harunava") | |
hdfs_file = fs.open_input_file(path) | |
file_size = hdfs_file.size() | |
if file_size > 1024 * 1024 * 1024: # 1G | |
os.system(f"hadoop fs -get --ct 8 -c 512 '{path}' '{local_path}' > /dev/null 2>&1") | |
elif file_size > 1024 * 1024 * 100: # 100M | |
os.system(f"hadoop fs -get '{path}' '{local_path}' > /dev/null 2>&1") | |
else: | |
local_fs = pf.LocalFileSystem() | |
with local_fs.open_output_stream(local_path) as local_file: | |
while True: | |
chunk = hdfs_file.read(1024 * 1024 * 100) # Reading 1MB chunks, you can adjust this as needed | |
if not chunk: | |
break | |
local_file.write(chunk) | |
else: | |
local_path = path | |
if not os.path.exists(local_path): | |
raise FileNotFoundError(f'{local_path}') | |
return local_path | |
def download_file(path): | |
try: | |
# with timer(f'Download {path}'): | |
return _download_file(path) | |
except func_timeout.exceptions.FunctionTimedOut as e: | |
raise ValueError(e) | |
class VideoReader: | |
def __init__(self, path: str) -> None: | |
self.path = path | |
self.local_path = self.preprocess() | |
self.vr = decord.VideoReader(self.local_path, num_threads=1, ctx=decord.cpu(0), fault_tol=1) | |
self.vr.seek(0) | |
self._length = len(self.vr) | |
self._fps = self.vr.get_avg_fps() | |
def length(self): | |
return self._length | |
def fps(self): | |
return self._fps | |
def sample(self, frame_indices) -> List[Image.Image]: | |
frames = self.vr.get_batch(frame_indices).asnumpy() | |
frames = [Image.fromarray(f).convert('RGB') for f in frames] | |
return frames | |
def preprocess(self): | |
return download_file(self.path) | |
def postprocess(self): | |
if self.path.startswith("hdfs"): | |
os.remove(self.local_path) | |
class ImageSeqReader: | |
def __init__(self, path: List[str]) -> None: | |
self.path = path | |
self.local_path = self.preprocess() | |
self._length = len(self.local_path) | |
self._fps = None | |
def length(self): | |
return self._length | |
def fps(self): | |
return self._fps | |
def sample(self, frame_indices): | |
return [read_image(self.local_path[i]) for i in frame_indices] | |
def preprocess(self): | |
local_paths = [] | |
for p in self.path: | |
local_paths.append(p) | |
return local_paths | |
def postprocess(self): | |
pass | |
class GIFReader: | |
def __init__(self, path: str) -> None: | |
self.path = path | |
self.local_path = self.preprocess() | |
self.gif = Image.open(self.local_path) | |
self._length = self.gif.n_frames | |
duration = self.gif.info.get('duration', 0) / 1000 # 转换为秒 | |
if duration > 0: | |
self._fps = 1 / duration | |
else: | |
self._fps = None | |
def length(self): | |
return self._length | |
def fps(self): | |
return self._fps | |
def sample(self, frame_indices): | |
frames = [] | |
i = 0 | |
for frame in ImageSequence.Iterator(self.gif): | |
if i in frame_indices: | |
frames.append(frame.convert('RGB')) | |
i += 1 | |
return frames | |
def preprocess(self): | |
return download_file(self.path) | |
def postprocess(self): | |
if self.path.startswith("hdfs"): | |
os.remove(self.local_path) | |
def check_frame_indices(frame_indices, total_frames, video_path): | |
if frame_indices[-1] == total_frames: | |
frame_indices[-1] = total_frames - 1 | |
valid_frame_indices = [i for i in frame_indices if i >= 0 and i < total_frames] | |
if len(valid_frame_indices) != len(frame_indices): | |
print(f'[Error] frame out of index. video_path={video_path}, frame_indices={frame_indices}, total_frames={total_frames}', flush=True) | |
return valid_frame_indices | |
def sample_video( | |
video_path: Union[str, List[str]], | |
frame_indices: List[int] = None, | |
start_frame:int=None, | |
end_frame:int=None, | |
n_frames:int = None, | |
time_indices: List[float] = None, | |
start_time:int=None, | |
end_time:int=None, | |
sampling_fps:float=None, | |
mask_boxes=None, | |
is_training:bool=False, | |
video_sampling_strategy={'video_sampler_version': 'v1'}, | |
return_frame_ids: bool=False, | |
) -> List[Image.Image]: | |
do_frame_disturb = video_sampling_strategy.get('do_frame_disturb', False) | |
if isinstance(video_path, str): | |
if video_path.endswith('.gif'): | |
reader = GIFReader(video_path) | |
else: | |
reader = VideoReader(video_path) | |
else: | |
reader = ImageSeqReader(video_path) | |
total_frames = reader.length | |
fps = reader.fps | |
if sampling_fps is not None: | |
frame_indices = list(range(0, total_frames, round(fps / sampling_fps))) | |
if len(frame_indices) > n_frames: | |
frame_indices = None | |
if time_indices is not None: | |
frame_indices = [round(float(i) * fps) for i in time_indices] | |
if start_time is not None and end_time is not None: | |
start_frame = round(start_time * fps) | |
end_frame = round(end_time * fps) | |
if frame_indices is None: | |
start_frame = 0 if start_frame is None else round(start_frame) | |
end_frame = total_frames - 1 if end_frame is None else round(end_frame) | |
if end_frame == total_frames: | |
end_frame -= 1 | |
if video_sampling_strategy['video_sampler_version'] == 'v1': | |
# 均匀抽帧,必采样首尾帧。 | |
frame_indices = _sample_frame_indices_v1(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy) | |
elif video_sampling_strategy['video_sampler_version'] == 'v2': | |
frame_indices = _sample_frame_indices_v2(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy) | |
else: | |
raise ValueError(f"video_sampler_version={video_sampling_strategy['video_sampler_version']} must be 'v1' or 'v2'") | |
frame_indices = [i + start_frame for i in frame_indices] | |
frame_indices = check_frame_indices(frame_indices, total_frames, video_path) | |
if do_frame_disturb: | |
frame_disturb_type, frame_indices_new = conduct_disturb_frame(frame_indices) | |
frame_indices_raw = frame_indices[:] | |
frame_indices = frame_indices_new | |
frames = reader.sample(frame_indices) | |
if mask_boxes is not None: | |
frames = process_image_for_tiktok(frames, mask_boxes) | |
n = video_sampling_strategy.get('force_frames_n_divisible', 1) | |
if n > 1 and len(frames) % n != 0: | |
new_n = n - len(frames) % n | |
frames.extend([Image.new(mode='RGB', size=frames[-1].size) for _ in range(new_n)]) | |
reader.postprocess() | |
if do_frame_disturb: | |
return frames, {"frame_indices": frame_indices, "disturb_type": frame_disturb_type, "frame_indices_raw": frame_indices_raw} | |
if return_frame_ids: | |
return frames, frame_indices | |
return frames | |
def load_image_from_base64String(img_path): | |
img = base64.b64decode(open(img_path, "rb").read()) | |
buf = io.BytesIO(img) | |
img = Image.open(buf) | |
return img | |
def read_image(image_path): | |
local_file = download_file(image_path) | |
if local_file.endswith('.dat'): | |
image = load_image_from_base64String(local_file) | |
else: | |
image = Image.open(local_file).convert('RGB') | |
if image_path.startswith("hdfs"): | |
os.remove(local_file) | |
return image | |
def adjust_bbox(text, frame): | |
width, height = frame.size | |
new_text = [] | |
start_idx = 0 | |
for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text): | |
coordinate_matches = re.findall(r"([0-9.]+)", match.group(0)) | |
xys = [float(coord) for coord in coordinate_matches] | |
new_xys = [] | |
for i in range(len(xys)): | |
p = xys[i] | |
if width == height: | |
pass | |
if width > height and i % 2 != 0: | |
p = xys[i] * height | |
p += (width - height) // 2 | |
p = round(p / width, 2) | |
if height > width and i % 2 == 0: | |
p = xys[i] * width | |
p += (height - width) // 2 | |
p = round(p / height, 2) | |
new_xys.append(p) | |
new_text.append(text[start_idx: match.span()[0]]) | |
new_text.append(str(new_xys)) | |
start_idx = match.span()[1] | |
new_text.append(text[start_idx: ]) | |
text = ''.join(new_text) | |
return text | |
def bbox_area(vertices, convert_format = True): | |
if convert_format: | |
vertices = list(zip(vertices[::2], vertices[1::2])) | |
x0, y0 = vertices[0] | |
x1, y1 = vertices[1] | |
return abs((x1 - x0) * (y1 - y0)) | |
def polygon_area(vertices, convert_format = True): | |
if convert_format: | |
vertices = list(zip(vertices[::2], vertices[1::2])) | |
n = len(vertices) # 多边形顶点的数量 | |
if n == 2: | |
return bbox_area(vertices, convert_format=False) | |
area = 0 | |
for i in range(n): | |
x1, y1 = vertices[i] | |
x2, y2 = vertices[(i + 1) % n] | |
area += x1 * y2 - x2 * y1 | |
return abs(area) / 2 | |
def get_text_len(text_line): | |
l = 0 | |
for c in text_line: | |
if '\u4e00' <= c <= '\u9fff': | |
l += 1 | |
else: | |
l += 0.5 | |
return l | |
def filter_ocr_polygon(response, area_threshold=0.0005): | |
try: | |
resp = json.loads(response) | |
except: | |
return response | |
new_resp = [] | |
for coords, text_line in resp: | |
area = polygon_area(coords, convert_format=True) | |
text_len = get_text_len(text_line) | |
if text_len == 0: | |
continue | |
if area / text_len < area_threshold: | |
continue | |
new_resp.append([coords, text_line]) | |
new_resp = json.dumps(new_resp, ensure_ascii=False) | |
return new_resp | |
def put_pred_to_data_dict(prediction, data_dict): | |
msg = data_dict['messages'][-1] | |
if msg['role'] == 'assistant': | |
msg['content'][-1]['text'] = prediction | |
else: | |
data_dict['messages'].append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": prediction}] | |
}) | |
def get_prompt_from_data_dict(data_dict): | |
prompt = "" | |
for msg in data_dict['messages']: | |
role = msg['role'] | |
assert role in {'system', 'user', 'assistant'} | |
for content in msg['content']: | |
if content['type'] == 'text': | |
if content['text']: | |
prompt += f"[{role}]: {content['text']}" | |
elif content['type'] == 'image': | |
prompt += f"[{role}]: <image>" | |
elif content['type'] == 'video': | |
prompt += f"[{role}]: <video>" | |
prompt += '\n' | |
return prompt | |