Steelskull commited on
Commit
2c89359
·
verified ·
1 Parent(s): 86a02df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -19
app.py CHANGED
@@ -8,24 +8,41 @@ import gradio as gr
8
  def calculate_weight_diff(base_weight, chat_weight):
9
  return torch.abs(base_weight - chat_weight).mean().item()
10
 
11
- def calculate_layer_diffs(base_model, chat_model):
12
  layer_diffs = []
13
- for base_layer, chat_layer in tqdm(zip(base_model.model.layers, chat_model.model.layers), total=len(base_model.model.layers)):
14
- layer_diff = {
15
- 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
16
- 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
17
- 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
18
- 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
19
- 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
20
- 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
21
- 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
22
- 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
23
- 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
24
- }
25
- layer_diffs.append(layer_diff)
 
 
 
26
 
27
- base_layer, chat_layer = None, None
28
- del base_layer, chat_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  return layer_diffs
31
 
@@ -50,11 +67,11 @@ def visualize_layer_diffs(layer_diffs):
50
  plt.tight_layout()
51
  return fig
52
 
53
- def gradio_interface(base_model_name, chat_model_name):
54
  base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
55
  chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
56
 
57
- layer_diffs = calculate_layer_diffs(base_model, chat_model)
58
  fig = visualize_layer_diffs(layer_diffs)
59
 
60
  return fig
@@ -63,7 +80,8 @@ iface = gr.Interface(
63
  fn=gradio_interface,
64
  inputs=[
65
  gr.Textbox(lines=2, placeholder="Enter base model name"),
66
- gr.Textbox(lines=2, placeholder="Enter chat model name")
 
67
  ],
68
  outputs="image",
69
  title="Model Weight Difference Visualizer"
 
8
  def calculate_weight_diff(base_weight, chat_weight):
9
  return torch.abs(base_weight - chat_weight).mean().item()
10
 
11
+ def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
12
  layer_diffs = []
13
+ layers = zip(base_model.model.layers, chat_model.model.layers)
14
+
15
+ if load_one_at_a_time:
16
+ for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
17
+ layer_diff = {
18
+ 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
19
+ 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
20
+ 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
21
+ 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
22
+ 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
23
+ 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
24
+ 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
25
+ 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
26
+ 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
27
+ }
28
+ layer_diffs.append(layer_diff)
29
 
30
+ base_layer, chat_layer = None, None
31
+ del base_layer, chat_layer
32
+ else:
33
+ for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
34
+ layer_diff = {
35
+ 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
36
+ 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
37
+ 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
38
+ 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
39
+ 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
40
+ 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
41
+ 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
42
+ 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
43
+ 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
44
+ }
45
+ layer_diffs.append(layer_diff)
46
 
47
  return layer_diffs
48
 
 
67
  plt.tight_layout()
68
  return fig
69
 
70
+ def gradio_interface(base_model_name, chat_model_name, load_one_at_a_time=False):
71
  base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)
72
  chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16)
73
 
74
+ layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
75
  fig = visualize_layer_diffs(layer_diffs)
76
 
77
  return fig
 
80
  fn=gradio_interface,
81
  inputs=[
82
  gr.Textbox(lines=2, placeholder="Enter base model name"),
83
+ gr.Textbox(lines=2, placeholder="Enter chat model name"),
84
+ gr.Checkbox(label="Load one layer at a time")
85
  ],
86
  outputs="image",
87
  title="Model Weight Difference Visualizer"