eziokittu commited on
Commit
207bcfa
·
verified ·
1 Parent(s): 7bdc45c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -12
main.py CHANGED
@@ -11,17 +11,15 @@ from projects.DL_CatDog.DL_CatDog import preprocess_image, read_image, model_DL_
11
  from projects.ML_StudentPerformance.ML_StudentPerformace import predict_student_performance, create_custom_data, form1
12
  from projects.ML_DiabetesPrediction.ML_DiabetesPrediction import model_ML_DiabetesPrediction, form2
13
 
14
- # Set cache directory to /tmp/mycache
15
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/mycache'
16
- if not os.path.exists('/tmp/mycache'):
17
- os.makedirs('/tmp/mycache')
18
-
19
- # Initialize the pipeline
20
- pipe = pipeline(
21
- "image-classification",
22
- model="wambugu71/crop_leaf_diseases_vit",
23
- cache_dir='/tmp/mycache'
24
- )
25
 
26
  app = FastAPI()
27
 
@@ -52,7 +50,6 @@ async def predict_DL_CatDog(file: UploadFile = File(...)):
52
  return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500)
53
 
54
  # Classification route for DL_PlantDisease
55
- pipe = pipeline("image-classification", model="wambugu71/crop_leaf_diseases_vit", cache_dir="/tmp/mycache")
56
  @app.post("/api/classify")
57
  async def classify_image(file: UploadFile = File(...)):
58
  try:
 
11
  from projects.ML_StudentPerformance.ML_StudentPerformace import predict_student_performance, create_custom_data, form1
12
  from projects.ML_DiabetesPrediction.ML_DiabetesPrediction import model_ML_DiabetesPrediction, form2
13
 
14
+ # Set the cache directory to a writable location
15
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
16
+
17
+ # Make sure the directory exists
18
+ os.makedirs('/tmp/.cache', exist_ok=True)
19
+
20
+ # Initialize the pipeline with the new cache directory
21
+ pipe = pipeline("image-classification", model="wambugu71/crop_leaf_diseases_vit", cache_dir="/tmp/.cache")
22
+
 
 
23
 
24
  app = FastAPI()
25
 
 
50
  return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500)
51
 
52
  # Classification route for DL_PlantDisease
 
53
  @app.post("/api/classify")
54
  async def classify_image(file: UploadFile = File(...)):
55
  try: