10M-LLM / inference_fine_tune.py
abancp's picture
Update inference_fine_tune.py
28553d3 verified
import torch
from tokenizers import Tokenizer
from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model
# Load tokenizer
def get_tokenizer(config) -> Tokenizer:
tokenizers_path = Path(config['tokenizer_file'])
if Path.exists(tokenizers_path):
print("Loading tokenizer from", tokenizers_path)
tokenizer = Tokenizer.from_file(str(tokenizers_path))
return tokenizer
else:
raise FileNotFoundError("Can't find tokenizer file:", tokenizers_path)
# Setup config
config = get_config("./openweb.config.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)
# Token IDs
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
user_token_id = tokenizer.token_to_id("<user>")
ai_token_id = tokenizer.token_to_id("<ai>")
# Load model
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config, config['preload'])
model.eval()
state = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])
# Streaming text generation
def generate_response(prompt: str, history):
#i dont train with history . so i not use it
input_tokens = tokenizer.encode(prompt).ids
input_tokens = [user_token_id] + input_tokens + [ai_token_id]
if len(input_tokens) > config['seq_len']:
yield "Prompt too long."
return
input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
temperature = 0.5
top_k = 20
generated_text = ""
i = 0
while input_tokens.shape[1] < 2000:
out = model.decode(input_tokens)
logits = model.project(out[:, -1])
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
word = tokenizer.decode([next_token.item()])
generated_text += word
yield generated_text # ✅ plain string for ChatInterface
input_tokens = torch.cat([input_tokens, next_token], dim=1)
if input_tokens.shape[1] > config['seq_len']:
input_tokens = input_tokens[:, -config['seq_len']:]
if next_token.item() == eos_token_id or i >= 1024:
break
i += 1