Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import tiktoken | |
| import gradio as gr | |
| import asyncio | |
| # Try to import spaces, use a dummy decorator if not available | |
| try: | |
| import spaces | |
| use_spaces_gpu = True | |
| except ImportError: | |
| use_spaces_gpu = False | |
| # Dummy decorator in case spaces is not available | |
| def dummy_gpu_decorator(func): | |
| return func | |
| spaces = type('', (), {'GPU': dummy_gpu_decorator})() | |
| # ... (keep the model architecture classes as they are) | |
| # Update the load_model function | |
| def load_model(model_path): | |
| config = GPTConfig() | |
| model = GPT(config) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| model.to(device) | |
| return model | |
| # Load the model | |
| model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file | |
| enc = tiktoken.get_encoding('gpt2') | |
| # Update the generate_text function | |
| # Adjust duration as needed | |
| async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40): | |
| device = next(model.parameters()).device | |
| input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device) | |
| generated = [] | |
| with torch.no_grad(): | |
| for _ in range(max_length): | |
| outputs, _ = model(input_ids) | |
| next_token_logits = outputs[:, -1, :] | |
| next_token_logits = next_token_logits / temperature | |
| top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) | |
| next_token_probs = F.softmax(top_k_logits, dim=-1) | |
| next_token_index = torch.multinomial(next_token_probs, num_samples=1) | |
| next_token = top_k_indices.gather(-1, next_token_index) | |
| input_ids = torch.cat([input_ids, next_token], dim=-1) | |
| generated.append(next_token.item()) | |
| next_token_str = enc.decode([next_token.item()]) | |
| yield next_token_str | |
| if next_token.item() == enc.encode('\n')[0] and len(generated) > 100: | |
| break | |
| await asyncio.sleep(0.02) # Slightly faster typing effect | |
| if len(generated) == max_length: | |
| yield "... (output truncated due to length)" | |
| # Update the gradio_generate function | |
| # Adjust duration as needed | |
| async def gradio_generate(prompt, max_length, temperature, top_k): | |
| output = "" | |
| async for token in generate_text(prompt, max_length, temperature, top_k): | |
| output += token | |
| yield output | |
| # The rest of your Gradio interface code remains the same |