import os | |
from findfile import find_file | |
from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier | |
from anonymous_demo.utils.demo_utils import retry | |
class CheckpointManager: | |
pass | |
class TADCheckpointManager(CheckpointManager): | |
def get_tad_text_classifier(checkpoint: str = None, | |
eval_batch_size=128, | |
**kwargs): | |
tad_text_classifier = TADTextClassifier(checkpoint, eval_batch_size=eval_batch_size, **kwargs) | |
return tad_text_classifier | |