File size: 1,260 Bytes
3b13f40 c6fe3c5 24d96ab c6fe3c5 3b13f40 c6fe3c5 24d96ab c6fe3c5 24d96ab c6fe3c5 24d96ab c6fe3c5 24d96ab c6fe3c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from src import config
from src import data
from src import loss
from src import models
from src import tokenizer as tk
from src import vision_model
from src import utils
from src.lightning_module import LightningModule
def train(trainer_config: config.TrainerConfig):
transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
train_dl, valid_dl = data.get_dataset(
transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config # type: ignore
)
vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config)
text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config)
lightning_module = LightningModule(
vision_encoder=vision_encoder,
text_encoder=text_encoder,
loss_fn=loss.get_loss(trainer_config._model_config.loss_type),
hyper_parameters=trainer_config,
len_train_dl=len(train_dl),
)
trainer = utils.get_trainer(trainer_config)
trainer.fit(lightning_module, train_dl, valid_dl)
if __name__ == "__main__":
trainer_config = config.TrainerConfig(debug=True)
train(trainer_config)
|