WHAM
English
microsoft
File size: 12,353 Bytes
f1c8ee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
Example script for running dreaming on a dataset.
The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context.

After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players),
should be identical if model was ideal.
"""

import argparse
from pathlib import Path
import os
import subprocess

import cv2
from tensordict import TensorDict
import torch as th
from tqdm import tqdm
import numpy as np
import ffmpegcv
from PIL import Image

import wham.utils as utils


parser = argparse.ArgumentParser(description="Run dreaming.")
parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.")
parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.")
parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.")
parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.")


parser.add_argument(
    "--protocol",
    type=str,
    default="base",
    choices=["base", "comprehensive"],
    help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.")
parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.")
parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.")

parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.")
parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.")
parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.")


def get_context_data(image_context, action_context, action_sequences):
    # Make sure we have CHW images:
    assert image_context.shape[-3] == 3, "Image context should be CHW"

    image_context = th.from_numpy(image_context).cuda()
    action_data = th.from_numpy(action_context).float().cuda()
    action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None

    return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2])


def add_video_metadata(file_path, metadata_config):
    # Construct the exiftool command
    cmd = [
        'exiftool',
        '-config', metadata_config,
        f'-ProgramName=\"{utils.PROGRAM_NAME}\"',
        '-overwrite_original',
        file_path
    ]

    try:
        # Execute the exiftool command
        subprocess.run(cmd, check=True)
        print(f"Metadata modified successfully.")
        # Print the new file metadata
        cmd_output = [
            'exiftool',
            file_path
        ]
        subprocess.run(cmd_output, check=True)
    except subprocess.CalledProcessError as e:
        print(f"Error modifying metadata: {e}")


@th.no_grad()
def do_dreaming(model, image_context, action_context, args, action_sequences=None):
    """
    image_contect and action_context provide the initial context for the model to dream from.

    If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions.
    """
    context_data = get_context_data(image_context, action_context, action_sequences)
    encoded_context_data = model.encode_context(context_data)

    encoded_action_sequences = None
    if action_sequences is not None:
        assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)"
        action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda()
        encoded_action_sequences = model.encode_context(action_sequences)

    encoded_dreamt_steps = []

    for dream_step in range(args.steps_to_dream):
        encoded_predicted_step, _ = model.predictor.predict_next_step(
            encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1
        )

        # Remove first step from context if we are at the max context length:
        if encoded_context_data.shape[1] == args.context_length:
            encoded_context_data = encoded_context_data[:, 1:]

        # Add predicted image + action to the context
        append_step = encoded_predicted_step
        if encoded_action_sequences is not None:
            # Replace predicted action with real action
            append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :]
        encoded_context_data = th.cat((encoded_context_data, append_step), dim=1)

        encoded_dreamt_steps.append(encoded_predicted_step)

    # Decode everything
    dreamed_images = []
    actions_during_dream = []
    for seq_i in range(args.steps_to_dream):
        decoded_step = model.decode_context(encoded_dreamt_steps[seq_i])
        dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy())
        actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy())

    dreamed_images = np.concatenate(dreamed_images, axis=1)
    actions_during_dream = np.concatenate(actions_during_dream, axis=1)

    return dreamed_images, actions_during_dream


@th.no_grad()
def encode_decode_images(model, images):
    """
    Pass ground_truth images through the encoding/decoding process of the model.
    """
    context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2])
    output_images = []
    for seq_i in range(images.shape[1]):
        encoded_images = model.encode_context(context[:, [seq_i]])
        decoded_images = model.decode_context(encoded_images)
        output_images.append(decoded_images["images"].cpu().numpy())
    return np.concatenate(output_images, axis=1)


def main(args):
    total_video_length = args.context_length + args.steps_to_dream

    # Now, load the model:
    model_path = Path(args.model_path)
    assert model_path.is_file(), "Could not find the model!"
    model = utils.load_model_from_checkpoint(model_path).cuda()

    # Glob the dataset to find all the ground truth segments we want to construct a dream for:
    data_path = Path(args.data_path)
    ground_truth_files = list(data_path.rglob("*.npz"))
    num_dreams = len(ground_truth_files)

    if args.max_files is not None:
        # Sort to make sure we always get the same files
        ground_truth_files = sorted(ground_truth_files)
        ground_truth_files = ground_truth_files[: args.max_files]
        num_dreams = len(ground_truth_files)

    output_path = Path(args.output)
    os.makedirs(output_path, exist_ok=True)

    print("=" * 100)
    print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS")
    print(f"WRITING TO {args.output}")
    print("=" * 100)

    dreams_created = 0
    with tqdm(total=num_dreams, desc="Dreams") as pbar:
        while ground_truth_files:
            # Load batch_size headers:
            batches = min(args.batch_size, len(ground_truth_files))
            batched_image_context = []
            batched_image_sequence = []
            batched_action_context = []
            batched_action_sequence = []
            episode_names = []
            for i in range(batches):
                episode = ground_truth_files.pop()
                episode_names.append(episode)
                try:
                    data = np.load(episode)
                    images = data["images"]
                    actions = data["actions"]
                except Exception:
                    print(f"Failed to load episode {episode} - skipping.")
                    continue

                if actions.shape[0] < total_video_length:
                    # We want to make sure we have ground_truth comparisons for the entire dream, so we ensure the episode is long enough
                    raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.")
                batched_image_context.append(images[: args.context_length])
                batched_image_sequence.append(images[args.context_length: total_video_length])
                batched_action_context.append(actions[: args.context_length])
                batched_action_sequence.append(actions[args.context_length: total_video_length])

            image_context = np.array(batched_image_context)
            image_sequences = np.array(batched_image_sequence)
            action_context = np.array(batched_action_context)
            action_sequences = np.array(batched_action_sequence)

            if args.protocol == "comprehensive":
                # We do not need to pass in the action sequences for comprehensive protocol
                action_sequences = None

            full_image_sequence = np.concatenate((image_context, image_sequences), axis=1)

            dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences)
            encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence)

            pbar.update(batches)
            dreams_created += batches

            # Save the dreams:
            # We are aiming to mimic the folder structure of the ground truth dataset, so use the episode names
            # but make them relative to our output folder:
            for i, dream in enumerate(dreamt_images):
                episode = episode_names[i]
                output_file = output_path / episode.relative_to(data_path)
                output_file.parent.mkdir(parents=True, exist_ok=True)
                np.savez(
                    output_file,
                    context_length=args.context_length,
                    steps_to_dream=args.steps_to_dream,
                    raw_context=image_context[i],
                    dreamt_images=dream,
                    all_actions=np.concatenate((action_context[i], actions_during_dream[i])),
                    encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i],
                )

                video_file = str(output_file.with_suffix(".mp4"))
                writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS)
                full_sequence = np.concatenate((image_context[i], dream), axis=0)
                for frame in full_sequence:
                    img = frame.transpose(1, 2, 0).astype(np.uint8).copy()
                    # Please DO NOT remove this watermark. This will infringe upon the repo's license agreement
                    (text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS)
                    x = img.shape[1] - text_width - 10  # 10 pixels from the right edge
                    y = img.shape[0] - 10  # 10 pixels from the bottom edge
                    cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS)

                    # Add image metadata
                    pil_image = Image.fromarray(img) 
                    pil_image.info['Id'] = 0x0131
                    pil_image.info['Type'] = 2
                    pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8")
                    pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1

                    # Convert pil_image to a CV2 format for the video writer
                    cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
                    writer.write(cv_image)
                writer.release()
                add_video_metadata(video_file, args.metadata_config)

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)