eziokittu commited on
Commit
7bdc45c
·
verified ·
1 Parent(s): 504700a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -6
main.py CHANGED
@@ -11,11 +11,17 @@ from projects.DL_CatDog.DL_CatDog import preprocess_image, read_image, model_DL_
11
  from projects.ML_StudentPerformance.ML_StudentPerformace import predict_student_performance, create_custom_data, form1
12
  from projects.ML_DiabetesPrediction.ML_DiabetesPrediction import model_ML_DiabetesPrediction, form2
13
 
14
- cache_dir = './mycache'
15
- if not os.path.exists(cache_dir):
16
- os.makedirs(cache_dir)
17
-
18
- os.environ['TRANSFORMERS_CACHE'] = cache_dir
 
 
 
 
 
 
19
 
20
  app = FastAPI()
21
 
@@ -46,7 +52,7 @@ async def predict_DL_CatDog(file: UploadFile = File(...)):
46
  return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500)
47
 
48
  # Classification route for DL_PlantDisease
49
- pipe = pipeline("image-classification", model="wambugu71/crop_leaf_diseases_vit", cache_dir=cache_dir)
50
  @app.post("/api/classify")
51
  async def classify_image(file: UploadFile = File(...)):
52
  try:
 
11
  from projects.ML_StudentPerformance.ML_StudentPerformace import predict_student_performance, create_custom_data, form1
12
  from projects.ML_DiabetesPrediction.ML_DiabetesPrediction import model_ML_DiabetesPrediction, form2
13
 
14
+ # Set cache directory to /tmp/mycache
15
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/mycache'
16
+ if not os.path.exists('/tmp/mycache'):
17
+ os.makedirs('/tmp/mycache')
18
+
19
+ # Initialize the pipeline
20
+ pipe = pipeline(
21
+ "image-classification",
22
+ model="wambugu71/crop_leaf_diseases_vit",
23
+ cache_dir='/tmp/mycache'
24
+ )
25
 
26
  app = FastAPI()
27
 
 
52
  return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500)
53
 
54
  # Classification route for DL_PlantDisease
55
+ pipe = pipeline("image-classification", model="wambugu71/crop_leaf_diseases_vit", cache_dir="/tmp/mycache")
56
  @app.post("/api/classify")
57
  async def classify_image(file: UploadFile = File(...)):
58
  try: