Spaces:
Sleeping
Sleeping
import os | |
from pydantic import BaseModel | |
from typing import List, Dict, Union | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Definition of Pydantic data models | |
class ProblematicItem(BaseModel): | |
text: str | |
class ProblematicList(BaseModel): | |
problematics: List[str] | |
class PredictionResponse(BaseModel): | |
predicted_class: str | |
score: float | |
class PredictionsResponse(BaseModel): | |
results: List[Dict[str, Union[str, float]]] | |
# Model environment variables | |
MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele") | |
LABEL_0 = os.getenv("LABEL_0", "Classe A") | |
LABEL_1 = os.getenv("LABEL_1", "Classe B") | |
# Loading the model and tokenizer | |
tokenizer = None | |
model = None | |
def load_model(): | |
global tokenizer, model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
return True | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return False | |
def health_check(): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
success = load_model() | |
if not success: | |
print("Model not available") | |
return {"status": "ok", "model": MODEL_NAME} | |
def predict_single(item: ProblematicItem): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
success = load_model() | |
if not success: | |
print('Error loading the model.') | |
try: | |
# Tokenization | |
inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt") | |
# Prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
confidence_score = probabilities[0][predicted_class].item() | |
# Associate the correct label | |
predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1 | |
return PredictionResponse(predicted_class=predicted_label, score=confidence_score) | |
except Exception as e: | |
print(f"Error during prediction: {str(e)}") | |
def predict_batch(items: ProblematicList): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
success = load_model() | |
if not success: | |
print("Model not available") | |
try: | |
results = [] | |
# Batch processing | |
batch_size = 8 | |
for i in range(0, len(items.problematics), batch_size): | |
batch_texts = items.problematics[i:i+batch_size] | |
# Tokenization | |
inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt") | |
# Prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_classes = torch.argmax(probabilities, dim=1).tolist() | |
confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))] | |
# Converting numerical predictions into labels | |
for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)): | |
predicted_label = LABEL_0 if pred_class == 0 else LABEL_1 | |
results.append({ | |
"text": batch_texts[j], | |
"class": predicted_label, | |
"score": score | |
}) | |
return PredictionsResponse(results=results) | |
except Exception as e: | |
print(f"Error during prediction: {str(e)}") |