Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 3,012 Bytes
			
			| 80ebcb3 91fb4ef 80ebcb3 91fb4ef 80ebcb3 91fb4ef 80ebcb3 91fb4ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import sys
import traceback
from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger
from finetrainers.config import _get_model_specifiction_cls
from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig
logger = get_logger()
def main():
    try:
        import multiprocessing
        multiprocessing.set_start_method("fork")
    except Exception as e:
        logger.error(
            f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. '
            f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n"
            f"Error: {e}"
        )
    try:
        args = BaseArgs()
        argv = [y.strip() for x in sys.argv for y in x.split()]
        training_type_index = argv.index("--training_type")
        if training_type_index == -1:
            raise ValueError("Training type not provided in command line arguments.")
        training_type = argv[training_type_index + 1]
        training_cls = None
        if training_type == TrainingType.LORA:
            training_cls = SFTLowRankConfig
        elif training_type == TrainingType.FULL_FINETUNE:
            training_cls = SFTFullRankConfig
        else:
            raise ValueError(f"Training type {training_type} not supported.")
        training_config = training_cls()
        args.extend_args(training_config.add_args, training_config.map_args, training_config.validate_args)
        args = args.parse_args()
        model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type)
        model_specification = model_specification_cls(
            pretrained_model_name_or_path=args.pretrained_model_name_or_path,
            tokenizer_id=args.tokenizer_id,
            tokenizer_2_id=args.tokenizer_2_id,
            tokenizer_3_id=args.tokenizer_3_id,
            text_encoder_id=args.text_encoder_id,
            text_encoder_2_id=args.text_encoder_2_id,
            text_encoder_3_id=args.text_encoder_3_id,
            transformer_id=args.transformer_id,
            vae_id=args.vae_id,
            text_encoder_dtype=args.text_encoder_dtype,
            text_encoder_2_dtype=args.text_encoder_2_dtype,
            text_encoder_3_dtype=args.text_encoder_3_dtype,
            transformer_dtype=args.transformer_dtype,
            vae_dtype=args.vae_dtype,
            revision=args.revision,
            cache_dir=args.cache_dir,
        )
        if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]:
            trainer = SFTTrainer(args, model_specification)
        else:
            raise ValueError(f"Training type {args.training_type} not supported.")
        trainer.run()
    except KeyboardInterrupt:
        logger.info("Received keyboard interrupt. Exiting...")
    except Exception as e:
        logger.error(f"An error occurred during training: {e}")
        logger.error(traceback.format_exc())
if __name__ == "__main__":
    main()
 | 
