File size: 4,910 Bytes
86a02df
 
 
 
 
 
 
 
 
 
2c89359
86a02df
2c89359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a02df
2c89359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86a02df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c89359
86a02df
 
 
2c89359
86a02df
 
 
 
 
 
 
 
2c89359
 
86a02df
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr

def calculate_weight_diff(base_weight, chat_weight):
    return torch.abs(base_weight - chat_weight).mean().item()

def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
    layer_diffs = []
    layers = zip(base_model.model.layers, chat_model.model.layers)
    
    if load_one_at_a_time:
        for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
            layer_diff = {
                'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
                'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
                'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
                'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
                'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
                'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
                'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
                'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
                'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
            }
            layer_diffs.append(layer_diff)

            base_layer, chat_layer = None, None
            del base_layer, chat_layer
    else:
        for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
            layer_diff = {
                'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
                'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
                'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
                'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
                'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
                'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
                'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
                'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
                'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
            }
            layer_diffs.append(layer_diff)

    return layer_diffs

def visualize_layer_diffs(layer_diffs):
    num_layers = len(layer_diffs)
    num_components = len(layer_diffs[0])

    fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
    fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)

    for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
        component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
        sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False)
        axs[i].set_title(component)
        axs[i].set_xlabel("Difference")
        axs[i].set_ylabel("Layer")
        axs[i].set_xticks([])
        axs[i].set_yticks(range(num_layers))
        axs[i].set_yticklabels(range(num_layers))
        axs[i].invert_yaxis()

    plt.tight_layout()
    return fig

def gradio_interface(base_model_name, chat_model_name, load_one_at_a_time=False):
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
    chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)

    layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
    fig = visualize_layer_diffs(layer_diffs)
    
    return fig

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter base model name"),
        gr.Textbox(lines=2, placeholder="Enter chat model name"),
        gr.Checkbox(label="Load one layer at a time")
    ],
    outputs="image",
    title="Model Weight Difference Visualizer"
)

iface.launch()