Citelab / citekit /attribute /attribute.py
SHEN1017's picture
Update citekit/attribute/attribute.py
0086d2b verified
import json
from context_cite import ContextCiter
import re
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
def all_normalize(obj):
all_values = []
for output_sent_result in obj:
for each_doc in output_sent_result:
for each_span in each_doc:
all_values.append(each_span[1])
max_val = max(all_values)
min_val = min(all_values)
for output_sent_result in obj:
for i, each_doc in enumerate(output_sent_result):
for j, each_span in enumerate(each_doc):
each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val))
output_sent_result[i][j] = each_span
return obj
def all_normalize_in(obj):
for output_sent_result in obj:
all_values = []
for each_doc in output_sent_result:
for each_span in each_doc:
all_values.append(each_span[1])
max_val = max(all_values)
min_val = min(all_values)
for i, each_doc in enumerate(output_sent_result):
for j, each_span in enumerate(each_doc):
each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val))
output_sent_result[i][j] = each_span
return obj
def load_json(file_path):
with open(file_path, 'r') as file:
data = file.read()
if file_path.endswith('.jsonl'):
joined = "},{".join(data.split("}\n{"))
data = f'[{{{joined}}}]'
objects = json.loads(data)
return objects
def ma(text):
pattern = r"Document \[\d+\]\(Title:[^)]+\)"
match = re.search(pattern, text)
if match:
index = match.end()
return index
else:
return 0
def write_json(file_path, data):
with open(file_path, 'w') as json_file:
json.dump(data, json_file, indent=4)
def load_model(model_name_or_path):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map='auto',
token = 'your token'
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.eval()
return model, tokenizer
def compute_log_prob(model, tokenizer, input_text, output_text):
inputs = tokenizer(input_text, return_tensors="pt")
output_tokens = tokenizer(output_text, return_tensors="pt")["input_ids"]
with torch.no_grad():
logits = model(**inputs).logits[:, -output_tokens.shape[1]-1:-1, :]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
output_log_probs = log_probs.gather(2, output_tokens.unsqueeze(-1)).squeeze(-1)
return output_log_probs.sum().item()
def compute_contributions(model, tokenizer, question, docs, output):
full_input = question + '\n\n' + '\n'.join(docs)
base_prob = compute_log_prob(model, tokenizer, full_input, output)
contributions = []
for i in range(len(docs)):
reduced_docs = docs[:i] + docs[i+1:]
reduced_input = question + '\n\n' + '\n'.join(reduced_docs)
reduced_prob = compute_log_prob(model, tokenizer, reduced_input, output)
contributions.append(base_prob - reduced_prob)
return contributions
class InterpretableAttributer:
def __init__(self, levels=['doc', 'span', 'word'], model = 'gpt-2'):
for level in levels:
assert level in ['doc', 'span', 'word'], f'Invalid level: {level}'
# span before doc
self.levels = sorted(levels, key=lambda x: ['span', 'doc', 'word'].index(x))
#self.model, self.tokenizer = load_model(model)
def attribute(self, question, docs, output):
attribute_results = {}
for level in self.levels:
attribute_result = []
for sentence in output:
attribute_result.append(self._attribute(question, docs, sentence, level))
attribute_results[level] = attribute_result
return attribute_results
def _attribute(self, question, docs, output, level):
if level == 'doc':
return self.doc_level_attribution(question, docs, output)
elif level == 'span':
return self.span_level_attribution(question, docs, output)
elif level == 'word':
return self.word_level_attribution(question, docs, output)
else:
raise ValueError(f'Invalid level: {level}')
def span_level_attribution(self, question, docs, output):
# USE CONTEXT CITE
context = '\n\n'.join(docs)
response = output
cc = ContextCiter(self.model, self.tokenizer, context, question)
_, prompt = cc._get_prompt_ids(return_prompt=True)
cc._cache["output"] = prompt + response
result = cc.get_attributions(as_dataframe=True, top_k=1000).data.to_dict(orient='records')
return result
def parse_attribution_results(self, docs, results):
context = '\n\n'.join(docs)
lens = [len(doc) for doc in docs]
len_sep = len('\n\n')
final_results = {}
for level, result in results.items():
if level == 'span':
ordered_all_sents = []
for output_sent_result in result:
final_end_for_span = {}
all_span_results = []
for each_span in output_sent_result:
span_text = each_span["Source"]
span_score = each_span["Score"]
start = 0
if span_text in final_end_for_span:
start = final_end_for_span[span_text]
span_start = context.find(span_text, start)
span_end = span_start + len(span_text)
final_end_for_span[span_text] = span_end
# locate the document
doc_idx = 0
while span_start > lens[doc_idx]:
span_start -= lens[doc_idx] + len_sep
span_end -= lens[doc_idx] + len_sep
doc_idx += 1
all_span_results.append((span_start, span_score, doc_idx))
ordered = [[] for _ in range(len(docs))]
for span_start, span_score, doc_idx in all_span_results:
ordered[doc_idx].append((span_start, span_score))
for i in range(len(docs)):
doc = docs[i]
real_start = ma(doc)
ordered[i] = sorted(ordered[i], key=lambda x: x[0])
ordered[i][0] = (real_start, ordered[i][0][1])
ordered_all_sents.append(ordered)
final_results[level+'_level'] = all_normalize_in(ordered_all_sents)
elif level == 'doc':
self.span_to_doc(result)
else:
raise NotImplementedError(f'Parsing for {level} not implemented yet')
return final_results
def span_to_doc(self, results):
import numpy as np
span_level = results['span_level']
doc_level = []
for output_sent_result in span_level:
doc_level.append([np.mean([span[1] for span in doc]) for doc in output_sent_result])
results['doc_level'] = doc_level
def attribute_for_result(self, result):
docs = result['doc_cache']
question = result['data']['question']
output = result['output']
attribution_results = self.attribute(question, docs, output)
parsed_results = self.parse_attribution_results(docs, attribution_results)
result.update(parsed_results)
if 'doc' not in self.levels:
# if doc is not in the levels, we need to convert the span level to doc level
print('Converting span level to doc level...')
try:
self.span_to_doc(result)
print('Conversion successful')
except Exception as e:
print(f'Error converting span level to doc level: {e}')
def attribute_for_results(self, results):
for result in results:
self.attribute_for_result(result)
return results
if __name__ == '__main__':
attributer = InterpretableAttributer(levels=['span'])
results = load_json('res_attr.json')
attributer.attribute_for_results(results)
write_json('res_attr_span.json', results)