import os import gradio as gr import torch from v1.usta_model import UstaModel from v1.usta_tokenizer import UstaTokenizer # Load the model and tokenizer def load_model(custom_model_path=None): try: u_tokenizer = UstaTokenizer("v1/tokenizer.json") print("✅ Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab)) # Model parameters - adjust these to match your trained model context_length = 32 vocab_size = len(u_tokenizer.vocab) embedding_dim = 12 num_heads = 4 num_layers = 8 # Load the model u_model = UstaModel( vocab_size=vocab_size, embedding_dim=embedding_dim, num_heads=num_heads, context_length=context_length, num_layers=num_layers ) # Determine which model file to use if custom_model_path and os.path.exists(custom_model_path): model_path = custom_model_path print(f"🎯 Using uploaded model: {model_path}") else: model_path = "v1/u_model.pth" if not os.path.exists(model_path): print("❌ Model file not found at", model_path) # Download the model file from GitHub try: print("📥 Downloading model weights from GitHub...") import requests url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model_4000.pth" headers = { 'Accept': 'application/octet-stream', 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, headers=headers) response.raise_for_status() # Raise an exception for bad status codes # Check if we got a proper binary file (PyTorch files start with specific bytes) if response.content[:4] != b'PK\x03\x04' and b' 25: # Leave some room for generation tokens = tokens[-25:] # Generate response with torch.no_grad(): # Use max_tokens parameter, but cap it at reasonable limit for this model actual_max_tokens = min(max_tokens, 32 - len(tokens)) generated_tokens = model.generate(tokens, actual_max_tokens) # Decode the generated tokens response = tokenizer.decode(generated_tokens) # Clean up the response (remove the original input) original_text = tokenizer.decode(tokens.tolist()) if response.startswith(original_text): response = response[len(original_text):] # Clean up any unwanted tokens response = response.replace("", "").replace("", "").strip() if not response: response = "I'm not sure how to respond to that with my geographical knowledge." # Yield the response (to maintain compatibility with streaming interface) yield response except Exception as e: yield f"Sorry, I encountered an error: {str(e)}" """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ # Create the interface with file upload with gr.Blocks(title="🤖 Usta Model Chat", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🤖 Usta Model Chat") gr.Markdown("Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities.") with gr.Row(): with gr.Column(scale=2): # Model upload section with gr.Group(): gr.Markdown("### 📁 Model Upload (Optional)") model_file = gr.File( label="Upload your own model.pth file", file_types=[".pth", ".pt"], info="Upload a custom UstaModel checkpoint to use instead of the default model" ) upload_btn = gr.Button("Load Model", variant="primary") model_status_display = gr.Textbox( label="Model Status", value=model_status, interactive=False, info="Shows the current model loading status" ) with gr.Column(scale=1): # Settings with gr.Group(): gr.Markdown("### ⚙️ Generation Settings") system_msg = gr.Textbox( value="You are Usta, a geographical knowledge assistant trained from scratch.", label="System message", info="Note: This model focuses on geographical knowledge" ) max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", info="Note: This parameter is not used by UstaModel" ) # Chat interface chatbot = gr.ChatInterface( respond, additional_inputs=[system_msg, max_tokens, temperature, top_p], chatbot=gr.Chatbot(height=400), title=None, # We already have title above description=None # We already have description above ) # Event handlers upload_btn.click( update_model, inputs=[model_file], outputs=[model_status_display] ) if __name__ == "__main__": demo.launch()