File size: 3,672 Bytes
2d0a0f5
5dca0b0
 
e88497a
 
 
2d0a0f5
e88497a
5dca0b0
 
e88497a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dca0b0
 
e88497a
5dca0b0
 
e88497a
5dca0b0
 
 
e88497a
5dca0b0
e88497a
5dca0b0
e88497a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d0a0f5
86eb1ac
e88497a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import AutoTokenizer
import ast
from collections import Counter
import re
import plotly.graph_objs as go

model_path = "models/"

# Available models
MODELS = ["Meta-Llama-3.1-8B", "gemma-2b"]

def create_vertical_histogram(data, title):
    labels, values = zip(*data) if data else ([], [])
    fig = go.Figure(go.Bar(
        x=labels,
        y=values
    ))
    fig.update_layout(
        title=title,
        xaxis_title="Item",
        yaxis_title="Count",
        height=400,
        xaxis=dict(tickangle=-45)
    )
    return fig

def process_input(input_type, input_value, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_path + model_name)
    
    if input_type == "Text":
        text = input_value
    elif input_type == "Token IDs":
        try:
            token_ids = ast.literal_eval(input_value)
            text = tokenizer.decode(token_ids)
        except ValueError:
            return "Error", "Invalid input", "", "", "", None, None, None

    character_count = len(text)
    word_count = len(text.split())
    
    token_ids = tokenizer.encode(text, add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    
    space_count = sum(1 for token in tokens if token == '▁')
    special_char_count = sum(1 for token in tokens if not token.isalnum() and token != '▁')
    
    words = re.findall(r'\b\w+\b', text.lower())
    special_chars = re.findall(r'[^\w\s]', text)
    numbers = re.findall(r'\d+', text)
    
    most_common_words = Counter(words).most_common(10)
    most_common_special_chars = Counter(special_chars).most_common(10)
    most_common_numbers = Counter(numbers).most_common(10)
    
    words_hist = create_vertical_histogram(most_common_words, "Most Common Words")
    special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters")
    numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers")
    
    analysis = f"Token count: {len(tokens)}\n"
    analysis += f"Character count: {character_count}\n"
    analysis += f"Word count: {word_count}\n"
    analysis += f"Space tokens: {space_count}\n"
    analysis += f"Special character tokens: {special_char_count}\n"
    analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}"
    
    return analysis, " ".join(tokens), str(token_ids), words_hist, special_chars_hist, numbers_hist

with gr.Blocks() as iface:
    gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!")
    gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.")
    
    with gr.Row():
        input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text")
        model_name = gr.Dropdown(choices=MODELS, label="Select Model")
    
    input_text = gr.Textbox(lines=5, label="Input")
    
    submit_button = gr.Button("Process")
    
    analysis_output = gr.Textbox(label="Analysis", lines=6)
    tokens_output = gr.Textbox(label="Tokens", lines=3)
    token_ids_output = gr.Textbox(label="Token IDs", lines=2)
    
    with gr.Row():
        words_plot = gr.Plot(label="Most Common Words")
        special_chars_plot = gr.Plot(label="Most Common Special Characters")
        numbers_plot = gr.Plot(label="Most Common Numbers")
    
    submit_button.click(
        process_input,
        inputs=[input_type, input_text, model_name],
        outputs=[analysis_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot]
    )

if __name__ == "__main__":
    iface.launch()