Abhishek Kumar commited on
Commit
9d6a6eb
·
1 Parent(s): a1b23c0

Add data preprocessing and better error handling

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
3
  import joblib
4
  import pandas as pd
5
  import numpy as np
 
6
 
7
  app = FastAPI()
8
 
@@ -18,20 +19,39 @@ app.add_middleware(
18
  # Load the model
19
  model = joblib.load('superkart_sales_model.joblib')
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @app.get("/")
22
  async def root():
23
  return {"message": "SuperKart Sales Prediction API"}
24
 
25
  @app.post("/predict")
26
  async def predict(file: UploadFile = File(...)):
27
- # Read the uploaded CSV file
28
- df = pd.read_csv(file.file)
29
-
30
- # Make predictions
31
- predictions = model.predict(df)
32
-
33
- return {"predictions": predictions.tolist()}
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
  import uvicorn
37
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
3
  import joblib
4
  import pandas as pd
5
  import numpy as np
6
+ from datetime import datetime
7
 
8
  app = FastAPI()
9
 
 
19
  # Load the model
20
  model = joblib.load('superkart_sales_model.joblib')
21
 
22
+ def preprocess_data(df):
23
+ # Calculate Price_Weight_Ratio
24
+ df['Price_Weight_Ratio'] = df['Product_MRP'] / df['Product_Weight']
25
+
26
+ # Calculate Store_Age
27
+ current_year = datetime.now().year
28
+ df['Store_Age'] = current_year - df['Store_Establishment_Year']
29
+
30
+ # Calculate Product_Year (assuming it's the same as Store_Establishment_Year for this example)
31
+ df['Product_Year'] = df['Store_Establishment_Year']
32
+
33
+ return df
34
+
35
  @app.get("/")
36
  async def root():
37
  return {"message": "SuperKart Sales Prediction API"}
38
 
39
  @app.post("/predict")
40
  async def predict(file: UploadFile = File(...)):
41
+ try:
42
+ # Read the uploaded CSV file
43
+ df = pd.read_csv(file.file)
44
+
45
+ # Preprocess the data
46
+ df = preprocess_data(df)
47
+
48
+ # Make predictions
49
+ predictions = model.predict(df)
50
+
51
+ return {"predictions": predictions.tolist()}
52
+ except Exception as e:
53
+ return {"error": str(e)}, 500
54
 
55
  if __name__ == "__main__":
56
  import uvicorn
57
+ uvicorn.run(app, host="0.0.0.0", port=7860)