File size: 593 Bytes
3455d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#!/usr/bin/env python
# coding=utf-8
""" BaseTuner: a subclass of BasePipeline.
"""

from lmflow.pipeline.base_pipeline import BasePipeline


class BaseAligner(BasePipeline):
    """ A subclass of BasePipeline which is alignable.
    """
    def __init__(self, *args, **kwargs):
        pass

    def _check_if_alignable(self, model, dataset, reward_model):
        # TODO: check if the model is alignable and dataset is compatible
        # TODO: add reward_model
        pass

    def align(self, model, dataset, reward_model):
        raise NotImplementedError(".align is not implemented")