Spaces:
Running
Running
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import torch | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| from inference_utils import PrimingData, construct_alphabet_list, convert_offsets_to_absolute_coords, encode_text, get_alphabet_map, load_priming_data | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MODEL_DIR = Path("./packaged_models") | |
| QUANTIZED_MODEL_NAME = "model.scripted.quantized.pt" | |
| SCRIPTED_MODEL_NAME = "model.scripted.pt" | |
| METADATA_MODEL_NAME = "model.pt" | |
| SCRIPTED_MODEL: Optional[torch.jit.ScriptModule] = None | |
| MODEL_METADATA: Optional[dict] = None | |
| DEVICE: Optional[torch.device] = None | |
| ALPHABET_MAP: Optional[dict[str, int]] = None | |
| ALPHABET_LIST: Optional[list[str]] = None | |
| ALPHABET_SIZE: Optional[int] = None | |
| MAX_TEXT_LEN: Optional[int] = None | |
| output_mixture_components: Optional[int] = None # To store num_mixtures for GMM sampling | |
| lstm_size: Optional[int] = None | |
| attention_mixture_components: Optional[int] = None | |
| class HandwritingRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=40, description="Text to generate handwriting for") | |
| max_length: int = Field(default=1000, ge=50, le=1200, description="Maximum number of stroke points") | |
| bias: float = Field(default=0.75, ge=0.1, le=10.0, description="Sampling bias for generation") | |
| class HandwritingResponse(BaseModel): | |
| success: bool = True | |
| input_text: str | |
| generation_time_ms: float | |
| num_points: int | |
| strokes: list[list[float]] | |
| message: str = "Successfully generated handwriting." | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| model_metadata_keys: Optional[list[str]] = None | |
| async def lifespan(app: FastAPI): | |
| """Lifespan context manager for startup and shutdown events""" | |
| global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST, output_mixture_components, lstm_size, attention_mixture_components, ALPHABET_SIZE | |
| logger.info("Attempting to load model resources during startup") | |
| try: | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {DEVICE}") | |
| scripted_model_path = MODEL_DIR / SCRIPTED_MODEL_NAME | |
| metadata_model_path = MODEL_DIR / METADATA_MODEL_NAME | |
| # if DEVICE.type == "cpu": | |
| # scripted_model_path = MODEL_DIR / QUANTIZED_MODEL_NAME | |
| if not scripted_model_path.exists(): | |
| logger.error(f"Traced model not found at {scripted_model_path}") | |
| raise FileNotFoundError(f"Traced model not found at {scripted_model_path}") | |
| if not metadata_model_path or not metadata_model_path.exists(): | |
| logger.error(f"Metadata model file not found at {metadata_model_path}") | |
| raise FileNotFoundError(f"Metadata model file not found at {metadata_model_path}") | |
| # Load the traced model | |
| SCRIPTED_MODEL = torch.jit.load(scripted_model_path, map_location=DEVICE) | |
| if SCRIPTED_MODEL: | |
| SCRIPTED_MODEL.eval() | |
| logger.info(f"Traced model loaded successfully from {scripted_model_path}") | |
| # Load the metadata | |
| MODEL_METADATA = torch.load(metadata_model_path, map_location='cpu') | |
| if MODEL_METADATA: | |
| logger.info(f"Model metadata loaded successfully from {metadata_model_path}") | |
| logger.info(f"Model metadata keys: {list(MODEL_METADATA.keys())}") | |
| config_full = MODEL_METADATA['config_full'] | |
| if not config_full or not isinstance(config_full, dict): | |
| raise ValueError(f"Key `config_full` not found or not a dict") | |
| dataset_config = config_full['dataset'] | |
| model_params = config_full['model_params'] | |
| if not dataset_config or not isinstance(dataset_config, dict): | |
| raise ValueError(f"Key `dataset` not found or not a dict in config_full") | |
| alphabet_str = dataset_config['alphabet_string'] | |
| MAX_TEXT_LEN = dataset_config['max_text_len'] | |
| output_mixture_components = model_params['output_mixture_components'] | |
| lstm_size = model_params['lstm_size'] | |
| attention_mixture_components = model_params['attention_mixture_components'] | |
| ALPHABET_LIST = construct_alphabet_list(alphabet_str) | |
| ALPHABET_SIZE = len(ALPHABET_LIST) | |
| ALPHABET_MAP = get_alphabet_map(ALPHABET_LIST) | |
| logger.info(f"Alphabet created. Size: {len(ALPHABET_LIST)}") | |
| logger.info("Model resources are loaded and ready") | |
| else: | |
| raise ValueError(f"Failed to load content frm metadata file") | |
| except Exception as e: | |
| logger.error(f"Error loading model resources: {e}", exc_info=True) | |
| SCRIPTED_MODEL = None | |
| MODEL_METADATA = None | |
| raise | |
| yield | |
| # Cleanup on shutdown | |
| logger.info("Shutting down API and cleaning up resources") | |
| SCRIPTED_MODEL = None | |
| MODEL_METADATA = None | |
| app = FastAPI( | |
| title="Scriptify API", | |
| description="API to generate handwriting from text using a PyTorch model.", | |
| version="0.1.0", | |
| lifespan=lifespan | |
| ) | |
| # add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:5173","http://127.0.0.1:5173"], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| async def read_root(): | |
| return {"message": "Welcome to the Scriptify Handwriting Generation API!"} | |
| async def health_check(): | |
| global SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST | |
| is_healthy = all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN, ALPHABET_LIST]) | |
| return HealthResponse( | |
| status="healthy" if is_healthy else "unhealthy", | |
| model_loaded=bool(SCRIPTED_MODEL), | |
| device=str(DEVICE) if DEVICE else "unknown", | |
| model_metadata_keys=list(MODEL_METADATA.keys()) if MODEL_METADATA else None, | |
| ) | |
| def text_to_tensor(text: str, max_text_length: int, add_eos: bool = True) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Convert text to tensor format expected by the model""" | |
| if ALPHABET_MAP is None: | |
| raise ValueError("Alphabet map not initialized during api startup") | |
| padded_encoded_np, true_length = encode_text( | |
| text=text, | |
| char_to_index_map=ALPHABET_MAP, | |
| max_length=max_text_length, | |
| add_eos = add_eos | |
| ) | |
| char_seq = torch.from_numpy(padded_encoded_np).to(device=DEVICE, dtype=torch.long) | |
| char_len = torch.tensor([true_length], device=DEVICE, dtype=torch.long) | |
| return char_seq, char_len | |
| def generate_strokes( | |
| char_seq: torch.Tensor, | |
| char_lengths: torch.Tensor, | |
| max_gen_len: int, | |
| api_bias: float, | |
| style: Optional[int] = None | |
| ) -> list[list[float]]: | |
| """Generate strokes using the model's built-in sample method""" | |
| global SCRIPTED_MODEL | |
| if SCRIPTED_MODEL is None: | |
| raise ValueError("Scripted model not initialized.") | |
| primingData = None | |
| if style is not None: | |
| priming_text, priming_strokes = load_priming_data(style) | |
| priming_text_tensor, priming_text_len_tensor = text_to_tensor( | |
| priming_text, max_text_length=len(priming_text), add_eos=False) | |
| priming_stroke_tensor = torch.tensor(priming_strokes, | |
| dtype=torch.float32, | |
| device=DEVICE).unsqueeze(dim=0) | |
| primingData = PrimingData(priming_stroke_tensor, | |
| char_seq_tensors=priming_text_tensor, | |
| char_seq_lengths=priming_text_len_tensor) | |
| with torch.inference_mode(): | |
| try: | |
| stroke_tensors = SCRIPTED_MODEL.sample( | |
| char_seq, | |
| char_lengths, | |
| max_length=max_gen_len, | |
| bias=api_bias, | |
| prime=primingData | |
| ) | |
| # batch_size is 1 | |
| if len(stroke_tensors) == 1: | |
| all_strokes_tensor = stroke_tensors[0] | |
| stroke_offsets = all_strokes_tensor.cpu().numpy().tolist() | |
| else: | |
| stroke_offsets = [] | |
| logger.warning(f"Expected single batch, but got {len(stroke_tensors)}") | |
| return stroke_offsets | |
| except Exception as e: | |
| logger.error(f"Error in model sampling: {e}", exc_info=True) | |
| return [] | |
| async def generate_handwriting_endpoint(request: HandwritingRequest): | |
| if not all([SCRIPTED_MODEL, MODEL_METADATA, DEVICE, ALPHABET_MAP, MAX_TEXT_LEN]): | |
| logger.error("API not fully initialized. Check /health endpoint.") | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Model or required resources not loaded." | |
| ) | |
| assert DEVICE is not None, "Device is None inside generate_handwriting" | |
| start_time = time.time() | |
| try: | |
| char_seq_tensor, char_lengths_tensor = text_to_tensor(request.text, max_text_length=MAX_TEXT_LEN) # type: ignore | |
| relative_stroke_offsets = generate_strokes( | |
| char_seq_tensor, char_lengths_tensor, | |
| request.max_length, | |
| request.bias, | |
| # style=1 #TODO: style is hardcode since the current version is hosted on cpu | |
| ) | |
| if not relative_stroke_offsets: | |
| return HandwritingResponse( | |
| success=False, | |
| input_text=request.text, | |
| strokes=[], | |
| num_points=0, | |
| generation_time_ms=(time.time() - start_time) * 1000, | |
| message="No strokes generated." | |
| ) | |
| absolute_stroke_coords = convert_offsets_to_absolute_coords(relative_stroke_offsets) | |
| generation_time_ms = (time.time() - start_time) * 1000 | |
| return HandwritingResponse( | |
| input_text=request.text, | |
| strokes=absolute_stroke_coords, | |
| num_points=len(absolute_stroke_coords), | |
| generation_time_ms=generation_time_ms | |
| ) | |
| except ValueError as ve: | |
| logger.error(f"ValueError during generation for '{request.text}': {ve}", exc_info=True) | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(ve)) | |
| except Exception as e: | |
| logger.error(f"Unexpected error for '{request.text}': {e}", exc_info=True) | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Starting Uvicorn server for Scriptify API...") | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, app_dir=".") |