Update model_keeper.py
Browse files- model_keeper.py +25 -21
model_keeper.py
CHANGED
@@ -30,32 +30,32 @@ class KeeperModelForCausalLM(PreTrainedModel):
|
|
30 |
self.bert = None
|
31 |
self.llm = None
|
32 |
|
33 |
-
if cfg:
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
|
53 |
-
|
54 |
|
55 |
-
|
56 |
-
else:
|
57 |
-
|
58 |
-
|
59 |
|
60 |
self.n_cands = n_cands
|
61 |
self.update_both = update_both
|
@@ -81,6 +81,10 @@ class KeeperModelForCausalLM(PreTrainedModel):
|
|
81 |
self.prompt_right = state_dict["prompt_right"].to(device)
|
82 |
if "respuesta" in state_dict:
|
83 |
self.respuesta = state_dict["respuesta"].to(device)
|
|
|
|
|
|
|
|
|
84 |
else:
|
85 |
# Optionally handle the case where CUDA is not available
|
86 |
print("CUDA is not available. Tensors will remain on CPU.")
|
|
|
30 |
self.bert = None
|
31 |
self.llm = None
|
32 |
|
33 |
+
# if cfg:
|
34 |
+
# print("Initializing KeeperModelForCausalLM from cfg")
|
35 |
+
# # Inicialización con configuración
|
36 |
|
37 |
+
# self.bert = AutoModel.from_pretrained(cfg.retriever_config['_name_or_path'])
|
38 |
|
39 |
+
# bnb_config = BitsAndBytesConfig(
|
40 |
+
# load_in_4bit=True,
|
41 |
+
# bnb_4bit_quant_type="nf4",
|
42 |
+
# bnb_4bit_compute_dtype=torch.bfloat16
|
43 |
+
# )
|
44 |
|
45 |
+
# self.llm = AutoModelForCausalLM.from_pretrained(
|
46 |
+
# cfg.model_config['_name_or_path'],
|
47 |
+
# device_map=cfg.device_map,
|
48 |
+
# torch_dtype=torch.bfloat16,
|
49 |
+
# quantization_config=bnb_config
|
50 |
+
# )
|
51 |
|
52 |
+
# # Almacena kwargs para la serialización y carga futura
|
53 |
+
# # self.init_kwargs = {'cfg': cfg}
|
54 |
|
55 |
+
# print("Initialization complete")
|
56 |
+
# else:
|
57 |
+
# # Si cfg no se proporciona, esto se manejará en el método from_pretrained
|
58 |
+
# print("Initializing KeeperTokenizer without cfg")
|
59 |
|
60 |
self.n_cands = n_cands
|
61 |
self.update_both = update_both
|
|
|
81 |
self.prompt_right = state_dict["prompt_right"].to(device)
|
82 |
if "respuesta" in state_dict:
|
83 |
self.respuesta = state_dict["respuesta"].to(device)
|
84 |
+
if "bert" in state_dict:
|
85 |
+
self.bert = state_dict["bert"].to(device)
|
86 |
+
if "llm" in state_dict:
|
87 |
+
self.llm = state_dict["llm"].to(device)
|
88 |
else:
|
89 |
# Optionally handle the case where CUDA is not available
|
90 |
print("CUDA is not available. Tensors will remain on CPU.")
|