deveix commited on
Commit
2756870
·
1 Parent(s): b80ec6f

fix argmax

Browse files
Files changed (1) hide show
  1. app/main.py +3 -2
app/main.py CHANGED
@@ -327,8 +327,9 @@ async def handle_cnn(file: UploadFile = File(...)):
327
  print('predictions', predictions)
328
 
329
  # Convert predictions to label indexes
330
- predicted_label_indexes = np.argmax(predictions)
331
-
 
332
  # Convert label indexes to actual label names
333
  predicted_labels = cnn_label_encoder.inverse_transform([predicted_label_indexes])
334
 
 
327
  print('predictions', predictions)
328
 
329
  # Convert predictions to label indexes
330
+ predicted_label_indexes = np.argmax(predictions, axis=1)
331
+ print(predicted_label_indexes)
332
+
333
  # Convert label indexes to actual label names
334
  predicted_labels = cnn_label_encoder.inverse_transform([predicted_label_indexes])
335