scriptify-api / main.py
henok3878
remove unused vars
9221ec9
from typing import Optional
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import torch
import torch.nn.functional as F
from pathlib import Path
import logging
import time
from contextlib import asynccontextmanager
from inference_utils import PrimingData, construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map, load_priming_data
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_DIR = Path("./packaged_models")
QUANTIZED_MODEL_NAME = "model.scripted.quantized.pt"
SCRIPTED_MODEL_NAME = "model.scripted.pt"
METADATA_MODEL_NAME = "model.pt"
SCRIPTED_MODEL: Optional[torch.jit.ScriptModule] = None
MODEL_METADATA: Optional[dict] = None
DEVICE: Optional[torch.device] = None
ALPHABET_MAP: Optional[dict[str, int]] = None
ALPHABET_LIST: Optional[list[str]] = None
ALPHABET_SIZE: Optional[int] = None
MAX_TEXT_LEN: Optional[int] = None
output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling
lstm_size: Optional[int] = None
attention_mixture_components: Optional[int] = None
class HandwritingRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=40, description="Text to generate handwriting for")
max_length: int = Field(default=1000, ge=50, le=1200, description="Maximum number of stroke points")
bias: float = Field(default=0.75, ge=0.1, le=10.0, description="Sampling bias for generation")
class HandwritingResponse(BaseModel):
success: bool = True
input_text: str
generation_time_ms: float
num_points: int
strokes: list[list[float]]
message: str = "Successfully generated handwriting."
class HealthResponse(BaseModel):
status: str
model_loaded: bool
device: str
model_metadata_keys: Optional[list[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE
logger.info("Attempting to load model resources during startup")
try:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")
scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME
metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME
# if DEVICE.type == "cpu":
# scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME
if not scripted_model_path.exists():
logger.error(f"Traced model not found at {scripted_model_path}")
raise FileNotFoundError(f"Traced model not found at {scripted_model_path}")
if not metadata_model_path or not metadata_model_path.exists():
logger.error(f"Metadata model file not found at {metadata_model_path}")
raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}")
# Load the traced model
SCRIPTED_MODEL = torch.jit.load(scripted_model_path, map_location=DEVICE)
if SCRIPTED_MODEL:
SCRIPTED_MODEL.eval()
logger.info(f"Traced model loaded successfully from {scripted_model_path}")
# Load the metadata
MODEL_METADATA = torch.load(metadata_model_path, map_location='cpu')
if MODEL_METADATA:
logger.info(f"Model metadata loaded successfully from {metadata_model_path}")
logger.info(f"Model metadata keys: {list(MODEL_METADATA.keys())}")
config_full = MODEL_METADATA['config_full']
if not config_full or not isinstance(config_full, dict):
raise ValueError(f"Key `config_full` not found or not a dict")
dataset_config = config_full['dataset']
model_params = config_full['model_params']
if not dataset_config or not isinstance(dataset_config, dict):
raise ValueError(f"Key `dataset` not found or not a dict in config_full")
alphabet_str = dataset_config['alphabet_string']
MAX_TEXT_LEN = dataset_config['max_text_len']
output_mixture_components = model_params['output_mixture_components']
lstm_size = model_params['lstm_size']
attention_mixture_components = model_params['attention_mixture_components']
ALPHABET_LIST = construct_alphabet_list(alphabet_str)
ALPHABET_SIZE = len(ALPHABET_LIST)
ALPHABET_MAP = get_alphabet_map(ALPHABET_LIST)
logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}")
logger.info("Model resources are loaded and ready")
else:
raise ValueError(f"Failed to load content frm metadata file")
except Exception as e:
logger.error(f"Error loading model resources: {e}", exc_info=True)
SCRIPTED_MODEL = None
MODEL_METADATA = None
raise
yield
# Cleanup on shutdown
logger.info("Shutting down API and cleaning up resources")
SCRIPTED_MODEL = None
MODEL_METADATA = None
app = FastAPI(
title="Scriptify API",
description="API to generate handwriting from text using a PyTorch model.",
version="0.1.0",
lifespan=lifespan
)
# add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173","http://127.0.0.1:5173"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
@app.get("/", tags=["General"])
async def read_root():
return {"message": "Welcome to the Scriptify Handwriting Generation API!"}
@app.get("/health", response_model=HealthResponse, tags=["General"])
async def health_check():
global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST
is_healthy = all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST])
return HealthResponse(
status="healthy" if is_healthy else "unhealthy",
model_loaded=bool(SCRIPTED_MODEL),
device=str(DEVICE) if DEVICE else "unknown",
model_metadata_keys=list(MODEL_METADATA.keys()) if MODEL_METADATA else None,
)
def text_to_tensor(text: str, max_text_length: int, add_eos: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert text to tensor format expected by the model"""
if ALPHABET_MAP is None:
raise ValueError("Alphabet map not initialized during api startup")
padded_encoded_np, true_length = encode_text(
text=text,
char_to_index_map=ALPHABET_MAP,
max_length=max_text_length,
add_eos = add_eos
)
char_seq = torch.from_numpy(padded_encoded_np).to(device=DEVICE, dtype=torch.long)
char_len = torch.tensor([true_length], device=DEVICE, dtype=torch.long)
return char_seq, char_len
def generate_strokes(
char_seq: torch.Tensor,
char_lengths: torch.Tensor,
max_gen_len: int,
api_bias: float,
style: Optional[int] = None
) -> list[list[float]]:
"""Generate strokes using the model's built-in sample method"""
global SCRIPTED_MODEL
if SCRIPTED_MODEL is None:
raise ValueError("Scripted model not initialized.")
primingData = None
if style is not None:
priming_text, priming_strokes = load_priming_data(style)
priming_text_tensor, priming_text_len_tensor = text_to_tensor(
priming_text, max_text_length=len(priming_text), add_eos=False)
priming_stroke_tensor = torch.tensor(priming_strokes,
dtype=torch.float32,
device=DEVICE).unsqueeze(dim=0)
primingData = PrimingData(priming_stroke_tensor,
char_seq_tensors=priming_text_tensor,
char_seq_lengths=priming_text_len_tensor)
with torch.inference_mode():
try:
stroke_tensors = SCRIPTED_MODEL.sample(
char_seq,
char_lengths,
max_length=max_gen_len,
bias=api_bias,
prime=primingData
)
# batch_size is 1
if len(stroke_tensors) == 1:
all_strokes_tensor = stroke_tensors[0]
stroke_offsets = all_strokes_tensor.cpu().numpy().tolist()
else:
stroke_offsets = []
logger.warning(f"Expected single batch, but got {len(stroke_tensors)}")
return stroke_offsets
except Exception as e:
logger.error(f"Error in model sampling: {e}", exc_info=True)
return []
@app.post("/generate", response_model=HandwritingResponse, tags=["Generation"])
async def generate_handwriting_endpoint(request: HandwritingRequest):
if not all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN]):
logger.error("API not fully initialized. Check /health endpoint.")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model or required resources not loaded."
)
assert DEVICE is not None, "Device is None inside generate_handwriting"
start_time = time.time()
try:
char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, max_text_length=MAX_TEXT_LEN) # type: ignore
relative_stroke_offsets = generate_strokes(
char_seq_tensor, char_lengths_tensor,
request.max_length,
request.bias,
# style=1 #TODO: style is hardcode since the current version is hosted on cpu
)
if not relative_stroke_offsets:
return HandwritingResponse(
success=False,
input_text=request.text,
strokes=[],
num_points=0,
generation_time_ms=(time.time() - start_time) * 1000,
message="No strokes generated."
)
absolute_stroke_coords = convert_offsets_to_absolute_coords(relative_stroke_offsets)
generation_time_ms = (time.time() - start_time) * 1000
return HandwritingResponse(
input_text=request.text,
strokes=absolute_stroke_coords,
num_points=len(absolute_stroke_coords),
generation_time_ms=generation_time_ms
)
except ValueError as ve:
logger.error(f"ValueError during generation for '{request.text}': {ve}", exc_info=True)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve))
except Exception as e:
logger.error(f"Unexpected error for '{request.text}': {e}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
if __name__ == "__main__":
import uvicorn
logger.info("Starting Uvicorn server for Scriptify API...")
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, app_dir=".")