adrienbrdne's picture
Upload 3 files
417877c verified
raw
history blame
4.88 kB
import os
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Definition of Pydantic data models
class ProblematicItem(BaseModel):
text: str
class ProblematicList(BaseModel):
problematics: List[str]
class PredictionResponse(BaseModel):
predicted_class: str
score: float
class PredictionsResponse(BaseModel):
results: List[Dict[str, Union[str, float]]]
# FastAPI Configuration
app = FastAPI(
title="Problematic Specificity Classification API",
description="This API classifies problematics using a fine-tuned model hosted on Hugging Face.",
version="1.0.0"
)
# Model environment variables
MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele")
LABEL_0 = os.getenv("LABEL_0", "Classe A")
LABEL_1 = os.getenv("LABEL_1", "Classe B")
# Loading the model and tokenizer
tokenizer = None
model = None
def load_model():
global tokenizer, model
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
# API state check route
@app.get("/")
def read_root():
return {"status": "ok", "model": MODEL_NAME}
# Route for checking model status
@app.get("/health")
def health_check():
global model, tokenizer
if model is None or tokenizer is None:
success = load_model()
if not success:
raise HTTPException(status_code=503, detail="Model not available")
return {"status": "ok", "model": MODEL_NAME}
# Route to predict a single problem at a time
@app.post("/predict", response_model=PredictionResponse)
def predict_single(item: ProblematicItem):
global model, tokenizer
if model is None or tokenizer is None:
success = load_model()
if not success:
raise HTTPException(status_code=503, detail="Model not available")
try:
# Tokenization
inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
# Prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence_score = probabilities[0][predicted_class].item()
# Associate the correct label
predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# Route for predicting several problems at once
@app.post("/predict-batch", response_model=PredictionsResponse)
def predict_batch(items: ProblematicList):
global model, tokenizer
if model is None or tokenizer is None:
success = load_model()
if not success:
raise HTTPException(status_code=503, detail="Model not available")
try:
results = []
# Batch processing
batch_size = 16
for i in range(0, len(items.problematics), batch_size):
batch_texts = items.problematics[i:i+batch_size]
# Tokenization
inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
# Prediction
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_classes = torch.argmax(probabilities, dim=1).tolist()
confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
# Converting numerical predictions into labels
for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
results.append({
"text": batch_texts[j],
"class": predicted_label,
"score": score
})
return PredictionsResponse(results=results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
# Model loading at startup
@app.on_event("startup")
async def startup_event():
load_model()
# Entry point for uvicorn
if __name__ == "__main__":
# Starting the server with uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)