import torch import sys import os # Add the parent directory of the model folder to the system path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../models'))) from gem_model import GEM # Configuration parameters for GEM vocab_size = 50001 # Example vocab size, adjust if necessary d_model = 1024 # Dimension of the model n_heads = 32 # Number of attention heads d_ff = 4096 # Dimension of the feedforward network n_layers = 32 # Number of transformer layers dropout = 0.1 # Dropout rate # Initialize the model model = GEM(vocab_size, d_model, n_heads, d_ff, n_layers, dropout) # Load pre-trained weights model_path = '/content/drive/MyDrive/GEM_Project/GEM_1o_Aug_15.pt' model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Set the model to evaluation mode model.eval() # Define a function to convert text to token IDs (example) def text_to_ids(tokenizer, text): # Implement this function based on your tokenizer's method return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) # Define a tokenizer or token conversion function (example placeholder) class DummyTokenizer: def tokenize(self, text): # Simple tokenization example, replace with actual tokenizer return text.split() def convert_tokens_to_ids(self, tokens): # Simple mapping example, replace with actual ID mapping return [ord(token[0]) % 50000 for token in tokens] # Initialize tokenizer tokenizer = DummyTokenizer() # Test input test_prompt = "This is a test." test_input_ids = torch.tensor(text_to_ids(tokenizer, test_prompt), dtype=torch.long).unsqueeze(0) # Add batch dimension attention_mask = torch.ones(test_input_ids.shape, dtype=torch.bool) # Perform a forward pass with torch.no_grad(): outputs = model(test_input_ids, attention_mask) print("Model outputs:") print(outputs) # Test the generate method generation_prompt = "Once upon a time" input_ids = torch.tensor(text_to_ids(tokenizer, generation_prompt), dtype=torch.long).unsqueeze(0) # Add batch dimension generated_ids = model.generate(input_ids, max_length=10, temperature=1.0) print("Generated IDs:") print(generated_ids)