SmoLLMv2 / inference.py
Shilpaj's picture
Feat: Upload app files
f42f624 verified
raw
history blame
3.23 kB
#! /usr/bin/env python
"""
Inference script for SmollmV2 model
Author: Shilpaj Bhalerao
Date: 2025-01-25
"""
# Third-Party Imports
import torch
from transformers import GPT2Tokenizer
# Local Imports
from smollv2_lightning import LitSmollmv2
from config import SmollmConfig, DataConfig
def load_model(checkpoint_path):
"""
Load the trained model from checkpoint.
"""
model = LitSmollmv2.load_from_checkpoint(
checkpoint_path,
model_config=SmollmConfig,
strict=False
)
model.eval()
return model
def generate_text(model, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9):
"""
Generate text using the loaded model.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Initialize tokenizer the same way as in CosmopediaDataModule
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
# Tokenize input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate tokens one at a time
for _ in range(max_new_tokens):
# Get the model's predictions
with torch.no_grad():
logits, _ = model.model(input_ids)
# Get the next token probabilities
logits = logits[:, -1, :] / temperature
probs = torch.nn.functional.softmax(logits, dim=-1)
# Sample from the distribution
if top_p > 0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_keep = cumsum_probs <= top_p
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
sorted_indices_to_keep[..., 0] = 1
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
# Append to input_ids
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if we generate an EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode and return the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
def main():
# Path to your checkpoint
checkpoint_path = "./checkpoints/last.ckpt"
# Load the model
model = load_model(checkpoint_path)
print("Model loaded successfully!")
# Example prompts for generation
prompts = [
"Once upon a time",
"The future of artificial intelligence",
"In the distant galaxy"
]
# Generate text for each prompt
for prompt in prompts:
print("\nPrompt:", prompt)
generated = generate_text(prompt=prompt, model=model)
print("Generated:", generated)
print("-" * 50)
if __name__ == "__main__":
main()