juanesvelez commited on
Commit
77be14e
·
verified ·
1 Parent(s): ad05766

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -38
app.py CHANGED
@@ -1,40 +1,96 @@
1
- import solara as sol
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- model_name = "datificate/gpt2-small-spanish"
6
- model = AutoModelForCausalLM.from_pretrained(model_name)
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
-
9
- def predict_next_token(text):
10
- inputs = tokenizer(text, return_tensors="pt")
11
- outputs = model(**inputs)
12
- next_token_logits = outputs.logits[:, -1, :]
13
- next_token_probs = torch.softmax(next_token_logits, dim=-1)
14
- top_k_probs, top_k_indices = torch.topk(next_token_probs, 10)
15
- top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices[0])
16
- return list(zip(top_k_tokens, top_k_probs[0].tolist()))
17
-
18
- @sol.component
19
- def NextTokenPredictionApp():
20
- text = sol.reactive("")
21
- predictions = sol.reactive([])
22
-
23
- def on_text_change(new_text):
24
- text.set(new_text)
25
- preds = predict_next_token(new_text)
26
- predictions.set(preds)
27
-
28
- sol.Markdown("# Predicción del Próximo Token")
29
- sol.Markdown("Ingrese un texto en español y vea las predicciones para el próximo token.")
30
-
31
- sol.InputText(value=text.value, on_change=on_text_change, placeholder="Escribe algo en español...", fullwidth=True)
32
- sol.Button("Predecir", on_click=lambda: on_text_change(text.value))
33
-
34
- if predictions.value:
35
- sol.Markdown("## Predicciones de tokens:")
36
- for token, prob in predictions.value:
37
- sol.Markdown(f"- **{token}**: {prob:.4f}")
38
-
39
- # Iniciar la aplicación en modo de desarrollo
40
- app = sol.App(NextTokenPredictionApp, title="Next Token Prediction App")
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Posts
8
+ Docs
9
+ Pricing
10
+
11
+
12
+
13
+ Spaces:
14
+
15
+ alonsosilva
16
+ /
17
+ NextTokenPrediction
18
+
19
+
20
+ like
21
+ 3
22
+ App
23
+ Files
24
+ Community
25
+ NextTokenPrediction
26
+ /
27
+ app.py
28
+
29
+ alonsosilva's picture
30
+ alonsosilva
31
+ Change reactive text
32
+ a4869ab
33
+ 7 months ago
34
+ raw
35
+ history
36
+ blame
37
+ contribute
38
+ delete
39
+ No virus
40
+ 2.63 kB
41
+ import solara
42
+ import random
43
  import torch
44
+ import torch.nn.functional as F
45
+ import pandas as pd
46
+ from transformers import AutoTokenizer, AutoModelForCausalLM
47
+
48
+ tokenizer = AutoTokenizer.from_pretrained('gpt2')
49
+ model = AutoModelForCausalLM.from_pretrained('gpt2')
50
+ text1 = solara.reactive("Never gonna give you up, never gonna let you")
51
+ @solara.component
52
+ def Page():
53
+ with solara.Column(margin=10):
54
+ solara.Markdown("#Next token prediction visualization")
55
+ solara.Markdown("I built this tool to help me understand autoregressive language models. For any given text, it gives the top 10 candidates to be the next token with their respective probabilities. The language model I'm using is the smallest version of GPT-2, with 124M parameters.")
56
+ def on_action_cell(column, row_index):
57
+ text1.value += tokenizer.decode(top_10.indices[0][row_index])
58
+ cell_actions = [solara.CellAction(icon="mdi-thumb-up", name="Select", on_click=on_action_cell)]
59
+ solara.InputText("Enter text:", value=text1, continuous_update=True)
60
+ if text1.value != "":
61
+ tokens = tokenizer.encode(text1.value, return_tensors="pt")
62
+ spans1 = ""
63
+ spans2 = ""
64
+ for i, token in enumerate(tokens[0]):
65
+ random.seed(i)
66
+ random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
67
+ spans1 += " " + f"<span style='font-family: helvetica; color: #{random_color}'>{token}</span>"
68
+ spans2 += " " + f"""<span style="
69
+ padding: 6px;
70
+ border-right: 3px solid white;
71
+ line-height: 3em;
72
+ font-family: courier;
73
+ background-color: #{random_color};
74
+ color: white;
75
+ position: relative;
76
+ "><span style="
77
+ position: absolute;
78
+ top: 5.5ch;
79
+ line-height: 1em;
80
+ left: -0.5px;
81
+ font-size: 0.45em"> {token}</span>{tokenizer.decode([token])}</span>"""
82
+ solara.Markdown(f'{spans2}')
83
+ solara.Markdown(f'{spans1}')
84
+ outputs = model.generate(tokens, max_new_tokens=1, output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
85
+ scores = F.softmax(outputs.scores[0], dim=-1)
86
+ top_10 = torch.topk(scores, 10)
87
+ df = pd.DataFrame()
88
+ df["probs"] = top_10.values[0]
89
+ df["probs"] = [f"{value:.2%}" for value in df["probs"].values]
90
+ df["next token ID"] = [top_10.indices[0][i].numpy() for i in range(10)]
91
+ df["predicted next token"] = [tokenizer.decode(top_10.indices[0][i]) for i in range(10)]
92
+ solara.Markdown("###Prediction")
93
+ solara.DataFrame(df, items_per_page=10, cell_actions=cell_actions)
94
+ Page()
95
 
96
+