ConstBERT / colbert_configuration.py
caesar-one's picture
Upload ConstBERT
9d3ebbc verified
from dataclasses import dataclass
import __main__
import os
import ujson
from huggingface_hub import hf_hub_download
import dataclasses
import datetime
from typing import Any
from dataclasses import dataclass, fields
import socket
import git
import time
import torch
import sys
def torch_load_dnn(path):
if path.startswith("http:") or path.startswith("https:"):
dnn = torch.hub.load_state_dict_from_url(path, map_location='cpu')
else:
dnn = torch.load(path, map_location='cpu')
return dnn
class dotdict(dict):
"""
dot.notation access to dictionary attributes
Credit: derek73 @ https://stackoverflow.com/questions/2352181
"""
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def get_metadata_only():
args = dotdict()
args.hostname = socket.gethostname()
try:
args.git_branch = git.Repo(search_parent_directories=True).active_branch.name
args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime)
except git.exc.InvalidGitRepositoryError as e:
pass
args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)')
args.cmd = ' '.join(sys.argv)
return args
def timestamp(daydir=False):
format_str = f"%Y-%m{'/' if daydir else '-'}%d{'/' if daydir else '_'}%H.%M.%S"
result = datetime.datetime.now().strftime(format_str)
return result
@dataclass
class DefaultVal:
val: Any
def __hash__(self):
return hash(repr(self.val))
def __eq__(self, other):
self.val == other.val
@dataclass
class RunSettings:
"""
The defaults here have a special status in Run(), which initially calls assign_defaults(),
so these aren't soft defaults in that specific context.
"""
overwrite: bool = DefaultVal(False)
root: str = DefaultVal(os.path.join(os.getcwd(), 'experiments'))
experiment: str = DefaultVal('default')
index_root: str = DefaultVal(None)
name: str = DefaultVal(timestamp(daydir=True))
rank: int = DefaultVal(0)
nranks: int = DefaultVal(1)
amp: bool = DefaultVal(True)
total_visible_gpus = torch.cuda.device_count()
gpus: int = DefaultVal(total_visible_gpus)
avoid_fork_if_possible: bool = DefaultVal(False)
@property
def gpus_(self):
value = self.gpus
if isinstance(value, int):
value = list(range(value))
if isinstance(value, str):
value = value.split(',')
value = list(map(int, value))
value = sorted(list(set(value)))
assert all(device_idx in range(0, self.total_visible_gpus) for device_idx in value), value
return value
@property
def index_root_(self):
return self.index_root or os.path.join(self.root, self.experiment, 'indexes/')
@property
def script_name_(self):
if '__file__' in dir(__main__):
cwd = os.path.abspath(os.getcwd())
script_path = os.path.abspath(__main__.__file__)
root_path = os.path.abspath(self.root)
if script_path.startswith(cwd):
script_path = script_path[len(cwd):]
else:
try:
commonpath = os.path.commonpath([script_path, root_path])
script_path = script_path[len(commonpath):]
except:
pass
assert script_path.endswith('.py')
script_name = script_path.replace('/', '.').strip('.')[:-3]
assert len(script_name) > 0, (script_name, script_path, cwd)
return script_name
return 'none'
@property
def path_(self):
return os.path.join(self.root, self.experiment, self.script_name_, self.name)
@property
def device_(self):
return self.gpus_[self.rank % self.nranks]
@dataclass
class TokenizerSettings:
query_token_id: str = DefaultVal("[unused0]")
doc_token_id: str = DefaultVal("[unused1]")
query_token: str = DefaultVal("[Q]")
doc_token: str = DefaultVal("[D]")
@dataclass
class ResourceSettings:
checkpoint: str = DefaultVal(None)
triples: str = DefaultVal(None)
collection: str = DefaultVal(None)
queries: str = DefaultVal(None)
index_name: str = DefaultVal(None)
name_or_path: str = DefaultVal(None)
@dataclass
class DocSettings:
dim: int = DefaultVal(128)
doc_maxlen: int = DefaultVal(220)
mask_punctuation: bool = DefaultVal(True)
@dataclass
class QuerySettings:
query_maxlen: int = DefaultVal(32)
attend_to_mask_tokens : bool = DefaultVal(False)
interaction: str = DefaultVal('colbert')
@dataclass
class TrainingSettings:
similarity: str = DefaultVal('cosine')
bsize: int = DefaultVal(32)
accumsteps: int = DefaultVal(1)
lr: float = DefaultVal(3e-06)
maxsteps: int = DefaultVal(500_000)
save_every: int = DefaultVal(None)
resume: bool = DefaultVal(False)
## NEW:
warmup: int = DefaultVal(None)
warmup_bert: int = DefaultVal(None)
relu: bool = DefaultVal(False)
nway: int = DefaultVal(2)
use_ib_negatives: bool = DefaultVal(False)
reranker: bool = DefaultVal(False)
distillation_alpha: float = DefaultVal(1.0)
ignore_scores: bool = DefaultVal(False)
model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased')
@dataclass
class IndexingSettings:
index_path: str = DefaultVal(None)
index_bsize: int = DefaultVal(64)
nbits: int = DefaultVal(1)
kmeans_niters: int = DefaultVal(4)
resume: bool = DefaultVal(False)
@property
def index_path_(self):
return self.index_path or os.path.join(self.index_root_, self.index_name)
@dataclass
class SearchSettings:
ncells: int = DefaultVal(None)
centroid_score_threshold: float = DefaultVal(None)
ndocs: int = DefaultVal(None)
load_index_with_mmap: bool = DefaultVal(False)
@dataclass
class CoreConfig:
def __post_init__(self):
"""
Source: https://stackoverflow.com/a/58081120/1493011
"""
self.assigned = {}
for field in fields(self):
field_val = getattr(self, field.name)
if isinstance(field_val, DefaultVal) or field_val is None:
setattr(self, field.name, field.default.val)
if not isinstance(field_val, DefaultVal):
self.assigned[field.name] = True
def assign_defaults(self):
for field in fields(self):
setattr(self, field.name, field.default.val)
self.assigned[field.name] = True
def configure(self, ignore_unrecognized=True, **kw_args):
ignored = set()
for key, value in kw_args.items():
self.set(key, value, ignore_unrecognized) or ignored.update({key})
return ignored
"""
# TODO: Take a config object, not kw_args.
for key in config.assigned:
value = getattr(config, key)
"""
def set(self, key, value, ignore_unrecognized=False):
if hasattr(self, key):
setattr(self, key, value)
self.assigned[key] = True
return True
if not ignore_unrecognized:
raise Exception(f"Unrecognized key `{key}` for {type(self)}")
def help(self):
print(ujson.dumps(self.export(), indent=4))
def __export_value(self, v):
v = v.provenance() if hasattr(v, 'provenance') else v
if isinstance(v, list) and len(v) > 100:
v = (f"list with {len(v)} elements starting with...", v[:3])
if isinstance(v, dict) and len(v) > 100:
v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3])
return v
def export(self):
d = dataclasses.asdict(self)
for k, v in d.items():
d[k] = self.__export_value(v)
return d
@dataclass
class BaseConfig(CoreConfig):
@classmethod
def from_existing(cls, *sources):
kw_args = {}
for source in sources:
if source is None:
continue
local_kw_args = dataclasses.asdict(source)
local_kw_args = {k: local_kw_args[k] for k in source.assigned}
kw_args = {**kw_args, **local_kw_args}
obj = cls(**kw_args)
return obj
@classmethod
def from_deprecated_args(cls, args):
obj = cls()
ignored = obj.configure(ignore_unrecognized=True, **args)
return obj, ignored
@classmethod
def from_path(cls, name):
with open(name) as f:
args = ujson.load(f)
if "config" in args:
args = args["config"]
return cls.from_deprecated_args(
args
) # the new, non-deprecated version functions the same at this level.
@classmethod
def load_from_checkpoint(cls, checkpoint_path):
if checkpoint_path.endswith(".dnn"):
dnn = torch_load_dnn(checkpoint_path)
config, _ = cls.from_deprecated_args(dnn.get("arguments", {}))
# TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those!
config.set("checkpoint", checkpoint_path)
return config
name_or_path = checkpoint_path
try:
checkpoint_path = hf_hub_download(
repo_id=checkpoint_path, filename="artifact.metadata"
).split("artifact")[0]
except Exception:
pass
loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata")
if os.path.exists(loaded_config_path):
loaded_config, _ = cls.from_path(loaded_config_path)
loaded_config.set("checkpoint", checkpoint_path)
loaded_config.set("name_or_path", name_or_path)
return loaded_config
return (
None # can happen if checkpoint_path is something like 'bert-base-uncased'
)
@classmethod
def load_from_index(cls, index_path):
# FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config).
# This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c)
# default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/')
# index_path = os.path.join(default_index_root, index_path)
# CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index.
try:
metadata_path = os.path.join(index_path, "metadata.json")
loaded_config, _ = cls.from_path(metadata_path)
except:
metadata_path = os.path.join(index_path, "plan.json")
loaded_config, _ = cls.from_path(metadata_path)
return loaded_config
def save(self, path, overwrite=False):
assert overwrite or not os.path.exists(path), path
with open(path, "w") as f:
args = self.export() # dict(self.__config)
args["meta"] = get_metadata_only()
args["meta"]["version"] = "colbert-v0.4"
# TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe!
f.write(ujson.dumps(args, indent=4) + "\n")
def save_for_checkpoint(self, checkpoint_path):
assert not checkpoint_path.endswith(
".dnn"
), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format."
output_config_path = os.path.join(checkpoint_path, "artifact.metadata")
self.save(output_config_path, overwrite=True)
@dataclass
class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings,
IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings):
pass