|
|
|
""" |
|
Inference script for SmollmV2 model |
|
Author: Shilpaj Bhalerao |
|
Date: 2025-01-25 |
|
""" |
|
|
|
import torch |
|
from transformers import GPT2Tokenizer |
|
|
|
|
|
from smollv2_lightning import LitSmollmv2 |
|
from config import SmollmConfig, DataConfig |
|
|
|
|
|
def load_model(checkpoint_path): |
|
""" |
|
Load the trained model from checkpoint. |
|
""" |
|
model = LitSmollmv2.load_from_checkpoint( |
|
checkpoint_path, |
|
model_config=SmollmConfig, |
|
strict=False |
|
) |
|
model.eval() |
|
return model |
|
|
|
|
|
def generate_text(model, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9): |
|
""" |
|
Generate text using the loaded model. |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
with torch.no_grad(): |
|
logits, _ = model.model(input_ids) |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
if top_p > 0: |
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
cumsum_probs = torch.cumsum(sorted_probs, dim=-1) |
|
sorted_indices_to_keep = cumsum_probs <= top_p |
|
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone() |
|
sorted_indices_to_keep[..., 0] = 1 |
|
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep) |
|
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs)) |
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
|
|
def main(): |
|
|
|
checkpoint_path = "./checkpoints/last.ckpt" |
|
|
|
|
|
model = load_model(checkpoint_path) |
|
print("Model loaded successfully!") |
|
|
|
|
|
prompts = [ |
|
"Once upon a time", |
|
"The future of artificial intelligence", |
|
"In the distant galaxy" |
|
] |
|
|
|
|
|
for prompt in prompts: |
|
print("\nPrompt:", prompt) |
|
generated = generate_text(prompt=prompt, model=model) |
|
print("Generated:", generated) |
|
print("-" * 50) |
|
|
|
if __name__ == "__main__": |
|
main() |