adrienbrdne commited on
Commit
58d1f23
·
verified ·
1 Parent(s): 9d85d89

Update scoring/specificity.py

Browse files
Files changed (1) hide show
  1. scoring/specificity.py +17 -12
scoring/specificity.py CHANGED
@@ -19,6 +19,10 @@ class PredictionResponse(BaseModel):
19
  class PredictionsResponse(BaseModel):
20
  results: List[Dict[str, Union[str, float]]]
21
 
 
 
 
 
22
  # Model environment variables
23
  MODEL_NAME = os.getenv("MODEL_NAME")
24
  LABEL_0 = os.getenv("LABEL_0")
@@ -102,19 +106,20 @@ def predict_batch(items: ProblematicList):
102
  with torch.no_grad():
103
  outputs = model(**inputs)
104
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
105
- predicted_classes = torch.argmax(probabilities, dim=1).tolist()
106
- confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
107
-
108
- # Converting numerical predictions into labels
109
- for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)):
110
- predicted_label = LABEL_0 if pred_class == 0 else LABEL_1
111
- results.append({
112
- "text": batch_texts[j],
113
- "class": predicted_label,
114
- "score": score
115
- })
 
116
 
117
- return PredictionsResponse(results=results)
118
 
119
  except Exception as e:
120
  print(f"Error during prediction: {str(e)}")
 
19
  class PredictionsResponse(BaseModel):
20
  results: List[Dict[str, Union[str, float]]]
21
 
22
+ class BatchPredictionScoreItem(BaseModel):
23
+ problem_description: str
24
+ score: float
25
+
26
  # Model environment variables
27
  MODEL_NAME = os.getenv("MODEL_NAME")
28
  LABEL_0 = os.getenv("LABEL_0")
 
106
  with torch.no_grad():
107
  outputs = model(**inputs)
108
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
109
+ # predicted_classes = torch.argmax(probabilities, dim=1).tolist()
110
+ # confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))]
111
+
112
+ for j in range(len(batch_texts)):
113
+ score_specific_class = probabilities[j][1].item()
114
+
115
+ results.append(
116
+ BatchPredictionScoreItem(
117
+ problem_description=batch_texts[j],
118
+ score=score_class_1
119
+ )
120
+ )
121
 
122
+ return results
123
 
124
  except Exception as e:
125
  print(f"Error during prediction: {str(e)}")