Spaces:
Running
Running
from typing import Dict, Any | |
from torch import nn | |
from data.datasets.ab_dataset import ABDataset | |
from abc import ABC, abstractmethod | |
from utils.common.log import logger | |
import json | |
import os | |
from utils.common.others import backup_key_codes | |
from .model import BaseModel | |
from data import Scenario | |
from schema import Schema | |
from utils.common.data_record import write_json | |
class BaseAlg(ABC): | |
def __init__(self, models: Dict[str, BaseModel], res_save_dir): | |
self.models = models | |
self.res_save_dir = res_save_dir | |
self.get_required_models_schema().validate(models) | |
os.makedirs(res_save_dir) | |
logger.info(f'[alg] init alg: {self.__class__.__name__}, res saved in {res_save_dir}') | |
def get_required_models_schema(self) -> Schema: | |
raise NotImplementedError | |
def get_required_hyp_schema(self) -> Schema: | |
raise NotImplementedError | |
def run(self, | |
scenario: Scenario, | |
hyps: Dict) -> Dict[str, Any]: | |
""" | |
return metrics | |
""" | |
self.get_required_hyp_schema().validate(hyps) | |
try: | |
write_json(os.path.join(self.res_save_dir, 'hyps.json'), hyps, ensure_obj_serializable=True) | |
except: | |
with open(os.path.join(self.res_save_dir, 'hyps.txt'), 'w') as f: | |
f.write(str(hyps)) | |
write_json(os.path.join(self.res_save_dir, 'scenario.json'), scenario.to_json()) | |
logger.info(f'[alg] alg {self.__class__.__name__} start running') | |
backup_key_codes(os.path.join(self.res_save_dir, 'backup_codes')) |