|
|
|
import torch |
|
import sys |
|
import os |
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../models'))) |
|
|
|
from gem_model import GEM |
|
|
|
|
|
vocab_size = 50001 |
|
d_model = 1024 |
|
n_heads = 32 |
|
d_ff = 4096 |
|
n_layers = 32 |
|
dropout = 0.1 |
|
|
|
|
|
model = GEM(vocab_size, d_model, n_heads, d_ff, n_layers, dropout) |
|
|
|
|
|
model_path = '/content/drive/MyDrive/GEM_Project/GEM_1o_Aug_15.pt' |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def text_to_ids(tokenizer, text): |
|
|
|
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) |
|
|
|
|
|
class DummyTokenizer: |
|
def tokenize(self, text): |
|
|
|
return text.split() |
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
|
|
return [ord(token[0]) % 50000 for token in tokens] |
|
|
|
|
|
tokenizer = DummyTokenizer() |
|
|
|
|
|
test_prompt = "This is a test." |
|
test_input_ids = torch.tensor(text_to_ids(tokenizer, test_prompt), dtype=torch.long).unsqueeze(0) |
|
attention_mask = torch.ones(test_input_ids.shape, dtype=torch.bool) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(test_input_ids, attention_mask) |
|
print("Model outputs:") |
|
print(outputs) |
|
|
|
|
|
generation_prompt = "Once upon a time" |
|
input_ids = torch.tensor(text_to_ids(tokenizer, generation_prompt), dtype=torch.long).unsqueeze(0) |
|
generated_ids = model.generate(input_ids, max_length=10, temperature=1.0) |
|
print("Generated IDs:") |
|
print(generated_ids) |
|
|