File size: 4,992 Bytes
f42f624 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
#! /usr/bin/env python3
"""
This script is a simple text generator using the SmollmV2 model.
It uses Gradio to create a web interface for generating text.
"""
# Third-Party Imports
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import GPT2Tokenizer
import spaces
import os
from pathlib import Path
# Local imports
from smollmv2 import SmollmV2
from config import SmollmConfig, DataConfig
from smollv2_lightning import LitSmollmv2
def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
"""
Combine split model parts into a single checkpoint file
"""
# Create checkpoints directory if it doesn't exist
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Check if combined model already exists
if os.path.exists(output_file):
print(f"Model already combined at: {output_file}")
return output_file
# Ensure the model parts exist
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory {model_dir} not found")
# Combine the parts
parts = sorted(Path(model_dir).glob("last.ckpt.part_*"))
if not parts:
raise FileNotFoundError("No model parts found")
print("Combining model parts...")
with open(output_file, 'wb') as outfile:
for part in parts:
print(f"Processing part: {part}")
with open(part, 'rb') as infile:
outfile.write(infile.read())
print(f"Model combined successfully: {output_file}")
return output_file
def load_model():
"""
Load the SmollmV2 model and tokenizer.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Combine model parts and get the checkpoint path
checkpoint_path = combine_model_parts()
# Load the model from combined checkpoint using Lightning module
model = LitSmollmv2.load_from_checkpoint(
checkpoint_path,
model_config=SmollmConfig,
strict=False
)
model.to(device)
model.eval()
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, device
@spaces.GPU(enable_queue=True)
def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
"""
Generate text using the SmollmV2 model.
"""
# Ensure num_tokens doesn't exceed model's block size
num_tokens = min(num_tokens, SmollmConfig.block_size)
# Tokenize input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate tokens one at a time
for _ in range(num_tokens):
# Get the model's predictions
with torch.no_grad():
with torch.autocast(device_type=device, dtype=torch.bfloat16):
logits, _ = model.model(input_ids)
# Get the next token probabilities
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# Apply top-p sampling
if top_p > 0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_keep = cumsum_probs <= top_p
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
sorted_indices_to_keep[..., 0] = 1
indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1)
# Append to input_ids
input_ids = torch.cat([input_ids, next_token], dim=-1)
# Stop if we generate an EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode and return the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return generated_text
# Load the model globally
model, tokenizer, device = load_model()
# Create the Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Enter your prompt", value="Once upon a time"),
gr.Slider(minimum=1, maximum=SmollmConfig.block_size, value=100, step=1, label="Number of tokens to generate"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more random)"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
],
outputs=gr.Textbox(label="Generated Text"),
title="SmollmV2 Text Generator",
description="Generate text using the SmollmV2 model",
allow_flagging="never",
cache_examples=True
)
if __name__ == "__main__":
demo.launch() |