Vis_Diff / app.py
Steelskull's picture
Update app.py
ed2f92c verified
raw
history blame
5.47 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr
import io
import PIL.Image
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
layer_diffs = []
layers = zip(base_model.model.layers, chat_model.model.layers)
if load_one_at_a_time:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
base_layer, chat_layer = None, None
del base_layer, chat_layer
else:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False)
axs[i].set_title(component)
axs[i].set_xlabel("Difference")
axs[i].set_ylabel("Layer")
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers))
axs[i].invert_yaxis()
plt.tight_layout()
# Convert plot to image
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
buf.seek(0)
plt.close(fig) # Close the figure to free memory
return PIL.Image.open(buf)
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
# Update to use 'token' instead of 'use_auth_token' to handle deprecation warning
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter base model name"),
gr.Textbox(lines=2, placeholder="Enter chat model name"),
gr.Textbox(lines=2, placeholder="Enter Hugging Face token", type="password"), # Hide token input
gr.Checkbox(label="Load one layer at a time")
],
outputs=gr.Image(type="pil"), # Specify PIL image output
title="Model Weight Difference Visualizer"
)
iface.launch()