Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,8 @@ import matplotlib.pyplot as plt
|
|
4 |
import seaborn as sns
|
5 |
from tqdm import tqdm
|
6 |
import gradio as gr
|
|
|
|
|
7 |
|
8 |
def calculate_weight_diff(base_weight, chat_weight):
|
9 |
return torch.abs(base_weight - chat_weight).mean().item()
|
@@ -46,7 +48,7 @@ def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
|
|
46 |
|
47 |
return layer_diffs
|
48 |
|
49 |
-
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
|
50 |
num_layers = len(layer_diffs)
|
51 |
num_components = len(layer_diffs[0])
|
52 |
|
@@ -65,26 +67,31 @@ def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name): # Add
|
|
65 |
axs[i].invert_yaxis()
|
66 |
|
67 |
plt.tight_layout()
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
|
71 |
-
|
72 |
-
|
|
|
73 |
|
74 |
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
|
75 |
-
|
76 |
-
|
77 |
-
return fig
|
78 |
|
79 |
iface = gr.Interface(
|
80 |
fn=gradio_interface,
|
81 |
inputs=[
|
82 |
gr.Textbox(lines=2, placeholder="Enter base model name"),
|
83 |
-
gr.Textbox(lines=2, placeholder="Enter
|
84 |
-
gr.Textbox(lines=2, placeholder="Enter Hugging Face token"),
|
85 |
gr.Checkbox(label="Load one layer at a time")
|
86 |
],
|
87 |
-
outputs="
|
88 |
title="Model Weight Difference Visualizer"
|
89 |
)
|
90 |
|
|
|
4 |
import seaborn as sns
|
5 |
from tqdm import tqdm
|
6 |
import gradio as gr
|
7 |
+
import io
|
8 |
+
import PIL.Image
|
9 |
|
10 |
def calculate_weight_diff(base_weight, chat_weight):
|
11 |
return torch.abs(base_weight - chat_weight).mean().item()
|
|
|
48 |
|
49 |
return layer_diffs
|
50 |
|
51 |
+
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
|
52 |
num_layers = len(layer_diffs)
|
53 |
num_components = len(layer_diffs[0])
|
54 |
|
|
|
67 |
axs[i].invert_yaxis()
|
68 |
|
69 |
plt.tight_layout()
|
70 |
+
|
71 |
+
# Convert plot to image
|
72 |
+
buf = io.BytesIO()
|
73 |
+
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
74 |
+
buf.seek(0)
|
75 |
+
plt.close(fig) # Close the figure to free memory
|
76 |
+
return PIL.Image.open(buf)
|
77 |
|
78 |
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
|
79 |
+
# Update to use 'token' instead of 'use_auth_token' to handle deprecation warning
|
80 |
+
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
|
81 |
+
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
|
82 |
|
83 |
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
|
84 |
+
return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
|
|
|
|
|
85 |
|
86 |
iface = gr.Interface(
|
87 |
fn=gradio_interface,
|
88 |
inputs=[
|
89 |
gr.Textbox(lines=2, placeholder="Enter base model name"),
|
90 |
+
gr.Textbox(lines=2, placeholder="Enter chat model name"),
|
91 |
+
gr.Textbox(lines=2, placeholder="Enter Hugging Face token", type="password"), # Hide token input
|
92 |
gr.Checkbox(label="Load one layer at a time")
|
93 |
],
|
94 |
+
outputs=gr.Image(type="pil"), # Specify PIL image output
|
95 |
title="Model Weight Difference Visualizer"
|
96 |
)
|
97 |
|