|
from typing import Dict, Literal |
|
|
|
from finetune.trainer import Trainer |
|
|
|
|
|
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {} |
|
|
|
|
|
def register(model_name: str, training_type: Literal["lora", "sft", "controlnet"], trainer_cls: Trainer): |
|
"""Register a model and its associated functions for a specific training type. |
|
|
|
Args: |
|
model_name (str): Name of the model to register (e.g. "cogvideox-5b") |
|
training_type (Literal["lora", "sft", "controlnet"]): Type of training - either "lora" or "sft" or "controlnet" |
|
trainer_cls (Trainer): Trainer class to register. |
|
""" |
|
|
|
|
|
if model_name not in SUPPORTED_MODELS: |
|
SUPPORTED_MODELS[model_name] = {} |
|
else: |
|
if training_type in SUPPORTED_MODELS[model_name]: |
|
raise ValueError(f"Training type {training_type} already exists for model {model_name}") |
|
|
|
SUPPORTED_MODELS[model_name][training_type] = trainer_cls |
|
|
|
|
|
def show_supported_models(): |
|
"""Print all currently supported models and their training types.""" |
|
|
|
print("\nSupported Models:") |
|
print("================") |
|
|
|
for model_name, training_types in SUPPORTED_MODELS.items(): |
|
print(f"\n{model_name}") |
|
print("-" * len(model_name)) |
|
for training_type in training_types: |
|
print(f" • {training_type}") |
|
|
|
|
|
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer: |
|
"""Get the trainer class for a specific model and training type.""" |
|
if model_type not in SUPPORTED_MODELS: |
|
print(f"\nModel '{model_type}' is not supported.") |
|
print("\nSupported models are:") |
|
for supported_model in SUPPORTED_MODELS: |
|
print(f" • {supported_model}") |
|
raise ValueError(f"Model '{model_type}' is not supported") |
|
|
|
if training_type not in SUPPORTED_MODELS[model_type]: |
|
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.") |
|
print(f"\nSupported training types for '{model_type}' are:") |
|
for supported_type in SUPPORTED_MODELS[model_type]: |
|
print(f" • {supported_type}") |
|
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'") |
|
|
|
return SUPPORTED_MODELS[model_type][training_type] |
|
|