Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer | |
import ast | |
from collections import Counter | |
import re | |
import plotly.graph_objs as go | |
model_path = "models/" | |
# Available models | |
MODELS = ["Meta-Llama-3.1-8B", "gemma-2b"] | |
def create_vertical_histogram(data, title): | |
labels, values = zip(*data) if data else ([], []) | |
fig = go.Figure(go.Bar( | |
x=labels, | |
y=values | |
)) | |
fig.update_layout( | |
title=title, | |
xaxis_title="Item", | |
yaxis_title="Count", | |
height=400, | |
xaxis=dict(tickangle=-45) | |
) | |
return fig | |
def process_input(input_type, input_value, model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_path + model_name) | |
if input_type == "Text": | |
text = input_value | |
elif input_type == "Token IDs": | |
try: | |
token_ids = ast.literal_eval(input_value) | |
text = tokenizer.decode(token_ids) | |
except ValueError: | |
return "Error", "Invalid input", "", "", "", None, None, None | |
character_count = len(text) | |
word_count = len(text.split()) | |
token_ids = tokenizer.encode(text, add_special_tokens=True) | |
tokens = tokenizer.convert_ids_to_tokens(token_ids) | |
space_count = sum(1 for token in tokens if token == 'β') | |
special_char_count = sum(1 for token in tokens if not token.isalnum() and token != 'β') | |
words = re.findall(r'\b\w+\b', text.lower()) | |
special_chars = re.findall(r'[^\w\s]', text) | |
numbers = re.findall(r'\d+', text) | |
most_common_words = Counter(words).most_common(10) | |
most_common_special_chars = Counter(special_chars).most_common(10) | |
most_common_numbers = Counter(numbers).most_common(10) | |
words_hist = create_vertical_histogram(most_common_words, "Most Common Words") | |
special_chars_hist = create_vertical_histogram(most_common_special_chars, "Most Common Special Characters") | |
numbers_hist = create_vertical_histogram(most_common_numbers, "Most Common Numbers") | |
analysis = f"Token count: {len(tokens)}\n" | |
analysis += f"Character count: {character_count}\n" | |
analysis += f"Word count: {word_count}\n" | |
analysis += f"Space tokens: {space_count}\n" | |
analysis += f"Special character tokens: {special_char_count}\n" | |
analysis += f"Other tokens: {len(tokens) - space_count - special_char_count}" | |
return analysis, " ".join(tokens), str(token_ids), words_hist, special_chars_hist, numbers_hist | |
with gr.Blocks() as iface: | |
gr.Markdown("# LLM Tokenization - Convert Text to tokens and vice versa!") | |
gr.Markdown("Enter text or token IDs and select a model to see the results, including word count, token analysis, and histograms of most common elements.") | |
with gr.Row(): | |
input_type = gr.Radio(["Text", "Token IDs"], label="Input Type", value="Text") | |
model_name = gr.Dropdown(choices=MODELS, label="Select Model") | |
input_text = gr.Textbox(lines=5, label="Input") | |
submit_button = gr.Button("Process") | |
analysis_output = gr.Textbox(label="Analysis", lines=6) | |
tokens_output = gr.Textbox(label="Tokens", lines=3) | |
token_ids_output = gr.Textbox(label="Token IDs", lines=2) | |
with gr.Row(): | |
words_plot = gr.Plot(label="Most Common Words") | |
special_chars_plot = gr.Plot(label="Most Common Special Characters") | |
numbers_plot = gr.Plot(label="Most Common Numbers") | |
submit_button.click( | |
process_input, | |
inputs=[input_type, input_text, model_name], | |
outputs=[analysis_output, tokens_output, token_ids_output, words_plot, special_chars_plot, numbers_plot] | |
) | |
if __name__ == "__main__": | |
iface.launch() |