from transformers import AutoModelForCausalLM, AutoTokenizer import torch import matplotlib.pyplot as plt import seaborn as sns from tqdm import tqdm import gradio as gr def calculate_weight_diff(base_weight, chat_weight): return torch.abs(base_weight - chat_weight).mean().item() def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False): layer_diffs = [] layers = zip(base_model.model.layers, chat_model.model.layers) if load_one_at_a_time: for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)): layer_diff = { 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight), 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight), 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight), 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight), 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight), 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight), 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight), 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight), 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight) } layer_diffs.append(layer_diff) base_layer, chat_layer = None, None del base_layer, chat_layer else: for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)): layer_diff = { 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight), 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight), 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight), 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight), 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight), 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight), 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight), 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight), 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight) } layer_diffs.append(layer_diff) return layer_diffs def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name): # Added model names as parameters num_layers = len(layer_diffs) num_components = len(layer_diffs[0]) fig, axs = plt.subplots(1, num_components, figsize=(24, 8)) fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16) for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())): component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs] sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False) axs[i].set_title(component) axs[i].set_xlabel("Difference") axs[i].set_ylabel("Layer") axs[i].set_xticks([]) axs[i].set_yticks(range(num_layers)) axs[i].set_yticklabels(range(num_layers)) axs[i].invert_yaxis() plt.tight_layout() return fig def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False): base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, use_auth_token=hf_token) chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, use_auth_token=hf_token) layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time) fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name) # Pass model names to visualization return fig iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Textbox(lines=2, placeholder="Enter base model name"), gr.Textbox(lines=2, placeholder="Enter finetuned model name"), gr.Textbox(lines=2, placeholder="Enter Hugging Face token"), gr.Checkbox(label="Load one layer at a time") ], outputs="image", title="Model Weight Difference Visualizer" ) iface.launch()