Spaces:
Sleeping
Sleeping
File size: 2,840 Bytes
6c2fd08 4ca155b 02cf0bb 85585aa 25893d0 324c98e 181cfde 4ca155b 181cfde 4ca155b 181cfde e60730b 02cf0bb 4ca155b 02cf0bb 5ae24e1 181cfde 5ae24e1 02cf0bb 181cfde 02cf0bb 6c2fd08 85585aa a075fee 02cf0bb 6c2fd08 e60730b 324c98e 181cfde 85585aa 02cf0bb 4ca155b 85585aa 02cf0bb 85585aa 02cf0bb 324c98e c78be87 324c98e 02cf0bb 0be31e9 324c98e 6c2fd08 324c98e e60730b 0be31e9 bdc217e 0be31e9 bdc217e 85585aa 181cfde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import asyncio
# Try to import spaces, use a dummy decorator if not available
try:
import spaces
use_spaces_gpu = True
except ImportError:
use_spaces_gpu = False
# Dummy decorator in case spaces is not available
def dummy_gpu_decorator(func):
return func
spaces = type('', (), {'GPU': dummy_gpu_decorator})()
# ... (keep the model architecture classes as they are)
# Update the load_model function
@spaces.GPU
def load_model(model_path):
config = GPTConfig()
model = GPT(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
model.to(device)
return model
# Load the model
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
enc = tiktoken.get_encoding('gpt2')
# Update the generate_text function
@spaces.GPU(duration=60) # Adjust duration as needed
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
device = next(model.parameters()).device
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
generated = []
with torch.no_grad():
for _ in range(max_length):
outputs, _ = model(input_ids)
next_token_logits = outputs[:, -1, :]
next_token_logits = next_token_logits / temperature
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
next_token_probs = F.softmax(top_k_logits, dim=-1)
next_token_index = torch.multinomial(next_token_probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token_index)
input_ids = torch.cat([input_ids, next_token], dim=-1)
generated.append(next_token.item())
next_token_str = enc.decode([next_token.item()])
yield next_token_str
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
break
await asyncio.sleep(0.02) # Slightly faster typing effect
if len(generated) == max_length:
yield "... (output truncated due to length)"
# Update the gradio_generate function
@spaces.GPU(duration=60) # Adjust duration as needed
async def gradio_generate(prompt, max_length, temperature, top_k):
output = ""
async for token in generate_text(prompt, max_length, temperature, top_k):
output += token
yield output
# The rest of your Gradio interface code remains the same |