Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from tracing.perm.permute import permute_model | |
| def main(base_model, ft_model, test_stat, num_perm, emb_dim=4096, mlp_dim=11008): | |
| unperm_stat = test_stat(base_model, ft_model) | |
| print(unperm_stat) | |
| perm_stats = [] | |
| for i in range(num_perm): | |
| mlp_permutation = torch.randperm(mlp_dim) | |
| emb_permutation = torch.randperm(emb_dim) | |
| permute_model(ft_model, mlp_permutation, emb_permutation) | |
| perm_stat = test_stat(base_model, ft_model) | |
| perm_stats.append(perm_stat) | |
| print(i, perm_stat) | |
| print(perm_stats) | |
| exact = p_value_exact(unperm_stat, perm_stats.copy()) | |
| approx = p_value_approx(unperm_stat, perm_stats.copy()) | |
| print(exact, approx) | |
| return exact, approx, unperm_stat, perm_stats | |
| def p_value_exact(unpermuted, permuted): | |
| count = 0 | |
| for a in permuted: | |
| if a < unpermuted: | |
| count += 1 | |
| return round((count + 1) / (len(permuted) + 1), 2) | |
| def p_value_approx(unpermuted, permuted): | |
| mean = sum(permuted) / len(permuted) | |
| std = np.std(permuted) | |
| zscore = (unpermuted - mean) / std | |
| return zscore | |