#! /usr/bin/env python3 """ This script is a simple text generator using the SmollmV2 model. It uses Gradio to create a web interface for generating text. """ # Third-Party Imports import torch import torch.nn.functional as F import gradio as gr from transformers import GPT2Tokenizer import spaces import os from pathlib import Path # Local imports from smollmv2 import SmollmV2 from config import SmollmConfig, DataConfig from smollv2_lightning import LitSmollmv2 def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"): """ Combine split model parts into a single checkpoint file """ # Create checkpoints directory if it doesn't exist os.makedirs(os.path.dirname(output_file), exist_ok=True) # Check if combined model already exists if os.path.exists(output_file): print(f"Model already combined at: {output_file}") return output_file # Ensure the model parts exist if not os.path.exists(model_dir): raise FileNotFoundError(f"Model directory {model_dir} not found") # Combine the parts parts = sorted(Path(model_dir).glob("last.ckpt.part_*")) if not parts: raise FileNotFoundError("No model parts found") print("Combining model parts...") with open(output_file, 'wb') as outfile: for part in parts: print(f"Processing part: {part}") with open(part, 'rb') as infile: outfile.write(infile.read()) print(f"Model combined successfully: {output_file}") return output_file def load_model(): """ Load the SmollmV2 model and tokenizer. """ device = 'cuda' if torch.cuda.is_available() else 'cpu' # Combine model parts and get the checkpoint path checkpoint_path = combine_model_parts() # Load the model from combined checkpoint using Lightning module model = LitSmollmv2.load_from_checkpoint( checkpoint_path, model_config=SmollmConfig, strict=False ) model.to(device) model.eval() # Initialize tokenizer tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path) tokenizer.pad_token = tokenizer.eos_token return model, tokenizer, device @spaces.GPU(enable_queue=True) def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9): """ Generate text using the SmollmV2 model. """ # Ensure num_tokens doesn't exceed model's block size num_tokens = min(num_tokens, SmollmConfig.block_size) # Tokenize input prompt input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Generate tokens one at a time for _ in range(num_tokens): # Get the model's predictions with torch.no_grad(): with torch.autocast(device_type=device, dtype=torch.bfloat16): logits, _ = model.model(input_ids) # Get the next token probabilities logits = logits[:, -1, :] / temperature probs = F.softmax(logits, dim=-1) # Apply top-p sampling if top_p > 0: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumsum_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_keep = cumsum_probs <= top_p sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone() sorted_indices_to_keep[..., 0] = 1 indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep) probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs)) probs = probs / probs.sum(dim=-1, keepdim=True) # Sample next token next_token = torch.multinomial(probs, num_samples=1) # Append to input_ids input_ids = torch.cat([input_ids, next_token], dim=-1) # Stop if we generate an EOS token if next_token.item() == tokenizer.eos_token_id: break # Decode and return the generated text generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) return generated_text # Load the model globally model, tokenizer, device = load_model() # Create the Gradio interface demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Enter your prompt", value="Once upon a time"), gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"), gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)") ], outputs=gr.Textbox(label="Generated Text"), title="SmollmV2 Text Generator", description="Generate text using the SmollmV2 model", allow_flagging="never", cache_examples=True ) if __name__ == "__main__": demo.launch()