Spaces:
Sleeping
Sleeping
File size: 3,174 Bytes
e59a831 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import os
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import PreTrainedModel, AutoConfig
from huggingface_hub import hf_hub_download
import tiktoken
from model import GPT, GPTConfig
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
# Get the absolute path to the templates directory
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
MODEL_ID = "sagargurujula/text-generator"
# Initialize FastAPI
app = FastAPI(title="GPT Text Generator")
# Templates with absolute path
templates = Jinja2Templates(directory=TEMPLATES_DIR)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model from Hugging Face Hub
def load_model():
try:
# Download the model file from HF Hub
model_path = hf_hub_download(
repo_id=MODEL_ID,
filename="best_model.pth"
)
# Initialize our custom GPT model
model = GPT(GPTConfig())
# Load the state dict
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
raise
# Load the model
model = load_model()
# Define the request body
class TextInput(BaseModel):
text: str
@app.post("/generate/")
async def generate_text(input: TextInput):
# Prepare input tensor
enc = tiktoken.get_encoding('gpt2')
input_ids = torch.tensor([enc.encode(input.text)]).to(device)
# Generate multiple tokens
generated_tokens = []
num_tokens_to_generate = 50 # Generate 20 new tokens
with torch.no_grad():
current_ids = input_ids
for _ in range(num_tokens_to_generate):
# Get model predictions
logits, _ = model(current_ids)
next_token = logits[0, -1, :].argmax().item()
generated_tokens.append(next_token)
# Add the new token to our current sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
# Decode all generated tokens
generated_text = enc.decode(generated_tokens)
# Return both input and generated text
return {
"input_text": input.text,
"generated_text": generated_text
}
# Modify the root route to serve the template
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request, "title": "GPT Text Generator"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8080)
# To run the app, use the command: uvicorn app:app --reload |