Spaces:
Runtime error
Runtime error
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/config.py | |
from functools import partial | |
import torch.nn as nn | |
from i2v_enhance.thirdparty.VFI.model import feature_extractor | |
from i2v_enhance.thirdparty.VFI.model import flow_estimation | |
'''==========Model config==========''' | |
def init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]): | |
'''This function should not be modified''' | |
return { | |
'embed_dims':[F, 2*F, 4*F, 8*F, 16*F], | |
'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]], | |
'num_heads':[8*F//32, 16*F//32], | |
'mlp_ratios':[4, 4], | |
'qkv_bias':True, | |
'norm_layer':partial(nn.LayerNorm, eps=1e-6), | |
'depths':depth, | |
'window_sizes':[W, W] | |
}, { | |
'embed_dims':[F, 2*F, 4*F, 8*F, 16*F], | |
'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]], | |
'depths':depth, | |
'num_heads':[8*F//32, 16*F//32], | |
'window_sizes':[W, W], | |
'scales':[4, 8, 16], | |
'hidden_dims':[4*F, 4*F], | |
'c':F | |
} | |
MODEL_CONFIG = { | |
'LOGNAME': 'ours', | |
'MODEL_TYPE': (feature_extractor, flow_estimation), | |
'MODEL_ARCH': init_model_config( | |
F = 32, | |
W = 7, | |
depth = [2, 2, 2, 4, 4] | |
) | |
} | |
# MODEL_CONFIG = { | |
# 'LOGNAME': 'ours_small', | |
# 'MODEL_TYPE': (feature_extractor, flow_estimation), | |
# 'MODEL_ARCH': init_model_config( | |
# F = 16, | |
# W = 7, | |
# depth = [2, 2, 2, 2, 2] | |
# ) | |
# } |