Chaitanya Sagar Gurujula
added initial version
e59a831
raw
history blame
3.17 kB
import torch
import os
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import PreTrainedModel, AutoConfig
from huggingface_hub import hf_hub_download
import tiktoken
from model import GPT, GPTConfig
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
# Get the absolute path to the templates directory
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
MODEL_ID = "sagargurujula/text-generator"
# Initialize FastAPI
app = FastAPI(title="GPT Text Generator")
# Templates with absolute path
templates = Jinja2Templates(directory=TEMPLATES_DIR)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model from Hugging Face Hub
def load_model():
try:
# Download the model file from HF Hub
model_path = hf_hub_download(
repo_id=MODEL_ID,
filename="best_model.pth"
)
# Initialize our custom GPT model
model = GPT(GPTConfig())
# Load the state dict
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
raise
# Load the model
model = load_model()
# Define the request body
class TextInput(BaseModel):
text: str
@app.post("/generate/")
async def generate_text(input: TextInput):
# Prepare input tensor
enc = tiktoken.get_encoding('gpt2')
input_ids = torch.tensor([enc.encode(input.text)]).to(device)
# Generate multiple tokens
generated_tokens = []
num_tokens_to_generate = 50 # Generate 20 new tokens
with torch.no_grad():
current_ids = input_ids
for _ in range(num_tokens_to_generate):
# Get model predictions
logits, _ = model(current_ids)
next_token = logits[0, -1, :].argmax().item()
generated_tokens.append(next_token)
# Add the new token to our current sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
# Decode all generated tokens
generated_text = enc.decode(generated_tokens)
# Return both input and generated text
return {
"input_text": input.text,
"generated_text": generated_text
}
# Modify the root route to serve the template
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request, "title": "GPT Text Generator"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8080)
# To run the app, use the command: uvicorn app:app --reload