haakohu's picture
fix
44539fc
raw
history blame contribute delete
714 Bytes
from dp2.metrics.ppl import calculate_ppl
from dp2.metrics.torch_metrics import compute_metrics_iteratively
from dp2.metrics.fid_clip import compute_fid_clip
def final_eval_fn(*args, **kwargs):
result = compute_metrics_iteratively(*args, **kwargs)
result2 = calculate_ppl(*args, **kwargs,)
result3 = compute_fid_clip(*args, **kwargs)
assert all(key not in result for key in result2)
result.update(result2)
result.update(result3)
return result
def train_eval_fn(*args, **kwargs):
result = compute_metrics_iteratively(*args, **kwargs)
result2 = compute_fid_clip(*args, **kwargs)
assert all(key not in result for key in result2)
result.update(result2)
return result