File size: 2,453 Bytes
82f9e44 ec30812 7e1aa1c ec30812 7e1aa1c ec30812 17a3eb0 ec30812 17a3eb0 ec30812 17a3eb0 ec30812 17a3eb0 ec30812 17a3eb0 ec30812 4adbe84 e6bd9b6 d727d22 82f9e44 ec30812 d727d22 419d496 28553d3 419d496 ec30812 d727d22 82f9e44 d727d22 419d496 ec30812 d727d22 ec30812 d727d22 82f9e44 ec30812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
from tokenizers import Tokenizer
from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model
# Load tokenizer
def get_tokenizer(config) -> Tokenizer:
tokenizers_path = Path(config['tokenizer_file'])
if Path.exists(tokenizers_path):
print("Loading tokenizer from", tokenizers_path)
tokenizer = Tokenizer.from_file(str(tokenizers_path))
return tokenizer
else:
raise FileNotFoundError("Can't find tokenizer file:", tokenizers_path)
# Setup config
config = get_config("./openweb.config.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)
# Token IDs
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
user_token_id = tokenizer.token_to_id("<user>")
ai_token_id = tokenizer.token_to_id("<ai>")
# Load model
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config, config['preload'])
model.eval()
state = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])
# Streaming text generation
def generate_response(prompt: str, history):
#i dont train with history . so i not use it
input_tokens = tokenizer.encode(prompt).ids
input_tokens = [user_token_id] + input_tokens + [ai_token_id]
if len(input_tokens) > config['seq_len']:
yield "Prompt too long."
return
input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
temperature = 0.5
top_k = 20
generated_text = ""
i = 0
while input_tokens.shape[1] < 2000:
out = model.decode(input_tokens)
logits = model.project(out[:, -1])
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
word = tokenizer.decode([next_token.item()])
generated_text += word
yield generated_text # ✅ plain string for ChatInterface
input_tokens = torch.cat([input_tokens, next_token], dim=1)
if input_tokens.shape[1] > config['seq_len']:
input_tokens = input_tokens[:, -config['seq_len']:]
if next_token.item() == eos_token_id or i >= 1024:
break
i += 1
|