from constants import (
    TAESD_MODEL,
    TAESDXL_MODEL,
    TAESD_MODEL_OPENVINO,
    TAESDXL_MODEL_OPENVINO,
)


def get_tiny_decoder_vae_model(pipeline_class) -> str:
    print(f"Pipeline class : {pipeline_class}")
    if (
        pipeline_class == "LatentConsistencyModelPipeline"
        or pipeline_class == "StableDiffusionPipeline"
        or pipeline_class == "StableDiffusionImg2ImgPipeline"
    ):
        return TAESD_MODEL
    elif (
        pipeline_class == "StableDiffusionXLPipeline"
        or pipeline_class == "StableDiffusionXLImg2ImgPipeline"
    ):
        return TAESDXL_MODEL
    elif (
        pipeline_class == "OVStableDiffusionPipeline"
        or pipeline_class == "OVStableDiffusionImg2ImgPipeline"
    ):
        return TAESD_MODEL_OPENVINO
    elif pipeline_class == "OVStableDiffusionXLPipeline":
        return TAESDXL_MODEL_OPENVINO
    else:
        raise Exception("No valid pipeline class found!")