from omegaconf import OmegaConf
from torch.cuda import is_available as use_cuda

model_config = {
  "name": "google/t5-large-ssm-nq",
  "class_name": "AutoModelForSeq2SeqLM",
  "tokenizer_class": "AutoTokenizer",
  "tokenizer_name": "google/t5-large-ssm-nq",
  "inner_params": [
    "encoder.block.22.layer.1.DenseReluDense.wi.weight",
    "encoder.block.22.layer.1.DenseReluDense.wo.weight",
    "encoder.block.23.layer.1.DenseReluDense.wi.weight",
    "encoder.block.23.layer.1.DenseReluDense.wo.weight",
    "decoder.block.22.layer.2.DenseReluDense.wi.weight",
    "decoder.block.22.layer.2.DenseReluDense.wo.weight",
    "decoder.block.23.layer.2.DenseReluDense.wi.weight",
    "decoder.block.23.layer.2.DenseReluDense.wo.weight",
  ],
  "pt": None,
  "small_name": "t5-small",
}

ft_config = OmegaConf.create({
  "device": "cpu",
  "edit_lr": 5e-6,
  "train_base": False,
  "grad_clip": 100,
  "ft": {
    "verbose": False,
    "max_edit_steps": 100,
    "time_limit": None,
    "locality": {
      "enabled": False,
      "oracle": True,
      "cedit": 1e-2,
      "batch_size": 1,
    },
    "rank": None,
    "opt": "RMSprop",
    "init_std": 0.01,
  },
  "model": model_config,
})

lu_config = OmegaConf.create({
  "device": "cpu",
  "lu": {
    "threshold": 2.75,
    "onehot_logit": 1,
  },
  "model": model_config,
})

ke_config = OmegaConf.create({
  "device": "cpu",
  "train_base": False,
  "lr": 1e-5,
  "model": model_config,
})

enn_config = OmegaConf.create({
  "device": "cpu",
  "lr": 1e-5,
  "edit_lr": 1e-2,
  "lr_lr": 1e-3,
  "train_base": True,
  "grad_clip": 100,
  "dropout": 0,
  "no_grad_layers": None,
  "enn": {
    "first_order": False,
    "n_edit_steps": 1,
  },
  "model": model_config,
  "archive": 8684705655, # "/iris/u/clin/code/efk/outputs/2022-02-09_05-48-20_8684705655/models/t5-large-ssm-nq.2022-02-09_05-48-20_8684705655",
})

mend_config = OmegaConf.create({
  "device": "cpu",
  "lr": 1e-6,
  "edit_lr": 1e-4,
  "lr_lr": 1e-4,
  "train_base": True,
  "grad_clip": 100,
  "dropout": 0,
  "no_grad_layers": None,
  "gtn": {
    "one_sided": False,
    "n_hidden": 1,
    "hidden_dim": None,
    "init": "id",
    "norm": True,
    "combine": True,
    "x_only": False,
    "delta_only": False,
    "act": "relu",
    "rank": 1920,
    "mlp_class": "IDMLP",
    "shared": True,
    "descent": False,
  },
  "model": model_config,
  "archive": 5940349945, # "/iris/u/clin/code/efk/outputs/2022-02-09_11-47-28_5940349945/models/t5-large-ssm-nq.2022-02-09_11-47-28_5940349945",
})

serac_config = OmegaConf.create({
  "device": "cpu", # "device": "cuda" if use_cuda() else "cpu",
  "lr": 1e-5,
  "edit_lr": 1e-2,
  "lr_lr": 0,
  "train_base": False,
  "grad_clip": 100,
  "dropout": 0,
  "no_grad_layers": None,
  "rep": {
    "cls_name": "distilbert-base-cased",
    "cls_class": "AutoModel",
    "supervised": "true",
    "cos": False,
    "freeze": None,
    "square": True,
    "bound_embeds": False,
    "use_all_negatives": False,
    "freeze_cntr": False,
    "dist_heads": 1,
    "cross_attend": False,
    "lora": None,
    "soft_weighting": False,
    "checkpoint_grad": False,
    "cache_embeds": True,
  },
  "model": model_config,
  "archive": 4719776130, # "/iris/u/clin/code/efk/outputs/2022-02-09_14-05-56_4719776130/models/t5-large-ssm-nq.2022-02-09_14-05-56_4719776130",
})