|
from dataclasses import dataclass |
|
|
|
import __main__ |
|
|
|
import os |
|
import ujson |
|
from huggingface_hub import hf_hub_download |
|
import dataclasses |
|
import datetime |
|
from typing import Any |
|
from dataclasses import dataclass, fields |
|
import socket |
|
import git |
|
import time |
|
import torch |
|
import sys |
|
|
|
def torch_load_dnn(path): |
|
if path.startswith("http:") or path.startswith("https:"): |
|
dnn = torch.hub.load_state_dict_from_url(path, map_location='cpu') |
|
else: |
|
dnn = torch.load(path, map_location='cpu') |
|
|
|
return dnn |
|
|
|
class dotdict(dict): |
|
""" |
|
dot.notation access to dictionary attributes |
|
Credit: derek73 @ https://stackoverflow.com/questions/2352181 |
|
""" |
|
__getattr__ = dict.__getitem__ |
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
def get_metadata_only(): |
|
args = dotdict() |
|
|
|
args.hostname = socket.gethostname() |
|
try: |
|
args.git_branch = git.Repo(search_parent_directories=True).active_branch.name |
|
args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha |
|
args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) |
|
except git.exc.InvalidGitRepositoryError as e: |
|
pass |
|
args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') |
|
args.cmd = ' '.join(sys.argv) |
|
|
|
return args |
|
|
|
def timestamp(daydir=False): |
|
format_str = f"%Y-%m{'/' if daydir else '-'}%d{'/' if daydir else '_'}%H.%M.%S" |
|
result = datetime.datetime.now().strftime(format_str) |
|
return result |
|
|
|
@dataclass |
|
class DefaultVal: |
|
val: Any |
|
|
|
def __hash__(self): |
|
return hash(repr(self.val)) |
|
|
|
def __eq__(self, other): |
|
self.val == other.val |
|
|
|
@dataclass |
|
class RunSettings: |
|
""" |
|
The defaults here have a special status in Run(), which initially calls assign_defaults(), |
|
so these aren't soft defaults in that specific context. |
|
""" |
|
|
|
overwrite: bool = DefaultVal(False) |
|
|
|
root: str = DefaultVal(os.path.join(os.getcwd(), 'experiments')) |
|
experiment: str = DefaultVal('default') |
|
|
|
index_root: str = DefaultVal(None) |
|
name: str = DefaultVal(timestamp(daydir=True)) |
|
|
|
rank: int = DefaultVal(0) |
|
nranks: int = DefaultVal(1) |
|
amp: bool = DefaultVal(True) |
|
|
|
total_visible_gpus = torch.cuda.device_count() |
|
gpus: int = DefaultVal(total_visible_gpus) |
|
|
|
avoid_fork_if_possible: bool = DefaultVal(False) |
|
|
|
@property |
|
def gpus_(self): |
|
value = self.gpus |
|
|
|
if isinstance(value, int): |
|
value = list(range(value)) |
|
|
|
if isinstance(value, str): |
|
value = value.split(',') |
|
|
|
value = list(map(int, value)) |
|
value = sorted(list(set(value))) |
|
|
|
assert all(device_idx in range(0, self.total_visible_gpus) for device_idx in value), value |
|
|
|
return value |
|
|
|
@property |
|
def index_root_(self): |
|
return self.index_root or os.path.join(self.root, self.experiment, 'indexes/') |
|
|
|
@property |
|
def script_name_(self): |
|
if '__file__' in dir(__main__): |
|
cwd = os.path.abspath(os.getcwd()) |
|
script_path = os.path.abspath(__main__.__file__) |
|
root_path = os.path.abspath(self.root) |
|
|
|
if script_path.startswith(cwd): |
|
script_path = script_path[len(cwd):] |
|
|
|
else: |
|
try: |
|
commonpath = os.path.commonpath([script_path, root_path]) |
|
script_path = script_path[len(commonpath):] |
|
except: |
|
pass |
|
|
|
|
|
assert script_path.endswith('.py') |
|
script_name = script_path.replace('/', '.').strip('.')[:-3] |
|
|
|
assert len(script_name) > 0, (script_name, script_path, cwd) |
|
|
|
return script_name |
|
|
|
return 'none' |
|
|
|
@property |
|
def path_(self): |
|
return os.path.join(self.root, self.experiment, self.script_name_, self.name) |
|
|
|
@property |
|
def device_(self): |
|
return self.gpus_[self.rank % self.nranks] |
|
|
|
|
|
@dataclass |
|
class TokenizerSettings: |
|
query_token_id: str = DefaultVal("[unused0]") |
|
doc_token_id: str = DefaultVal("[unused1]") |
|
query_token: str = DefaultVal("[Q]") |
|
doc_token: str = DefaultVal("[D]") |
|
|
|
|
|
@dataclass |
|
class ResourceSettings: |
|
checkpoint: str = DefaultVal(None) |
|
triples: str = DefaultVal(None) |
|
collection: str = DefaultVal(None) |
|
queries: str = DefaultVal(None) |
|
index_name: str = DefaultVal(None) |
|
name_or_path: str = DefaultVal(None) |
|
|
|
|
|
@dataclass |
|
class DocSettings: |
|
dim: int = DefaultVal(128) |
|
doc_maxlen: int = DefaultVal(220) |
|
mask_punctuation: bool = DefaultVal(True) |
|
|
|
|
|
@dataclass |
|
class QuerySettings: |
|
query_maxlen: int = DefaultVal(32) |
|
attend_to_mask_tokens : bool = DefaultVal(False) |
|
interaction: str = DefaultVal('colbert') |
|
|
|
|
|
@dataclass |
|
class TrainingSettings: |
|
similarity: str = DefaultVal('cosine') |
|
|
|
bsize: int = DefaultVal(32) |
|
|
|
accumsteps: int = DefaultVal(1) |
|
|
|
lr: float = DefaultVal(3e-06) |
|
|
|
maxsteps: int = DefaultVal(500_000) |
|
|
|
save_every: int = DefaultVal(None) |
|
|
|
resume: bool = DefaultVal(False) |
|
|
|
|
|
warmup: int = DefaultVal(None) |
|
|
|
warmup_bert: int = DefaultVal(None) |
|
|
|
relu: bool = DefaultVal(False) |
|
|
|
nway: int = DefaultVal(2) |
|
|
|
use_ib_negatives: bool = DefaultVal(False) |
|
|
|
reranker: bool = DefaultVal(False) |
|
|
|
distillation_alpha: float = DefaultVal(1.0) |
|
|
|
ignore_scores: bool = DefaultVal(False) |
|
|
|
model_name: str = DefaultVal(None) |
|
|
|
@dataclass |
|
class IndexingSettings: |
|
index_path: str = DefaultVal(None) |
|
|
|
index_bsize: int = DefaultVal(64) |
|
|
|
nbits: int = DefaultVal(1) |
|
|
|
kmeans_niters: int = DefaultVal(4) |
|
|
|
resume: bool = DefaultVal(False) |
|
|
|
@property |
|
def index_path_(self): |
|
return self.index_path or os.path.join(self.index_root_, self.index_name) |
|
|
|
@dataclass |
|
class SearchSettings: |
|
ncells: int = DefaultVal(None) |
|
centroid_score_threshold: float = DefaultVal(None) |
|
ndocs: int = DefaultVal(None) |
|
load_index_with_mmap: bool = DefaultVal(False) |
|
|
|
|
|
@dataclass |
|
class CoreConfig: |
|
def __post_init__(self): |
|
""" |
|
Source: https://stackoverflow.com/a/58081120/1493011 |
|
""" |
|
|
|
self.assigned = {} |
|
|
|
for field in fields(self): |
|
field_val = getattr(self, field.name) |
|
|
|
if isinstance(field_val, DefaultVal) or field_val is None: |
|
setattr(self, field.name, field.default.val) |
|
|
|
if not isinstance(field_val, DefaultVal): |
|
self.assigned[field.name] = True |
|
|
|
def assign_defaults(self): |
|
for field in fields(self): |
|
setattr(self, field.name, field.default.val) |
|
self.assigned[field.name] = True |
|
|
|
def configure(self, ignore_unrecognized=True, **kw_args): |
|
ignored = set() |
|
|
|
for key, value in kw_args.items(): |
|
self.set(key, value, ignore_unrecognized) or ignored.update({key}) |
|
|
|
return ignored |
|
|
|
""" |
|
# TODO: Take a config object, not kw_args. |
|
|
|
for key in config.assigned: |
|
value = getattr(config, key) |
|
""" |
|
|
|
def set(self, key, value, ignore_unrecognized=False): |
|
if hasattr(self, key): |
|
setattr(self, key, value) |
|
self.assigned[key] = True |
|
return True |
|
|
|
if not ignore_unrecognized: |
|
raise Exception(f"Unrecognized key `{key}` for {type(self)}") |
|
|
|
def help(self): |
|
print(ujson.dumps(self.export(), indent=4)) |
|
|
|
def __export_value(self, v): |
|
v = v.provenance() if hasattr(v, 'provenance') else v |
|
|
|
if isinstance(v, list) and len(v) > 100: |
|
v = (f"list with {len(v)} elements starting with...", v[:3]) |
|
|
|
if isinstance(v, dict) and len(v) > 100: |
|
v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3]) |
|
|
|
return v |
|
|
|
def export(self): |
|
d = dataclasses.asdict(self) |
|
|
|
for k, v in d.items(): |
|
d[k] = self.__export_value(v) |
|
|
|
return d |
|
|
|
@dataclass |
|
class BaseConfig(CoreConfig): |
|
@classmethod |
|
def from_existing(cls, *sources): |
|
kw_args = {} |
|
|
|
for source in sources: |
|
if source is None: |
|
continue |
|
|
|
local_kw_args = dataclasses.asdict(source) |
|
local_kw_args = {k: local_kw_args[k] for k in source.assigned} |
|
kw_args = {**kw_args, **local_kw_args} |
|
|
|
obj = cls(**kw_args) |
|
|
|
return obj |
|
|
|
@classmethod |
|
def from_deprecated_args(cls, args): |
|
obj = cls() |
|
ignored = obj.configure(ignore_unrecognized=True, **args) |
|
|
|
return obj, ignored |
|
|
|
@classmethod |
|
def from_path(cls, name): |
|
with open(name) as f: |
|
args = ujson.load(f) |
|
|
|
if "config" in args: |
|
args = args["config"] |
|
|
|
return cls.from_deprecated_args( |
|
args |
|
) |
|
|
|
@classmethod |
|
def load_from_checkpoint(cls, checkpoint_path): |
|
if checkpoint_path.endswith(".dnn"): |
|
dnn = torch_load_dnn(checkpoint_path) |
|
config, _ = cls.from_deprecated_args(dnn.get("arguments", {})) |
|
|
|
|
|
config.set("checkpoint", checkpoint_path) |
|
|
|
return config |
|
|
|
name_or_path = checkpoint_path |
|
try: |
|
checkpoint_path = hf_hub_download( |
|
repo_id=checkpoint_path, filename="artifact.metadata" |
|
).split("artifact")[0] |
|
except Exception: |
|
pass |
|
loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata") |
|
if os.path.exists(loaded_config_path): |
|
loaded_config, _ = cls.from_path(loaded_config_path) |
|
loaded_config.set("checkpoint", checkpoint_path) |
|
loaded_config.set("name_or_path", name_or_path) |
|
|
|
return loaded_config |
|
|
|
return ( |
|
None |
|
) |
|
|
|
@classmethod |
|
def load_from_index(cls, index_path): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
metadata_path = os.path.join(index_path, "metadata.json") |
|
loaded_config, _ = cls.from_path(metadata_path) |
|
except: |
|
metadata_path = os.path.join(index_path, "plan.json") |
|
loaded_config, _ = cls.from_path(metadata_path) |
|
|
|
return loaded_config |
|
|
|
def save(self, path, overwrite=False): |
|
assert overwrite or not os.path.exists(path), path |
|
|
|
with open(path, "w") as f: |
|
args = self.export() |
|
args["meta"] = get_metadata_only() |
|
args["meta"]["version"] = "colbert-v0.4" |
|
|
|
|
|
f.write(ujson.dumps(args, indent=4) + "\n") |
|
|
|
def save_for_checkpoint(self, checkpoint_path): |
|
assert not checkpoint_path.endswith( |
|
".dnn" |
|
), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format." |
|
|
|
output_config_path = os.path.join(checkpoint_path, "artifact.metadata") |
|
self.save(output_config_path, overwrite=True) |
|
|
|
|
|
@dataclass |
|
class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings, |
|
IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings): |
|
pass |