Spaces:
Running
Running
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from tqdm import tqdm | |
import gradio as gr | |
import io | |
import PIL.Image | |
def calculate_weight_diff(base_weight, chat_weight): | |
return torch.abs(base_weight - chat_weight).mean().item() | |
def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False): | |
layer_diffs = [] | |
layers = zip(base_model.model.layers, chat_model.model.layers) | |
if load_one_at_a_time: | |
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)): | |
layer_diff = { | |
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight), | |
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight), | |
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight), | |
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight), | |
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight), | |
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight), | |
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight), | |
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight), | |
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight) | |
} | |
layer_diffs.append(layer_diff) | |
base_layer, chat_layer = None, None | |
del base_layer, chat_layer | |
else: | |
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)): | |
layer_diff = { | |
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight), | |
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight), | |
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight), | |
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight), | |
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight), | |
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight), | |
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight), | |
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight), | |
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight) | |
} | |
layer_diffs.append(layer_diff) | |
return layer_diffs | |
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name): | |
num_layers = len(layer_diffs) | |
num_components = len(layer_diffs[0]) | |
fig, axs = plt.subplots(1, num_components, figsize=(24, 8)) | |
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16) | |
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())): | |
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs] | |
sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False) | |
axs[i].set_title(component) | |
axs[i].set_xlabel("Difference") | |
axs[i].set_ylabel("Layer") | |
axs[i].set_xticks([]) | |
axs[i].set_yticks(range(num_layers)) | |
axs[i].set_yticklabels(range(num_layers)) | |
axs[i].invert_yaxis() | |
plt.tight_layout() | |
# Convert plot to image | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') | |
buf.seek(0) | |
plt.close(fig) # Close the figure to free memory | |
return PIL.Image.open(buf) | |
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False): | |
# Update to use 'token' instead of 'use_auth_token' to handle deprecation warning | |
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token) | |
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token) | |
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time) | |
return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name) | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter base model name"), | |
gr.Textbox(lines=2, placeholder="Enter chat model name"), | |
gr.Textbox(lines=2, placeholder="Enter Hugging Face token", type="password"), # Hide token input | |
gr.Checkbox(label="Load one layer at a time") | |
], | |
outputs=gr.Image(type="pil"), # Specify PIL image output | |
title="Model Weight Difference Visualizer" | |
) | |
iface.launch() |