|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import sys | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../models'))) | 
					
						
						|  |  | 
					
						
						|  | from gem_model import GEM | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vocab_size = 50001 | 
					
						
						|  | d_model = 1024 | 
					
						
						|  | n_heads = 32 | 
					
						
						|  | d_ff = 4096 | 
					
						
						|  | n_layers = 32 | 
					
						
						|  | dropout = 0.1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = GEM(vocab_size, d_model, n_heads, d_ff, n_layers, dropout) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_path = '/content/drive/MyDrive/GEM_Project/GEM_1o_Aug_15.pt' | 
					
						
						|  | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.eval() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def text_to_ids(tokenizer, text): | 
					
						
						|  |  | 
					
						
						|  | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DummyTokenizer: | 
					
						
						|  | def tokenize(self, text): | 
					
						
						|  |  | 
					
						
						|  | return text.split() | 
					
						
						|  |  | 
					
						
						|  | def convert_tokens_to_ids(self, tokens): | 
					
						
						|  |  | 
					
						
						|  | return [ord(token[0]) % 50000 for token in tokens] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tokenizer = DummyTokenizer() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | test_prompt = "This is a test." | 
					
						
						|  | test_input_ids = torch.tensor(text_to_ids(tokenizer, test_prompt), dtype=torch.long).unsqueeze(0) | 
					
						
						|  | attention_mask = torch.ones(test_input_ids.shape, dtype=torch.bool) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | outputs = model(test_input_ids, attention_mask) | 
					
						
						|  | print("Model outputs:") | 
					
						
						|  | print(outputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | generation_prompt = "Once upon a time" | 
					
						
						|  | input_ids = torch.tensor(text_to_ids(tokenizer, generation_prompt), dtype=torch.long).unsqueeze(0) | 
					
						
						|  | generated_ids = model.generate(input_ids, max_length=10, temperature=1.0) | 
					
						
						|  | print("Generated IDs:") | 
					
						
						|  | print(generated_ids) | 
					
						
						|  |  |