AlexTransformer commited on
Commit
707db97
·
verified ·
1 Parent(s): 80bae41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import gradio as gr
5
+ import pandas as pd
6
+
7
+ # Load model and tokenizer
8
+ model_ids = {
9
+ "ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT",
10
+ "ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT"
11
+ }
12
+
13
+ tokenizers = {
14
+ name: AutoTokenizer.from_pretrained(path)
15
+ for name, path in model_ids.items()
16
+ }
17
+
18
+ models = {
19
+ name: AutoModelForCausalLM.from_pretrained(path).eval()
20
+ for name, path in model_ids.items()
21
+ }
22
+
23
+ # Main function: compute token-wise log probabilities and top-k predictions
24
+ @torch.no_grad()
25
+ def compare_models(text, top_k=5):
26
+ results = {}
27
+
28
+ for model_name in model_ids:
29
+ tokenizer = tokenizers[model_name]
30
+ model = models[model_name]
31
+
32
+ # Tokenize input
33
+ inputs = tokenizer(text, return_tensors="pt")
34
+ input_ids = inputs["input_ids"]
35
+
36
+ # Get model output logits
37
+ outputs = model(**inputs)
38
+ shift_logits = outputs.logits[:, :-1, :] # Align prediction with target
39
+ shift_labels = input_ids[:, 1:] # Shift labels to match predictions
40
+
41
+ # Compute log probabilities
42
+ log_probs = F.log_softmax(shift_logits, dim=-1)
43
+ token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
44
+
45
+ total_log_prob = token_log_probs.sum().item()
46
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] # Skip BOS token
47
+
48
+ # Generate top-k predictions for each position (up to first 20 tokens)
49
+ topk_list = []
50
+ for i in range(min(20, shift_logits.shape[1])):
51
+ topk = torch.topk(log_probs[0, i], k=top_k)
52
+ topk_ids = topk.indices.tolist()
53
+ topk_scores = topk.values.tolist()
54
+ topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
55
+ topk_probs = [round(float(torch.exp(s)), 4) for s in topk_scores]
56
+ pair_list = [f"{tok} ({prob})" for tok, prob in zip(topk_tokens, topk_probs)]
57
+ topk_list.append(", ".join(pair_list))
58
+
59
+ # Prepare dataframe for display
60
+ df = pd.DataFrame({
61
+ "Token": tokens[:20],
62
+ "LogProb": [round(float(x), 4) for x in token_log_probs[0][:20]],
63
+ f"Top-{top_k} Predictions": topk_list
64
+ })
65
+
66
+ results[model_name] = {
67
+ "df": df,
68
+ "total_log_prob": total_log_prob
69
+ }
70
+
71
+ # Merge two model results into one table
72
+ merged = pd.DataFrame({
73
+ "Token": results["ERNIE-4.5-PT"]["df"]["Token"],
74
+ "ERNIE-4.5-PT LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
75
+ "ERNIE-4.5-PT Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"],
76
+ "ERNIE-4.5-Base-PT LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
77
+ "ERNIE-4.5-Base-PT Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"],
78
+ })
79
+
80
+ # Summarize total log probability for each model
81
+ summary = (
82
+ f"🧠 Total Log Prob:\n"
83
+ f"- ERNIE-4.5-PT: {results['ERNIE-4.5-PT']['total_log_prob']:.2f}\n"
84
+ f"- ERNIE-4.5-Base-PT: {results['ERNIE-4.5-Base-PT']['total_log_prob']:.2f}"
85
+ )
86
+
87
+ return merged, summary
88
+
89
+ # Gradio interface
90
+ demo = gr.Interface(
91
+ fn=compare_models,
92
+ inputs=[
93
+ gr.Textbox(lines=2, placeholder="Type a sentence here...", label="Input Sentence"),
94
+ gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top-k Predictions")
95
+ ],
96
+ outputs=[
97
+ gr.Dataframe(label="Token LogProbs and Top-k Predictions"),
98
+ gr.Textbox(label="Sentence Total Log Probability", lines=3)
99
+ ],
100
+ title="🧪 ERNIE 4.5 Model Comparison with Top-k Predictions",
101
+ description="Compare ERNIE-4.5-0.3B Instruct and Base model by computing token logprobs and Top-k predictions"
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ demo.launch()