Text Generation
English
instruction-following
reasoning
File size: 1,252 Bytes
d18eb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

import torch
from models.gem_model import GEM
from utils.data_preprocessing import load_tokenizer
from configs.config import MODEL_CONFIG

def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7):
    device = torch.device(MODEL_CONFIG['DEVICE'])
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    generated = model.generate(input_ids, max_length=max_length, temperature=temperature)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

def main():
    device = torch.device(MODEL_CONFIG['DEVICE'])

    tokenizer = load_tokenizer()

    model = GEM(
        vocab_size=MODEL_CONFIG['VOCAB_SIZE'],
        d_model=MODEL_CONFIG['D_MODEL'],
        n_heads=MODEL_CONFIG['N_HEADS'],
        d_ff=MODEL_CONFIG['D_FF'],
        n_layers=MODEL_CONFIG['N_LAYERS'],
        max_seq_len=MODEL_CONFIG['MAX_SEQ_LEN'],
        dropout=MODEL_CONFIG['DROPOUT']
    ).to(device)

    checkpoint = torch.load('final_model/model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    prompt = "Once upon a time"
    generated_text = generate_text(model, tokenizer, prompt, max_length=100)
    print(f"Generated text:\n{generated_text}")

if __name__ == "__main__":
    main()