import gradio as gr from transformers import AutoTokenizer import ast from collections import Counter import re import plotly.graph_objs as go model_path = "models/" # Available models MODELS = ["Meta-Llama-3.1-8B", "gemma-2b"] def create_vertical_histogram(data, title): labels, values = zip(*data) if data else ([], []) fig = go.Figure(go.Bar( x=labels, y=values )) fig.update_layout( title=title, xaxis_title="Item", yaxis_title="Count", height=400, xaxis=dict(tickangle=-45) ) return fig def process_input(input_type, input_value, model_name): tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) if input_type == "Text": text = input_value elif input_type == "Token IDs": try: token_ids = ast.literal_eval(input_value) text = tokenizer.decode(token_ids) except ValueError: return "Error", "Invalid input", "", "", "", None, None, None character_count = len(text) word_count = len(text.split()) token_ids = tokenizer.encode(text, add_special_tokens=True) tokens = tokenizer.convert_ids_to_tokens(token_ids) space_count = sum(1 for token in tokens if token == '▁') special_char_count = sum(1 for token in tokens if not token.isalnum() and token != '▁') words = re.findall(r'\b\w+\b', text.lower()) special_chars = re.findall(r'[^\w\s]', text) numbers = re.findall(r'\d+', text) most_common_words = Counter(words).most_common(10) most_common_special_chars = Counter(special_chars).most_common(10) most_common_numbers = Counter(numbers).most_common(10) words_hist = create_vertical_histogram(most_common_words, "Most Common Words") special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters") numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers") analysis = f"Token count: {len(tokens)}\n" analysis += f"Character count: {character_count}\n" analysis += f"Word count: {word_count}\n" analysis += f"Space tokens: {space_count}\n" analysis += f"Special character tokens: {special_char_count}\n" analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}" return analysis, " ".join(tokens), str(token_ids), words_hist, special_chars_hist, numbers_hist def text_example(): return "Hello, world! This is an example text input for tokenization." def token_ids_example(): return "[128000, 9906, 11, 1917, 0, 1115, 374, 459, 3187, 1495, 1988, 369, 4037, 2065, 13]" with gr.Blocks() as iface: gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!") gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.") with gr.Row(): input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text") model_name = gr.Dropdown(choices=MODELS, label="Select Model",value=MODELS[0]) input_text = gr.Textbox(lines=5, label="Input") with gr.Row(): text_example_button = gr.Button("Load Text Example") token_ids_example_button = gr.Button("Load Token IDs Example") submit_button = gr.Button("Process") analysis_output = gr.Textbox(label="Analysis", lines=6) tokens_output = gr.Textbox(label="Tokens", lines=3) token_ids_output = gr.Textbox(label="Token IDs", lines=2) with gr.Row(): words_plot = gr.Plot(label="Most Common Words") special_chars_plot = gr.Plot(label="Most Common Special Characters") numbers_plot = gr.Plot(label="Most Common Numbers") text_example_button.click( lambda: (text_example(), "Text"), outputs=[input_text, input_type] ) token_ids_example_button.click( lambda: (token_ids_example(), "Token IDs"), outputs=[input_text, input_type] ) submit_button.click( process_input, inputs=[input_type, input_text, model_name], outputs=[analysis_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot] ) if __name__ == "__main__": iface.launch()