|
from .configuration_keeper import KeeperConfig |
|
|
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoConfig, |
|
AutoModel, |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
AutoModelForCausalLM, |
|
BitsAndBytesConfig |
|
) |
|
|
|
from typing import Dict |
|
import torch |
|
import numpy as np |
|
from einops import rearrange |
|
|
|
class KeeperModelForCausalLM(PreTrainedModel): |
|
""" |
|
ColBERT model from: https://arxiv.org/pdf/2004.12832.pdf |
|
We use a dot-product instead of cosine per term (slightly better) |
|
""" |
|
config_class = KeeperConfig |
|
base_model_prefix = "keeper_model" |
|
|
|
def __init__(self, cfg, n_cands=8, update_both=False) -> None: |
|
super().__init__(cfg) |
|
|
|
self.bert = None |
|
self.llm = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.n_cands = n_cands |
|
self.update_both = update_both |
|
print(f"Model n_cands: {self.n_cands}") |
|
|
|
|
|
def _load_from_state_dict(self, state_dict, *args, **kwargs): |
|
super()._load_from_state_dict(state_dict, *args, **kwargs) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device('cuda') |
|
if "document_retriever_text" in state_dict: |
|
self.document_retriever_text = state_dict["document_retriever_text"].to(device) |
|
if "document_retriever_mask" in state_dict: |
|
self.document_retriever_mask = state_dict["document_retriever_mask"].to(device) |
|
if "document_retriever_type" in state_dict: |
|
self.document_retriever_type = state_dict["document_retriever_type"].to(device) |
|
if "document_model_text" in state_dict: |
|
self.document_model_text = state_dict["document_model_text"].to(device) |
|
if "prompt_left" in state_dict: |
|
self.prompt_left = state_dict["prompt_left"].to(device) |
|
if "prompt_right" in state_dict: |
|
self.prompt_right = state_dict["prompt_right"].to(device) |
|
if "respuesta" in state_dict: |
|
self.respuesta = state_dict["respuesta"].to(device) |
|
if "bert" in state_dict: |
|
self.bert = state_dict["bert"].to(device) |
|
if "llm" in state_dict: |
|
self.llm = state_dict["llm"].to(device) |
|
else: |
|
|
|
print("CUDA is not available. Tensors will remain on CPU.") |
|
|
|
|
|
def generate(self, query: Dict[str, torch.LongTensor], k: int = 3, max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True, **kwargs): |
|
|
|
query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()} |
|
|
|
topk_texts = self.document_extractor(query, k) |
|
|
|
concatenated_texts = torch.cat(topk_texts, dim=0) |
|
|
|
T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1) |
|
|
|
prompt_length = T.shape[1] |
|
|
|
outputs = self.llm.generate(input_ids=T,max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, do_sample=do_sample) |
|
|
|
return outputs[0][prompt_length:].unsqueeze(0) |
|
|
|
def document_extractor(self, query: Dict[str, torch.LongTensor], k_val: int = 3, **kwargs): |
|
|
|
query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()} |
|
|
|
query_vecs = self.forward_representation(query_retriever) |
|
|
|
doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type} |
|
|
|
document_vecs = self.forward_representation(doc_dic, sequence_type="doc") |
|
|
|
self.score = self.forward_aggregation(query_vecs, query['tokens_retriever']["attention_mask"], document_vecs, self.document_retriever_mask) |
|
|
|
k_val = min(k_val, self.score.numel()) |
|
|
|
topk_scores, topk_indices = torch.topk(self.score, k_val) |
|
|
|
return [self.document_model_text[i,:] for i in topk_indices[0].tolist()] |
|
|
|
def forward_representation(self, |
|
tokens, |
|
max_seq_len = 128, |
|
sequence_type=None) -> torch.Tensor: |
|
|
|
if sequence_type == "doc": |
|
if self.update_both: |
|
with torch.no_grad(): |
|
vecs = self.bert(**tokens)[0] |
|
else: |
|
with torch.no_grad(): |
|
with torch.no_grad(): |
|
vecs = self.bert(**tokens)[0] |
|
else: |
|
with torch.no_grad(): |
|
vecs = self.bert(**tokens)[0] |
|
|
|
return vecs |
|
|
|
def forward_aggregation(self, query_vecs, query_mask, document_vecs, document_mask): |
|
|
|
|
|
|
|
|
|
|
|
_bsz = query_vecs.shape[0] |
|
n_cands = document_vecs.shape[0] // _bsz |
|
query_vecs_dup = query_vecs.repeat_interleave(n_cands, dim=0).contiguous() |
|
|
|
score = torch.bmm(query_vecs_dup, document_vecs.transpose(1, 2)) |
|
exp_mask = document_mask.bool().unsqueeze(1).expand(-1, score.shape[1], -1) |
|
score[~exp_mask] = - 10000 |
|
|
|
|
|
score = score.max(-1).values |
|
query_mask_dup = query_mask.repeat_interleave(n_cands, dim=0).contiguous() |
|
|
|
score[~(query_mask_dup.bool())] = 0 |
|
score = rearrange(score.sum(-1), '(b n) -> b n', n=n_cands) |
|
return score |
|
|
|
def prompt(self, left_p = None, right_p = None): |
|
if left_p is None: |
|
left_p = """ <bos><start_of_turn>user |
|
Eres un experto en cultura paraguaya que responde de forma clara, amable y concisa. |
|
Segun el siguiente contexto: |
|
------------------------------- |
|
""" |
|
if right_p is None: |
|
right_p = """ |
|
------------------------------- |
|
- Solamente puedes responder usando el contexto de arriba, si no se encuentra en el contexto mencionar: 'No tengo informacion sobre eso'. |
|
- Si encuentras la respuesta puedes copiarla. |
|
- Debes responder solamente en Espanol. |
|
Pregunta: """ |
|
return left_p, right_p |
|
|
|
def save_docs(self, docs: list, tokenizer, max_seq_len=128): |
|
|
|
prompt_left, prompt_right = self.prompt() |
|
prompt_left_output = tokenizer.encode(prompt_left) |
|
prompt_right_output = tokenizer.encode(prompt_right) |
|
|
|
|
|
doc_outputs = tokenizer.encode(docs, max_length=max_seq_len, padding='max_length', truncation=True) |
|
|
|
|
|
doc_outputs = {k: v.to("cuda") for k, v in doc_outputs.items()} |
|
prompt_left_output = {k: v.to("cuda") for k, v in prompt_left_output.items()} |
|
prompt_right_output = {k: v.to("cuda") for k, v in prompt_right_output.items()} |
|
|
|
|
|
resp = tokenizer.encode(""" |
|
Respuesta: <end_of_turn> |
|
<start_of_turn>model """) |
|
resp_model = {k: v.to("cuda") for k, v in resp['tokens_model'].items()} |
|
|
|
|
|
self.document_retriever_text = doc_outputs['tokens_retriever']['input_ids'] |
|
self.document_retriever_mask = doc_outputs['tokens_retriever']['attention_mask'] |
|
self.document_retriever_type = doc_outputs['tokens_retriever']['token_type_ids'] |
|
self.document_model_text = doc_outputs['tokens_model']['input_ids'] |
|
|
|
|
|
self.prompt_left = prompt_left_output['tokens_model']['input_ids'] |
|
self.prompt_right = prompt_right_output['tokens_model']['input_ids'] |
|
self.respuesta = resp_model['input_ids'] |