GPT-2-with_gpu / app.py
sagar007's picture
Update app.py
181cfde verified
raw
history blame
2.84 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import asyncio
# Try to import spaces, use a dummy decorator if not available
try:
import spaces
use_spaces_gpu = True
except ImportError:
use_spaces_gpu = False
# Dummy decorator in case spaces is not available
def dummy_gpu_decorator(func):
return func
spaces = type('', (), {'GPU': dummy_gpu_decorator})()
# ... (keep the model architecture classes as they are)
# Update the load_model function
@spaces.GPU
def load_model(model_path):
config = GPTConfig()
model = GPT(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
return model
# Load the model
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
enc = tiktoken.get_encoding('gpt2')
# Update the generate_text function
@spaces.GPU(duration=60) # Adjust duration as needed
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
device = next(model.parameters()).device
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
generated = []
with torch.no_grad():
for _ in range(max_length):
outputs, _ = model(input_ids)
next_token_logits = outputs[:, -1, :]
next_token_logits = next_token_logits / temperature
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
next_token_probs = F.softmax(top_k_logits, dim=-1)
next_token_index = torch.multinomial(next_token_probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token_index)
input_ids = torch.cat([input_ids, next_token], dim=-1)
generated.append(next_token.item())
next_token_str = enc.decode([next_token.item()])
yield next_token_str
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
break
await asyncio.sleep(0.02) # Slightly faster typing effect
if len(generated) == max_length:
yield "... (output truncated due to length)"
# Update the gradio_generate function
@spaces.GPU(duration=60) # Adjust duration as needed
async def gradio_generate(prompt, max_length, temperature, top_k):
output = ""
async for token in generate_text(prompt, max_length, temperature, top_k):
output += token
yield output
# The rest of your Gradio interface code remains the same