File size: 2,840 Bytes
6c2fd08
4ca155b
 
02cf0bb
85585aa
25893d0
324c98e
181cfde
 
 
 
 
 
 
 
 
 
4ca155b
181cfde
4ca155b
181cfde
e60730b
02cf0bb
4ca155b
02cf0bb
5ae24e1
181cfde
 
5ae24e1
 
 
 
 
 
02cf0bb
181cfde
02cf0bb
6c2fd08
85585aa
a075fee
02cf0bb
6c2fd08
e60730b
 
324c98e
181cfde
 
85585aa
02cf0bb
 
 
4ca155b
85585aa
 
 
 
 
 
 
02cf0bb
85585aa
02cf0bb
324c98e
 
c78be87
324c98e
02cf0bb
0be31e9
324c98e
6c2fd08
324c98e
 
e60730b
 
 
0be31e9
bdc217e
0be31e9
bdc217e
 
85585aa
181cfde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import asyncio

# Try to import spaces, use a dummy decorator if not available
try:
    import spaces
    use_spaces_gpu = True
except ImportError:
    use_spaces_gpu = False
    # Dummy decorator in case spaces is not available
    def dummy_gpu_decorator(func):
        return func
    spaces = type('', (), {'GPU': dummy_gpu_decorator})()

# ... (keep the model architecture classes as they are)

# Update the load_model function
@spaces.GPU
def load_model(model_path):
    config = GPTConfig()
    model = GPT(config)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(model_path, map_location=device)
    
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.eval()
    model.to(device)
    return model

# Load the model
model = load_model('gpt_model.pth')  # Replace with the actual path to your .pt file
enc = tiktoken.get_encoding('gpt2')

# Update the generate_text function
@spaces.GPU(duration=60)  # Adjust duration as needed
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
    device = next(model.parameters()).device
    input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
    generated = []
    
    with torch.no_grad():
        for _ in range(max_length):
            outputs, _ = model(input_ids)
            next_token_logits = outputs[:, -1, :]
            next_token_logits = next_token_logits / temperature
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            next_token_probs = F.softmax(top_k_logits, dim=-1)
            next_token_index = torch.multinomial(next_token_probs, num_samples=1)
            next_token = top_k_indices.gather(-1, next_token_index)
            
            input_ids = torch.cat([input_ids, next_token], dim=-1)
            generated.append(next_token.item())
            
            next_token_str = enc.decode([next_token.item()])
            yield next_token_str
            
            if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
                break
            
            await asyncio.sleep(0.02)  # Slightly faster typing effect

    if len(generated) == max_length:
        yield "... (output truncated due to length)"

# Update the gradio_generate function
@spaces.GPU(duration=60)  # Adjust duration as needed
async def gradio_generate(prompt, max_length, temperature, top_k):
    output = ""
    async for token in generate_text(prompt, max_length, temperature, top_k):
        output += token
        yield output

# The rest of your Gradio interface code remains the same