import torch from tokenizers import Tokenizer from pathlib import Path from config import get_config, get_weights_file_path from train import get_model # Load tokenizer def get_tokenizer(config) -> Tokenizer: tokenizers_path = Path(config['tokenizer_file']) if Path.exists(tokenizers_path): print("Loading tokenizer from", tokenizers_path) tokenizer = Tokenizer.from_file(str(tokenizers_path)) return tokenizer else: raise FileNotFoundError("Can't find tokenizer file:", tokenizers_path) # Setup config config = get_config("./openweb.config.json") device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = get_tokenizer(config) # Token IDs pad_token_id = tokenizer.token_to_id("") eos_token_id = tokenizer.token_to_id("") user_token_id = tokenizer.token_to_id("") ai_token_id = tokenizer.token_to_id("") # Load model model = get_model(config, tokenizer.get_vocab_size()).to(device) model_path = get_weights_file_path(config, config['preload']) model.eval() state = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(state['model_state_dict']) # Streaming text generation def generate_response(prompt: str, history): #i dont train with history . so i not use it input_tokens = tokenizer.encode(prompt).ids input_tokens = [user_token_id] + input_tokens + [ai_token_id] if len(input_tokens) > config['seq_len']: yield "Prompt too long." return input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device) temperature = 0.5 top_k = 20 generated_text = "" i = 0 while input_tokens.shape[1] < 2000: out = model.decode(input_tokens) logits = model.project(out[:, -1]) logits = logits / temperature top_k_logits, top_k_indices = torch.topk(logits, top_k) probs = torch.softmax(top_k_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token) word = tokenizer.decode([next_token.item()]) generated_text += word yield generated_text # ✅ plain string for ChatInterface input_tokens = torch.cat([input_tokens, next_token], dim=1) if input_tokens.shape[1] > config['seq_len']: input_tokens = input_tokens[:, -config['seq_len']:] if next_token.item() == eos_token_id or i >= 1024: break i += 1