Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import gradio as gr
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
# Load model and tokenizer
|
8 |
+
model_ids = {
|
9 |
+
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT",
|
10 |
+
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT"
|
11 |
+
}
|
12 |
+
|
13 |
+
tokenizers = {
|
14 |
+
name: AutoTokenizer.from_pretrained(path)
|
15 |
+
for name, path in model_ids.items()
|
16 |
+
}
|
17 |
+
|
18 |
+
models = {
|
19 |
+
name: AutoModelForCausalLM.from_pretrained(path).eval()
|
20 |
+
for name, path in model_ids.items()
|
21 |
+
}
|
22 |
+
|
23 |
+
# Main function: compute token-wise log probabilities and top-k predictions
|
24 |
+
@torch.no_grad()
|
25 |
+
def compare_models(text, top_k=5):
|
26 |
+
results = {}
|
27 |
+
|
28 |
+
for model_name in model_ids:
|
29 |
+
tokenizer = tokenizers[model_name]
|
30 |
+
model = models[model_name]
|
31 |
+
|
32 |
+
# Tokenize input
|
33 |
+
inputs = tokenizer(text, return_tensors="pt")
|
34 |
+
input_ids = inputs["input_ids"]
|
35 |
+
|
36 |
+
# Get model output logits
|
37 |
+
outputs = model(**inputs)
|
38 |
+
shift_logits = outputs.logits[:, :-1, :] # Align prediction with target
|
39 |
+
shift_labels = input_ids[:, 1:] # Shift labels to match predictions
|
40 |
+
|
41 |
+
# Compute log probabilities
|
42 |
+
log_probs = F.log_softmax(shift_logits, dim=-1)
|
43 |
+
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
|
44 |
+
|
45 |
+
total_log_prob = token_log_probs.sum().item()
|
46 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] # Skip BOS token
|
47 |
+
|
48 |
+
# Generate top-k predictions for each position (up to first 20 tokens)
|
49 |
+
topk_list = []
|
50 |
+
for i in range(min(20, shift_logits.shape[1])):
|
51 |
+
topk = torch.topk(log_probs[0, i], k=top_k)
|
52 |
+
topk_ids = topk.indices.tolist()
|
53 |
+
topk_scores = topk.values.tolist()
|
54 |
+
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
|
55 |
+
topk_probs = [round(float(torch.exp(s)), 4) for s in topk_scores]
|
56 |
+
pair_list = [f"{tok} ({prob})" for tok, prob in zip(topk_tokens, topk_probs)]
|
57 |
+
topk_list.append(", ".join(pair_list))
|
58 |
+
|
59 |
+
# Prepare dataframe for display
|
60 |
+
df = pd.DataFrame({
|
61 |
+
"Token": tokens[:20],
|
62 |
+
"LogProb": [round(float(x), 4) for x in token_log_probs[0][:20]],
|
63 |
+
f"Top-{top_k} Predictions": topk_list
|
64 |
+
})
|
65 |
+
|
66 |
+
results[model_name] = {
|
67 |
+
"df": df,
|
68 |
+
"total_log_prob": total_log_prob
|
69 |
+
}
|
70 |
+
|
71 |
+
# Merge two model results into one table
|
72 |
+
merged = pd.DataFrame({
|
73 |
+
"Token": results["ERNIE-4.5-PT"]["df"]["Token"],
|
74 |
+
"ERNIE-4.5-PT LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
|
75 |
+
"ERNIE-4.5-PT Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"],
|
76 |
+
"ERNIE-4.5-Base-PT LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
|
77 |
+
"ERNIE-4.5-Base-PT Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"],
|
78 |
+
})
|
79 |
+
|
80 |
+
# Summarize total log probability for each model
|
81 |
+
summary = (
|
82 |
+
f"🧠 Total Log Prob:\n"
|
83 |
+
f"- ERNIE-4.5-PT: {results['ERNIE-4.5-PT']['total_log_prob']:.2f}\n"
|
84 |
+
f"- ERNIE-4.5-Base-PT: {results['ERNIE-4.5-Base-PT']['total_log_prob']:.2f}"
|
85 |
+
)
|
86 |
+
|
87 |
+
return merged, summary
|
88 |
+
|
89 |
+
# Gradio interface
|
90 |
+
demo = gr.Interface(
|
91 |
+
fn=compare_models,
|
92 |
+
inputs=[
|
93 |
+
gr.Textbox(lines=2, placeholder="Type a sentence here...", label="Input Sentence"),
|
94 |
+
gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-k Predictions")
|
95 |
+
],
|
96 |
+
outputs=[
|
97 |
+
gr.Dataframe(label="Token LogProbs and Top-k Predictions"),
|
98 |
+
gr.Textbox(label="Sentence Total Log Probability", lines=3)
|
99 |
+
],
|
100 |
+
title="🧪 ERNIE 4.5 Model Comparison with Top-k Predictions",
|
101 |
+
description="Compare ERNIE-4.5-0.3B Instruct and Base model by computing token logprobs and Top-k predictions"
|
102 |
+
)
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
demo.launch()
|