from dataclasses import dataclass import __main__ import os import ujson from huggingface_hub import hf_hub_download import dataclasses import datetime from typing import Any from dataclasses import dataclass, fields import socket import git import time import torch import sys def torch_load_dnn(path): if path.startswith("http:") or path.startswith("https:"): dnn = torch.hub.load_state_dict_from_url(path, map_location='cpu') else: dnn = torch.load(path, map_location='cpu') return dnn class dotdict(dict): """ dot.notation access to dictionary attributes Credit: derek73 @ https://stackoverflow.com/questions/2352181 """ __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def get_metadata_only(): args = dotdict() args.hostname = socket.gethostname() try: args.git_branch = git.Repo(search_parent_directories=True).active_branch.name args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) except git.exc.InvalidGitRepositoryError as e: pass args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') args.cmd = ' '.join(sys.argv) return args def timestamp(daydir=False): format_str = f"%Y-%m{'/' if daydir else '-'}%d{'/' if daydir else '_'}%H.%M.%S" result = datetime.datetime.now().strftime(format_str) return result @dataclass class DefaultVal: val: Any def __hash__(self): return hash(repr(self.val)) def __eq__(self, other): self.val == other.val @dataclass class RunSettings: """ The defaults here have a special status in Run(), which initially calls assign_defaults(), so these aren't soft defaults in that specific context. """ overwrite: bool = DefaultVal(False) root: str = DefaultVal(os.path.join(os.getcwd(), 'experiments')) experiment: str = DefaultVal('default') index_root: str = DefaultVal(None) name: str = DefaultVal(timestamp(daydir=True)) rank: int = DefaultVal(0) nranks: int = DefaultVal(1) amp: bool = DefaultVal(True) total_visible_gpus = torch.cuda.device_count() gpus: int = DefaultVal(total_visible_gpus) avoid_fork_if_possible: bool = DefaultVal(False) @property def gpus_(self): value = self.gpus if isinstance(value, int): value = list(range(value)) if isinstance(value, str): value = value.split(',') value = list(map(int, value)) value = sorted(list(set(value))) assert all(device_idx in range(0, self.total_visible_gpus) for device_idx in value), value return value @property def index_root_(self): return self.index_root or os.path.join(self.root, self.experiment, 'indexes/') @property def script_name_(self): if '__file__' in dir(__main__): cwd = os.path.abspath(os.getcwd()) script_path = os.path.abspath(__main__.__file__) root_path = os.path.abspath(self.root) if script_path.startswith(cwd): script_path = script_path[len(cwd):] else: try: commonpath = os.path.commonpath([script_path, root_path]) script_path = script_path[len(commonpath):] except: pass assert script_path.endswith('.py') script_name = script_path.replace('/', '.').strip('.')[:-3] assert len(script_name) > 0, (script_name, script_path, cwd) return script_name return 'none' @property def path_(self): return os.path.join(self.root, self.experiment, self.script_name_, self.name) @property def device_(self): return self.gpus_[self.rank % self.nranks] @dataclass class TokenizerSettings: query_token_id: str = DefaultVal("[unused0]") doc_token_id: str = DefaultVal("[unused1]") query_token: str = DefaultVal("[Q]") doc_token: str = DefaultVal("[D]") @dataclass class ResourceSettings: checkpoint: str = DefaultVal(None) triples: str = DefaultVal(None) collection: str = DefaultVal(None) queries: str = DefaultVal(None) index_name: str = DefaultVal(None) name_or_path: str = DefaultVal(None) @dataclass class DocSettings: dim: int = DefaultVal(128) doc_maxlen: int = DefaultVal(220) mask_punctuation: bool = DefaultVal(True) @dataclass class QuerySettings: query_maxlen: int = DefaultVal(32) attend_to_mask_tokens : bool = DefaultVal(False) interaction: str = DefaultVal('colbert') @dataclass class TrainingSettings: similarity: str = DefaultVal('cosine') bsize: int = DefaultVal(32) accumsteps: int = DefaultVal(1) lr: float = DefaultVal(3e-06) maxsteps: int = DefaultVal(500_000) save_every: int = DefaultVal(None) resume: bool = DefaultVal(False) ## NEW: warmup: int = DefaultVal(None) warmup_bert: int = DefaultVal(None) relu: bool = DefaultVal(False) nway: int = DefaultVal(2) use_ib_negatives: bool = DefaultVal(False) reranker: bool = DefaultVal(False) distillation_alpha: float = DefaultVal(1.0) ignore_scores: bool = DefaultVal(False) model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased') @dataclass class IndexingSettings: index_path: str = DefaultVal(None) index_bsize: int = DefaultVal(64) nbits: int = DefaultVal(1) kmeans_niters: int = DefaultVal(4) resume: bool = DefaultVal(False) @property def index_path_(self): return self.index_path or os.path.join(self.index_root_, self.index_name) @dataclass class SearchSettings: ncells: int = DefaultVal(None) centroid_score_threshold: float = DefaultVal(None) ndocs: int = DefaultVal(None) load_index_with_mmap: bool = DefaultVal(False) @dataclass class CoreConfig: def __post_init__(self): """ Source: https://stackoverflow.com/a/58081120/1493011 """ self.assigned = {} for field in fields(self): field_val = getattr(self, field.name) if isinstance(field_val, DefaultVal) or field_val is None: setattr(self, field.name, field.default.val) if not isinstance(field_val, DefaultVal): self.assigned[field.name] = True def assign_defaults(self): for field in fields(self): setattr(self, field.name, field.default.val) self.assigned[field.name] = True def configure(self, ignore_unrecognized=True, **kw_args): ignored = set() for key, value in kw_args.items(): self.set(key, value, ignore_unrecognized) or ignored.update({key}) return ignored """ # TODO: Take a config object, not kw_args. for key in config.assigned: value = getattr(config, key) """ def set(self, key, value, ignore_unrecognized=False): if hasattr(self, key): setattr(self, key, value) self.assigned[key] = True return True if not ignore_unrecognized: raise Exception(f"Unrecognized key `{key}` for {type(self)}") def help(self): print(ujson.dumps(self.export(), indent=4)) def __export_value(self, v): v = v.provenance() if hasattr(v, 'provenance') else v if isinstance(v, list) and len(v) > 100: v = (f"list with {len(v)} elements starting with...", v[:3]) if isinstance(v, dict) and len(v) > 100: v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3]) return v def export(self): d = dataclasses.asdict(self) for k, v in d.items(): d[k] = self.__export_value(v) return d @dataclass class BaseConfig(CoreConfig): @classmethod def from_existing(cls, *sources): kw_args = {} for source in sources: if source is None: continue local_kw_args = dataclasses.asdict(source) local_kw_args = {k: local_kw_args[k] for k in source.assigned} kw_args = {**kw_args, **local_kw_args} obj = cls(**kw_args) return obj @classmethod def from_deprecated_args(cls, args): obj = cls() ignored = obj.configure(ignore_unrecognized=True, **args) return obj, ignored @classmethod def from_path(cls, name): with open(name) as f: args = ujson.load(f) if "config" in args: args = args["config"] return cls.from_deprecated_args( args ) # the new, non-deprecated version functions the same at this level. @classmethod def load_from_checkpoint(cls, checkpoint_path): if checkpoint_path.endswith(".dnn"): dnn = torch_load_dnn(checkpoint_path) config, _ = cls.from_deprecated_args(dnn.get("arguments", {})) # TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those! config.set("checkpoint", checkpoint_path) return config name_or_path = checkpoint_path try: checkpoint_path = hf_hub_download( repo_id=checkpoint_path, filename="artifact.metadata" ).split("artifact")[0] except Exception: pass loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata") if os.path.exists(loaded_config_path): loaded_config, _ = cls.from_path(loaded_config_path) loaded_config.set("checkpoint", checkpoint_path) loaded_config.set("name_or_path", name_or_path) return loaded_config return ( None # can happen if checkpoint_path is something like 'bert-base-uncased' ) @classmethod def load_from_index(cls, index_path): # FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config). # This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c) # default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/') # index_path = os.path.join(default_index_root, index_path) # CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index. try: metadata_path = os.path.join(index_path, "metadata.json") loaded_config, _ = cls.from_path(metadata_path) except: metadata_path = os.path.join(index_path, "plan.json") loaded_config, _ = cls.from_path(metadata_path) return loaded_config def save(self, path, overwrite=False): assert overwrite or not os.path.exists(path), path with open(path, "w") as f: args = self.export() # dict(self.__config) args["meta"] = get_metadata_only() args["meta"]["version"] = "colbert-v0.4" # TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe! f.write(ujson.dumps(args, indent=4) + "\n") def save_for_checkpoint(self, checkpoint_path): assert not checkpoint_path.endswith( ".dnn" ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format." output_config_path = os.path.join(checkpoint_path, "artifact.metadata") self.save(output_config_path, overwrite=True) @dataclass class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings, IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings): pass