File size: 1,252 Bytes
d18eb09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
from models.gem_model import GEM
from utils.data_preprocessing import load_tokenizer
from configs.config import MODEL_CONFIG
def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7):
device = torch.device(MODEL_CONFIG['DEVICE'])
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated = model.generate(input_ids, max_length=max_length, temperature=temperature)
return tokenizer.decode(generated[0], skip_special_tokens=True)
def main():
device = torch.device(MODEL_CONFIG['DEVICE'])
tokenizer = load_tokenizer()
model = GEM(
vocab_size=MODEL_CONFIG['VOCAB_SIZE'],
d_model=MODEL_CONFIG['D_MODEL'],
n_heads=MODEL_CONFIG['N_HEADS'],
d_ff=MODEL_CONFIG['D_FF'],
n_layers=MODEL_CONFIG['N_LAYERS'],
max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'],
dropout=MODEL_CONFIG['DROPOUT']
).to(device)
checkpoint = torch.load('final_model/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
prompt = "Once upon a time"
generated_text = generate_text(model, tokenizer, prompt, max_length=100)
print(f"Generated text:\n{generated_text}")
if __name__ == "__main__":
main()
|