| import os | |
| import streamlit as st | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer | |
| st.set_page_config(page_title='Gujju Llama Tokenizer Playground', layout="wide") | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Add your custom CSS here */ | |
| .stApp *, .stMarkdown *, .stTextInput *, .stTextArea *, .stSelectbox *, .stCheckbox *, .stRadio *, .stButton *, .stProgress *, .stSlider *, .stNumberInput * { | |
| color: #015b66 !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| token = os.environ.get("hf_token") | |
| login(token=token) | |
| class TokenizationVisualizer: | |
| def __init__(self): | |
| self.tokenizers = {} | |
| def add_tokenizer(self, name, model_name): | |
| self.tokenizers[name] = AutoTokenizer.from_pretrained(model_name) | |
| def visualize_tokens(self, text, tokenizer): | |
| tokens = tokenizer.tokenize(text) | |
| str_tokens = [] | |
| for token in tokens: | |
| str_tokens.append(tokenizer.convert_tokens_to_string([token])) | |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| colors = ['#ffdab9', '#e6ee9c', '#9cddc8', '#bcaaa4', '#c5b0d5'] | |
| html = "" | |
| for i, token in enumerate(str_tokens): | |
| color = colors[i % len(colors)] | |
| html += f'<mark title="{token}" style="background-color: {color};">{token}</mark>' | |
| return html, token_ids | |
| def playground_tab(visualizer): | |
| st.title("Tokenization Visualizer for Language Models") | |
| st.markdown(""" | |
| You can use this playground to visualize Llama2 tokens & Gujarati Llama tokens generated by the tokenizers. | |
| """) | |
| text_input = st.text_area("Enter text below to visualize tokens:", height=300) | |
| if st.button("Tokenize"): | |
| st.divider() | |
| if text_input.strip(): | |
| llama_tokenization_results, llama_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Llama2"]) | |
| gujju_tokenization_results, gujju_token_ids = visualizer.visualize_tokens(text_input, visualizer.tokenizers["Gujju Llama"]) | |
| col1, col2 = st.columns(2) | |
| col1.title('Llama2 Tokenizer') | |
| col1.container(height=200, border=True).markdown(llama_tokenization_results, unsafe_allow_html=True) | |
| with col1.expander(f"Token IDs (Token Counts = {len(llama_token_ids)})"): | |
| st.markdown(llama_token_ids) | |
| col2.title('Gujju Llama Tokenizer') | |
| col2.container(height=200, border=True).markdown(gujju_tokenization_results, unsafe_allow_html=True) | |
| with col2.expander(f"Token IDs (Token Counts = {len(gujju_token_ids)})"): | |
| st.markdown(gujju_token_ids) | |
| else: | |
| st.error("Please enter some text.") | |
| def main(): | |
| huggingface_tokenizers ={ | |
| "Gujju Llama": "sampoorna42/Gujju-Llama-Instruct-v0.1", | |
| "Llama2": "meta-llama/Llama-2-7b-hf", | |
| } | |
| visualizer = TokenizationVisualizer() | |
| for tokenizer, src in huggingface_tokenizers.items(): | |
| visualizer.add_tokenizer(tokenizer, src) | |
| playground_tab(visualizer) | |
| if __name__ == "__main__": | |
| main() | |