Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| import ast | |
| from collections import Counter | |
| import re | |
| import plotly.graph_objs as go | |
| model_path = "models/" | |
| # Available models | |
| MODELS = ["Meta-Llama-3.1-8B", "gemma-2b"] | |
| def create_vertical_histogram(data, title): | |
| labels, values = zip(*data) if data else ([], []) | |
| fig = go.Figure(go.Bar( | |
| x=labels, | |
| y=values | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Item", | |
| yaxis_title="Count", | |
| height=400, | |
| xaxis=dict(tickangle=-45) | |
| ) | |
| return fig | |
| def process_input(input_type, input_value, model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) | |
| if input_type == "Text": | |
| text = input_value | |
| elif input_type == "Token IDs": | |
| try: | |
| token_ids = ast.literal_eval(input_value) | |
| text = tokenizer.decode(token_ids) | |
| except ValueError: | |
| return "Error", "Invalid input", "", "", "", None, None, None | |
| character_count = len(text) | |
| word_count = len(text.split()) | |
| token_ids = tokenizer.encode(text, add_special_tokens=True) | |
| tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
| space_count = sum(1 for token in tokens if token == 'β') | |
| special_char_count = sum(1 for token in tokens if not token.isalnum() and token != 'β') | |
| words = re.findall(r'\b\w+\b', text.lower()) | |
| special_chars = re.findall(r'[^\w\s]', text) | |
| numbers = re.findall(r'\d+', text) | |
| most_common_words = Counter(words).most_common(10) | |
| most_common_special_chars = Counter(special_chars).most_common(10) | |
| most_common_numbers = Counter(numbers).most_common(10) | |
| words_hist = create_vertical_histogram(most_common_words, "Most Common Words") | |
| special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters") | |
| numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers") | |
| analysis = f"Token count: {len(tokens)}\n" | |
| analysis += f"Character count: {character_count}\n" | |
| analysis += f"Word count: {word_count}\n" | |
| analysis += f"Space tokens: {space_count}\n" | |
| analysis += f"Special character tokens: {special_char_count}\n" | |
| analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}" | |
| return analysis, " ".join(tokens), str(token_ids), words_hist, special_chars_hist, numbers_hist | |
| def text_example(): | |
| return "Hello, world! This is an example text input for tokenization." | |
| def token_ids_example(): | |
| return "[128000, 9906, 11, 1917, 0, 1115, 374, 459, 3187, 1495, 1988, 369, 4037, 2065, 13]" | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!") | |
| gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.") | |
| with gr.Row(): | |
| input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text") | |
| model_name = gr.Dropdown(choices=MODELS, label="Select Model",value=MODELS[0]) | |
| input_text = gr.Textbox(lines=5, label="Input") | |
| with gr.Row(): | |
| text_example_button = gr.Button("Load Text Example") | |
| token_ids_example_button = gr.Button("Load Token IDs Example") | |
| submit_button = gr.Button("Process") | |
| analysis_output = gr.Textbox(label="Analysis", lines=6) | |
| tokens_output = gr.Textbox(label="Tokens", lines=3) | |
| token_ids_output = gr.Textbox(label="Token IDs", lines=2) | |
| with gr.Row(): | |
| words_plot = gr.Plot(label="Most Common Words") | |
| special_chars_plot = gr.Plot(label="Most Common Special Characters") | |
| numbers_plot = gr.Plot(label="Most Common Numbers") | |
| text_example_button.click( | |
| lambda: (text_example(), "Text"), | |
| outputs=[input_text, input_type] | |
| ) | |
| token_ids_example_button.click( | |
| lambda: (token_ids_example(), "Token IDs"), | |
| outputs=[input_text, input_type] | |
| ) | |
| submit_button.click( | |
| process_input, | |
| inputs=[input_type, input_text, model_name], | |
| outputs=[analysis_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |