deveix commited on
Commit
19a4c6d
·
1 Parent(s): 2d7bdd7

fix accuracy

Browse files
Files changed (1) hide show
  1. app/main.py +9 -5
app/main.py CHANGED
@@ -140,8 +140,8 @@ async def get_answer(item: Item, token: str = Depends(verify_token)):
140
  raise HTTPException(status_code=500, detail=str(e))
141
 
142
  # mlp
143
- mlp_model = joblib.load('app/mlp_model.pkl')
144
- mlp_pca = joblib.load('app/pca.pkl')
145
  mlp_scaler = joblib.load('app/scaler.pkl')
146
  mlp_label_encoder = joblib.load('app/label_encoder.pkl')
147
 
@@ -267,13 +267,17 @@ async def handle_audio(file: UploadFile = File(...)):
267
  features = mlp_scaler.transform(features)
268
  features = mlp_pca.transform(features)
269
 
270
- # Dummy example to proceed with an inference
271
  results = mlp_model.predict(features)
272
 
273
- # Clean up (optional, especially if dealing with large files or sensitive data)
 
 
 
274
  os.remove(temp_filename)
275
 
276
- return {"message": "File processed successfully", "prediction": results.tolist()}
 
277
  except Exception as e:
278
  # Handle possible exceptions
279
  raise HTTPException(status_code=500, detail=str(e))
 
140
  raise HTTPException(status_code=500, detail=str(e))
141
 
142
  # mlp
143
+ mlp_model = joblib.load('app/mlp_model2.pkl')
144
+ mlp_pca = joblib.load('app/pca2.pkl')
145
  mlp_scaler = joblib.load('app/scaler.pkl')
146
  mlp_label_encoder = joblib.load('app/label_encoder.pkl')
147
 
 
267
  features = mlp_scaler.transform(features)
268
  features = mlp_pca.transform(features)
269
 
270
+ # proceed with an inference
271
  results = mlp_model.predict(features)
272
 
273
+ # Decode the predictions using the label encoder
274
+ decoded_predictions = mlp_label_encoder.inverse_transform(results)
275
+
276
+ # Clean up the temporary file
277
  os.remove(temp_filename)
278
 
279
+ # Return a successful response with decoded predictions
280
+ return {"message": "File processed successfully", "prediction": decoded_predictions.tolist()}
281
  except Exception as e:
282
  # Handle possible exceptions
283
  raise HTTPException(status_code=500, detail=str(e))