File size: 4,992 Bytes
f42f624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path

# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2


def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
    """
    Combine split model parts into a single checkpoint file
    """
    # Create checkpoints directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Check if combined model already exists
    if os.path.exists(output_file):
        print(f"Model already combined at: {output_file}")
        return output_file
    
    # Ensure the model parts exist
    if not os.path.exists(model_dir):
        raise FileNotFoundError(f"Model directory {model_dir} not found")
    
    # Combine the parts
    parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
    if not parts:
        raise FileNotFoundError("No model parts found")
    
    print("Combining model parts...")
    with open(output_file, 'wb') as outfile:
        for part in parts:
            print(f"Processing part: {part}")
            with open(part, 'rb') as infile:
                outfile.write(infile.read())
    
    print(f"Model combined successfully: {output_file}")
    return output_file

def load_model():
    """
    Load the SmollmV2 model and tokenizer.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Combine model parts and get the checkpoint path
    checkpoint_path = combine_model_parts()
    
    # Load the model from combined checkpoint using Lightning module
    model = LitSmollmv2.load_from_checkpoint(
        checkpoint_path,
        model_config=SmollmConfig,
        strict=False
    )
    
    model.to(device)
    model.eval()
    
    # Initialize tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer, device


@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
    """
    Generate text using the SmollmV2 model.
    """
    # Ensure num_tokens doesn't exceed model's block size
    num_tokens = min(num_tokens, SmollmConfig.block_size)
    
    # Tokenize input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate tokens one at a time
    for _ in range(num_tokens):
        # Get the model's predictions
        with torch.no_grad():
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, _ = model.model(input_ids)
        
        # Get the next token probabilities
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        
        # Apply top-p sampling
        if top_p > 0:
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_keep = cumsum_probs <= top_p
            sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
            sorted_indices_to_keep[..., 0] = 1
            indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
            probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
            probs = probs / probs.sum(dim=-1, keepdim=True)
        
        # Sample next token
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        # Stop if we generate an EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    # Decode and return the generated text
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text

# Load the model globally
model, tokenizer, device = load_model()

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Enter your prompt", value="Once upon a time"),
        gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="SmollmV2 Text Generator",
    description="Generate text using the SmollmV2 model",
    allow_flagging="never",
    cache_examples=True
)

if __name__ == "__main__":
    demo.launch()