#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path

# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2


def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
    """
    Combine split model parts into a single checkpoint file
    """
    # Create checkpoints directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Check if combined model already exists
    if os.path.exists(output_file):
        print(f"Model already combined at: {output_file}")
        return output_file
    
    # Ensure the model parts exist
    if not os.path.exists(model_dir):
        raise FileNotFoundError(f"Model directory {model_dir} not found")
    
    # Combine the parts
    parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
    if not parts:
        raise FileNotFoundError("No model parts found")
    
    print("Combining model parts...")
    with open(output_file, 'wb') as outfile:
        for part in parts:
            print(f"Processing part: {part}")
            with open(part, 'rb') as infile:
                outfile.write(infile.read())
    
    print(f"Model combined successfully: {output_file}")
    return output_file

def load_model():
    """
    Load the SmollmV2 model and tokenizer.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Combine model parts and get the checkpoint path
    checkpoint_path = combine_model_parts()
    
    # Load the model from combined checkpoint using Lightning module
    model = LitSmollmv2.load_from_checkpoint(
        checkpoint_path,
        model_config=SmollmConfig,
        strict=False
    )
    
    model.to(device)
    model.eval()
    
    # Initialize tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer, device


@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
    """
    Generate text using the SmollmV2 model.
    """
    # Ensure num_tokens doesn't exceed model's block size
    num_tokens = min(num_tokens, SmollmConfig.block_size)
    
    # Tokenize input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate tokens one at a time
    for _ in range(num_tokens):
        # Get the model's predictions
        with torch.no_grad():
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, _ = model.model(input_ids)
        
        # Get the next token probabilities
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        
        # Apply top-p sampling
        if top_p > 0:
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_keep = cumsum_probs <= top_p
            sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
            sorted_indices_to_keep[..., 0] = 1
            indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
            probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
            probs = probs / probs.sum(dim=-1, keepdim=True)
        
        # Sample next token
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        # Stop if we generate an EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    # Decode and return the generated text
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text

# Load the model globally
model, tokenizer, device = load_model()

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Enter your prompt", value="Once upon a time"),
        gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="SmollmV2 Text Generator",
    description="Generate text using the SmollmV2 model",
    allow_flagging="never",
    cache_examples=True
)

if __name__ == "__main__":
    demo.launch()