File size: 3,998 Bytes
330e067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Dict, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


# Definition of Pydantic data models
class ProblematicItem(BaseModel):
    text: str

class ProblematicList(BaseModel):
    problematics: List[str]

class PredictionResponse(BaseModel):
    predicted_class: str
    score: float

class PredictionsResponse(BaseModel):
    results: List[Dict[str, Union[str, float]]]

# Model environment variables
MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele")
LABEL_0 = os.getenv("LABEL_0", "Classe A")
LABEL_1 = os.getenv("LABEL_1", "Classe B")

# Loading the model and tokenizer
tokenizer = None
model = None

def load_model():
    global tokenizer, model
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False


def health_check():
    global model, tokenizer
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            raise HTTPException(status_code=503, detail="Model not available")
    return {"status": "ok", "model": MODEL_NAME}


def predict_single(item: ProblematicItem):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print('Error loading the model.')
    
    try:
        # Tokenization
        inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
        
        # Prediction
        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence_score = probabilities[0][predicted_class].item()
        
        # Associate the correct label
        predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
        
        return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
    
    except Exception as e:
        print(f"Error during prediction: {str(e)}")

def predict_batch(items: ProblematicList):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print("Model not available")
    
    try:
        results = []
        
        # Batch processing
        batch_size = 8
        for i in range(0, len(items.problematics), batch_size):
            batch_texts = items.problematics[i:i+batch_size]
            
            # Tokenization
            inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
            
            # Prediction
            with torch.no_grad():
                outputs = model(**inputs)
                probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
                predicted_classes = torch.argmax(probabilities, dim=1).tolist()
                confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
            
            # Converting numerical predictions into labels
            for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
                predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
                results.append({
                    "text": batch_texts[j],
                    "class": predicted_label,
                    "score": score
                })
        
        return PredictionsResponse(results=results)
    
    except Exception as e:
        print(f"Error during prediction: {str(e)}")