from ..trainer_videobase import VideoBaseTrainer import torch.nn.functional as F from typing import Optional import os import torch from transformers.utils import WEIGHTS_NAME import json class VQVAETrainer(VideoBaseTrainer): def compute_loss(self, model, inputs, return_outputs=False): model = model.module x = inputs.get("video") x = x / 2 z = model.pre_vq_conv(model.encoder(x)) vq_output = model.codebook(z) x_recon = model.decoder(model.post_vq_conv(vq_output["embeddings"])) recon_loss = F.mse_loss(x_recon, x) / 0.06 commitment_loss = vq_output['commitment_loss'] loss = recon_loss + commitment_loss return loss