import torch from models.gem_model import GEM from utils.data_preprocessing import load_tokenizer from configs.config import MODEL_CONFIG def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7): device = torch.device(MODEL_CONFIG['DEVICE']) input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) generated = model.generate(input_ids, max_length=max_length, temperature=temperature) return tokenizer.decode(generated[0], skip_special_tokens=True) def main(): device = torch.device(MODEL_CONFIG['DEVICE']) tokenizer = load_tokenizer() model = GEM( vocab_size=MODEL_CONFIG['VOCAB_SIZE'], d_model=MODEL_CONFIG['D_MODEL'], n_heads=MODEL_CONFIG['N_HEADS'], d_ff=MODEL_CONFIG['D_FF'], n_layers=MODEL_CONFIG['N_LAYERS'], max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'], dropout=MODEL_CONFIG['DROPOUT'] ).to(device) checkpoint = torch.load('final_model/model.pt') model.load_state_dict(checkpoint['model_state_dict']) model.eval() prompt = "Once upon a time" generated_text = generate_text(model, tokenizer, prompt, max_length=100) print(f"Generated text:\n{generated_text}") if __name__ == "__main__": main()