File size: 5,019 Bytes
dcd4560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import Dict, List
import random
from PIL import Image, ImageDraw, ImageFont

from .utils import sample_video


class VideoPermutationParser:
    def __init__(
        self,
        n_frames=8,
        is_training=True,
        frame_nums = list(range(8, 25)),
        video_sampling_strategy={},
    ):
        self.n_frames = n_frames
        self.is_training = is_training
        self.frame_nums = frame_nums
        self.video_sampling_strategy = video_sampling_strategy
        # fmt: off
        self.data_temp = {
            "text": [{
                "prompt": "<video>",
                "response": ""
            }],
            "video": [{
                "video_file": {
                    "yg": "/mnt/bn/videonasyg/videos/webvid_10M_download/011851_011900/1047443473.mp4",
                    "lq": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4"
                },
                "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
            }],
        }
        # fmt: on

    def check_format(self, data_dict: Dict):
        pass
        # for k in self.data_temp.keys():
        #     assert k in data_dict

    def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
        self.check_format(data_dict)

        frames = self.load_video_item(data_dict['video'][0])

        # frames = self.add_text_to_frames(frames) # for debug

        idxs = list(range(1, len(frames) + 1))
        random.shuffle(idxs)

        prefix_len = int(3/8*len(idxs))

        shuffled_frames = [frames[i-1] for i in idxs]

        prompt = f'Output the correct chronological order of scrambled video frames. The order of the first {prefix_len} ones are:\n'
        prompt += '\n'.join([str(i) for i in idxs[: prefix_len]]) + '\nOutput the order of the following frames:'
        response = '\n'.join([str(i) for i in idxs[prefix_len: ]])

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "video", "video": shuffled_frames},
                    {"type": "text", "text": prompt},
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": response}
                ]
            }
        ]

        return messages


    def load_video_item(self, video_item) -> List[Image.Image]:
        """
        video_item:
        {"video_file": "/path/to/video", "n_frames": 8} 
        {"video_file": "/path/to/video", "frame_indices": [0, 1, 2], "n_frames": 3} 
        {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100, "n_frames": 8}
        {"video_file": "/path/to/video", "time_indices": [0, 1, 2], "n_frames": 3}
        {"video_file": "/path/to/video", "start_time": 0, "end_time": 100, "n_frames": 8}
        {"image_file": ["/path/to/image"], "frame_indices": [0, 1, 2], "n_frames": 3}
        """

        # check format
        if ("image_file" not in video_item) and ("video_file" not in video_item):
            raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
    
        video_path = video_item.get('video_file', video_item.get('image_file'))
        n_frames = video_item.get('n_frames', None)
        frame_indices = video_item.get('frame_indices', None)
        start_frame = video_item.get('start_frame', None)
        end_frame = video_item.get('end_frame', None)
        time_indices = video_item.get('time_indices', None)
        start_time = video_item.get('start_time', None)
        end_time = video_item.get('end_time', None)
        mask_boxes = video_item.get('mask_boxes', None)

        n_frames = random.choice(self.frame_nums)
        n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
        if n > 1 and n_frames % n != 0:
            n_frames += n - n_frames % n

        frames, frame_indices = sample_video(
            video_path=video_path,
            frame_indices=frame_indices,
            start_frame=start_frame,
            end_frame=end_frame,
            n_frames=n_frames,
            time_indices=time_indices,
            start_time=start_time,
            end_time=end_time,
            mask_boxes=mask_boxes,
            is_training=self.is_training,
            video_sampling_strategy=self.video_sampling_strategy,
            return_frame_ids=True,
        )
        return frames
    

    def add_text_to_frames(self, frames: List[Image.Image]):
        new_frames = []
        for i, image in enumerate(frames):
            draw = ImageDraw.Draw(image)
            
            font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 100)
            text_position = (50, 50)
            text_content = f'{i+1}'
            text_color = (255, 0, 0)
            draw.text(text_position, text_content, font=font, fill=text_color)
            new_frames.append(image)
        return new_frames