fine-tuning-service / src /image_classification /image_classification_parameters.py
fashxp's picture
initial commit
7c4332a
raw
history blame
1.78 kB
from pydantic import BaseModel
from typing import Annotated
from fastapi import Form
class ImageClassificationTrainingParameters(BaseModel):
epochs: int
learning_rate: float
def map_image_classification_training_parameters(
epocs: Annotated[int, Form(...)] = 3,
learning_rate: Annotated[float, Form(...)] = 5e-5
) -> ImageClassificationTrainingParameters:
return ImageClassificationTrainingParameters(
epochs=epocs,
learning_rate=learning_rate
)
class ImageClassificationParameters:
__training_files_path: str
__training_zip_file_path: str
__result_model_name: str
__source_model_name: str
__training_parameters: ImageClassificationTrainingParameters
def __init__(self,
training_files_path: str,
training_zip_file_path: str,
result_model_name: str,
source_model_name: str,
training_parameters: ImageClassificationTrainingParameters
):
self.__training_files_path = training_files_path
self.__training_zip_file_path = training_zip_file_path
self.__result_model_name = result_model_name
self.__source_model_name = source_model_name
self.__training_parameters = training_parameters
def get_training_files_path(self) -> str:
return self.__training_files_path
def get_training_zip_file(self) -> str:
return self.__training_zip_file_path
def get_result_model_name(self) -> str:
return self.__result_model_name
def get_source_model_name(self) -> str:
return self.__source_model_name
def get_training_parameters(self) -> ImageClassificationTrainingParameters:
return self.__training_parameters