File size: 3,227 Bytes
f42f624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#! /usr/bin/env python
"""
Inference script for SmollmV2 model
Author: Shilpaj Bhalerao
Date: 2025-01-25
"""
# Third-Party Imports
import torch
from transformers import GPT2Tokenizer

# Local Imports
from smollv2_lightning import LitSmollmv2
from config import SmollmConfig, DataConfig


def load_model(checkpoint_path):
    """
    Load the trained model from checkpoint.
    """
    model = LitSmollmv2.load_from_checkpoint(
        checkpoint_path,
        model_config=SmollmConfig,
        strict=False
    )
    model.eval()
    return model


def generate_text(model, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9):
    """
    Generate text using the loaded model.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Initialize tokenizer the same way as in CosmopediaDataModule
    tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Tokenize input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Generate tokens one at a time
    for _ in range(max_new_tokens):
        # Get the model's predictions
        with torch.no_grad():
            logits, _ = model.model(input_ids)
        
        # Get the next token probabilities
        logits = logits[:, -1, :] / temperature
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
        # Sample from the distribution
        if top_p > 0:
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_keep = cumsum_probs <= top_p
            sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
            sorted_indices_to_keep[..., 0] = 1
            indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
            probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
            probs = probs / probs.sum(dim=-1, keepdim=True)
        
        # Sample next token
        next_token = torch.multinomial(probs, num_samples=1)
        
        # Append to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        
        # Stop if we generate an EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    # Decode and return the generated text
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return generated_text


def main():
    # Path to your checkpoint
    checkpoint_path = "./checkpoints/last.ckpt"
    
    # Load the model
    model = load_model(checkpoint_path)
    print("Model loaded successfully!")
    
    # Example prompts for generation
    prompts = [
        "Once upon a time",
        "The future of artificial intelligence",
        "In the distant galaxy"
    ]
    
    # Generate text for each prompt
    for prompt in prompts:
        print("\nPrompt:", prompt)
        generated = generate_text(prompt=prompt, model=model)
        print("Generated:", generated)
        print("-" * 50)

if __name__ == "__main__":
    main()