import torch import torch.nn as nn from torch.nn import functional as F import tiktoken import gradio as gr import asyncio # Try to import spaces, use a dummy decorator if not available try: import spaces use_spaces_gpu = True except ImportError: use_spaces_gpu = False # Dummy decorator in case spaces is not available def dummy_gpu_decorator(func): return func spaces = type('', (), {'GPU': dummy_gpu_decorator})() # ... (keep the model architecture classes as they are) # Update the load_model function @spaces.GPU def load_model(model_path): config = GPTConfig() model = GPT(config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(model_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.eval() model.to(device) return model # Load the model model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file enc = tiktoken.get_encoding('gpt2') # Update the generate_text function @spaces.GPU(duration=60) # Adjust duration as needed async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40): device = next(model.parameters()).device input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device) generated = [] with torch.no_grad(): for _ in range(max_length): outputs, _ = model(input_ids) next_token_logits = outputs[:, -1, :] next_token_logits = next_token_logits / temperature top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) next_token_probs = F.softmax(top_k_logits, dim=-1) next_token_index = torch.multinomial(next_token_probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token_index) input_ids = torch.cat([input_ids, next_token], dim=-1) generated.append(next_token.item()) next_token_str = enc.decode([next_token.item()]) yield next_token_str if next_token.item() == enc.encode('\n')[0] and len(generated) > 100: break await asyncio.sleep(0.02) # Slightly faster typing effect if len(generated) == max_length: yield "... (output truncated due to length)" # Update the gradio_generate function @spaces.GPU(duration=60) # Adjust duration as needed async def gradio_generate(prompt, max_length, temperature, top_k): output = "" async for token in generate_text(prompt, max_length, temperature, top_k): output += token yield output # The rest of your Gradio interface code remains the same