File size: 3,174 Bytes
e59a831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import os
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import PreTrainedModel, AutoConfig
from huggingface_hub import hf_hub_download
import tiktoken
from model import GPT, GPTConfig
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

# Get the absolute path to the templates directory
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")

MODEL_ID = "sagargurujula/text-generator"

# Initialize FastAPI
app = FastAPI(title="GPT Text Generator")

# Templates with absolute path
templates = Jinja2Templates(directory=TEMPLATES_DIR)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model from Hugging Face Hub
def load_model():
    try:
        # Download the model file from HF Hub
        model_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename="best_model.pth"
        )
        
        # Initialize our custom GPT model
        model = GPT(GPTConfig())
        
        # Load the state dict
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        model.to(device)
        model.eval()
        return model
    
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

# Load the model
model = load_model()

# Define the request body
class TextInput(BaseModel):
    text: str

@app.post("/generate/")
async def generate_text(input: TextInput):
    # Prepare input tensor
    enc = tiktoken.get_encoding('gpt2')
    input_ids = torch.tensor([enc.encode(input.text)]).to(device)
    
    # Generate multiple tokens
    generated_tokens = []
    num_tokens_to_generate = 50  # Generate 20 new tokens
    
    with torch.no_grad():
        current_ids = input_ids
        
        for _ in range(num_tokens_to_generate):
            # Get model predictions
            logits, _ = model(current_ids)
            next_token = logits[0, -1, :].argmax().item()
            generated_tokens.append(next_token)
            
            # Add the new token to our current sequence
            current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
    
    # Decode all generated tokens
    generated_text = enc.decode(generated_tokens)
    
    # Return both input and generated text
    return {
        "input_text": input.text,
        "generated_text": generated_text
    }

# Modify the root route to serve the template
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    return templates.TemplateResponse(
        "index.html", 
        {"request": request, "title": "GPT Text Generator"}
    )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="127.0.0.1", port=8080)

# To run the app, use the command: uvicorn app:app --reload