File size: 4,058 Bytes
45f7e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58d1f23
f423824
58d1f23
 
45f7e41
6948ec2
 
 
45f7e41
f7bea85
 
 
45f7e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21968dc
 
45f7e41
 
 
 
 
 
 
 
 
 
 
 
 
58d1f23
 
 
 
 
 
 
 
f423824
06303ec
58d1f23
 
68b86b8
45f7e41
21968dc
 
 
45f7e41
21968dc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from pydantic import BaseModel
from typing import List, Dict, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


# Definition of Pydantic data models
class ProblematicItem(BaseModel):
    text: str

class ProblematicList(BaseModel):
    problematics: List[str]

class PredictionResponse(BaseModel):
    predicted_class: str
    score: float

class PredictionsResponse(BaseModel):
    results: List[Dict[str, Union[str, float]]]

class BatchPredictionScoreItem(BaseModel):
    problematic: str
    score: float

# Model environment variables
MODEL_NAME = os.getenv("MODEL_NAME")
LABEL_0 = os.getenv("LABEL_0")
LABEL_1 = os.getenv("LABEL_1")

if not MODEL_NAME:
    raise ValueError("Environment variable MODEL_NAME is not set.")

# Loading the model and tokenizer
tokenizer = None
model = None


def load_model():
    global tokenizer, model
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False


def health_check():
    global model, tokenizer
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print("Model not available")
    return {"status": "ok", "model": MODEL_NAME}


def predict_single(item: ProblematicItem):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print('Error loading the model.')
    
    try:
        # Tokenization
        inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt")
        
        # Prediction
        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence_score = probabilities[0][predicted_class].item()
        
        # Associate the correct label
        predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1
        
        return PredictionResponse(predicted_class=predicted_label, score=confidence_score)
    
    except Exception as e:
        print(f"Error during prediction: {str(e)}")

def predict_batch(items: ProblematicList):
    global model, tokenizer
    
    if model is None or tokenizer is None:
        success = load_model()
        if not success:
            print("Model not available")
    
    try:
        results = []
        if not items.problematics:
            return []
        
        # Batch processing
        batch_size = 8
        for i in range(0, len(items.problematics), batch_size):
            batch_texts = items.problematics[i:i+batch_size]
            
            # Tokenization
            inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
            
            # Prediction
            with torch.no_grad():
                outputs = model(**inputs)
                probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
                # predicted_classes = torch.argmax(probabilities, dim=1).tolist()
                # confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]

            for j in range(len(batch_texts)):
                score_specific_class = probabilities[j][1].item() 
                
                results.append(
                    BatchPredictionScoreItem(
                        problematic=batch_texts[j],
                        score=score_specific_class
                    )
                )
        return results
    
    except AttributeError as ae:
        print(f"AttributeError during prediction in predict_batch (likely wrong input type): {str(ae)}")
        return []
    except Exception as e:
        print(f"Generic error during prediction in predict_batch: {str(e)}")
        return []