import os import requests import torch from .training_status import Status from .environment_variable_checker import EnvironmentVariableChecker from .task_manager import TaskManager from .training_manager import TrainingManager from .image_classification.image_classification_trainer import ImageClassificationTrainer from .image_classification.image_classification_parameters import ImageClassificationParameters, map_image_classification_training_parameters, ImageClassificationTrainingParameters from fastapi import FastAPI, Header, Depends, HTTPException, BackgroundTasks, UploadFile, Form, File, status from fastapi.responses import FileResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from typing import Optional, Annotated import logging import sys import zipfile import os from pathlib import Path import tempfile import shutil app = FastAPI() environmentVariableChecker = EnvironmentVariableChecker() environmentVariableChecker.validate_environment_variables() logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s') logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) classification_trainer: TrainingManager = TrainingManager(ImageClassificationTrainer()) security = HTTPBearer() def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): token = environmentVariableChecker.get_authentication_token() if credentials.credentials != token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) return {"token": credentials.credentials} class ResponseModel(BaseModel): message: str success: bool = True @app.post( "/upload", summary="Upload a zip file containing training data", response_model=ResponseModel ) async def upload_file( training_params: Annotated[ImageClassificationTrainingParameters, Depends(map_image_classification_training_parameters)], data_files_training: Annotated[UploadFile, File(...)], token_data: dict = Depends(verify_token), result_model_name: str = Form(...), source_model_name: str = Form('google/vit-base-patch16-224-in21k'), ): # check if training is running, if so then exit status = classification_trainer.get_task_status() if status.get_status() == Status.IN_PROGRESS or status.get_status() == Status.CANCELLING: raise HTTPException(status_code=405, detail="Training is already in progress") # Ensure the uploaded file is a ZIP file if not data_files_training.filename.endswith(".zip"): raise HTTPException(status_code=422, detail="Uploaded file is not a zip file") try: # Create a temporary directory to extract the contents tmp_path = os.path.join(tempfile.gettempdir(), 'training_data') path = Path(tmp_path) path.mkdir(parents=True, exist_ok=True) contents = await data_files_training.read() zip_path = os.path.join(tmp_path, 'image_classification_data.zip') with open(zip_path, 'wb') as temp_file: temp_file.write(contents) # prepare parameters parameters = ImageClassificationParameters( training_files_path=tmp_path, training_zip_file_path=zip_path, result_model_name=result_model_name, source_model_name=source_model_name, training_parameters=training_params ) # start training await classification_trainer.start_training(parameters) # TODO add more return parameters and information return ResponseModel(message="training started") except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") @app.get("/get_task_status") async def get_task_status(token_data: dict = Depends(verify_token)): status = classification_trainer.get_task_status() return { "progress": status.get_progress(), "task": status.get_task(), "status": status.get_status().value } @app.get("/stop_task") async def stop_task(token_data: dict = Depends(verify_token)): try: classification_trainer.stop_task() return { "success": True } except Exception as e: raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") @app.get("/gpu_check") async def gpu_check(): gpu = 'GPU not available' if torch.cuda.is_available(): gpu = 'GPU is available' print("GPU is available") else: print("GPU is not available") return {'success': True, 'response': 'hello world 3', 'gpu': gpu}