|
import json |
|
from context_cite import ContextCiter |
|
import re |
|
import torch |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
|
|
|
|
def all_normalize(obj): |
|
all_values = [] |
|
for output_sent_result in obj: |
|
for each_doc in output_sent_result: |
|
for each_span in each_doc: |
|
all_values.append(each_span[1]) |
|
max_val = max(all_values) |
|
min_val = min(all_values) |
|
for output_sent_result in obj: |
|
for i, each_doc in enumerate(output_sent_result): |
|
for j, each_span in enumerate(each_doc): |
|
each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val)) |
|
output_sent_result[i][j] = each_span |
|
return obj |
|
|
|
def all_normalize_in(obj): |
|
for output_sent_result in obj: |
|
all_values = [] |
|
for each_doc in output_sent_result: |
|
for each_span in each_doc: |
|
all_values.append(each_span[1]) |
|
max_val = max(all_values) |
|
min_val = min(all_values) |
|
for i, each_doc in enumerate(output_sent_result): |
|
for j, each_span in enumerate(each_doc): |
|
each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val)) |
|
output_sent_result[i][j] = each_span |
|
return obj |
|
|
|
def load_json(file_path): |
|
|
|
with open(file_path, 'r') as file: |
|
data = file.read() |
|
if file_path.endswith('.jsonl'): |
|
joined = "},{".join(data.split("}\n{")) |
|
data = f'[{{{joined}}}]' |
|
objects = json.loads(data) |
|
return objects |
|
|
|
def ma(text): |
|
pattern = r"Document \[\d+\]\(Title:[^)]+\)" |
|
|
|
match = re.search(pattern, text) |
|
|
|
if match: |
|
index = match.end() |
|
return index |
|
else: |
|
return 0 |
|
|
|
def write_json(file_path, data): |
|
with open(file_path, 'w') as json_file: |
|
json.dump(data, json_file, indent=4) |
|
|
|
|
|
|
|
def load_model(model_name_or_path): |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
device_map='auto', |
|
token = 'your token' |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model.eval() |
|
return model, tokenizer |
|
|
|
|
|
def compute_log_prob(model, tokenizer, input_text, output_text): |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
output_tokens = tokenizer(output_text, return_tensors="pt")["input_ids"] |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits[:, -output_tokens.shape[1]-1:-1, :] |
|
|
|
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) |
|
output_log_probs = log_probs.gather(2, output_tokens.unsqueeze(-1)).squeeze(-1) |
|
return output_log_probs.sum().item() |
|
|
|
def compute_contributions(model, tokenizer, question, docs, output): |
|
full_input = question + '\n\n' + '\n'.join(docs) |
|
base_prob = compute_log_prob(model, tokenizer, full_input, output) |
|
|
|
contributions = [] |
|
for i in range(len(docs)): |
|
reduced_docs = docs[:i] + docs[i+1:] |
|
reduced_input = question + '\n\n' + '\n'.join(reduced_docs) |
|
reduced_prob = compute_log_prob(model, tokenizer, reduced_input, output) |
|
contributions.append(base_prob - reduced_prob) |
|
|
|
return contributions |
|
|
|
class InterpretableAttributer: |
|
|
|
def __init__(self, levels=['doc', 'span', 'word'], model = 'gpt-2'): |
|
for level in levels: |
|
assert level in ['doc', 'span', 'word'], f'Invalid level: {level}' |
|
|
|
self.levels = sorted(levels, key=lambda x: ['span', 'doc', 'word'].index(x)) |
|
|
|
|
|
|
|
def attribute(self, question, docs, output): |
|
attribute_results = {} |
|
for level in self.levels: |
|
attribute_result = [] |
|
for sentence in output: |
|
attribute_result.append(self._attribute(question, docs, sentence, level)) |
|
attribute_results[level] = attribute_result |
|
return attribute_results |
|
|
|
|
|
def _attribute(self, question, docs, output, level): |
|
if level == 'doc': |
|
return self.doc_level_attribution(question, docs, output) |
|
elif level == 'span': |
|
return self.span_level_attribution(question, docs, output) |
|
elif level == 'word': |
|
return self.word_level_attribution(question, docs, output) |
|
else: |
|
raise ValueError(f'Invalid level: {level}') |
|
|
|
def span_level_attribution(self, question, docs, output): |
|
|
|
context = '\n\n'.join(docs) |
|
response = output |
|
|
|
cc = ContextCiter(self.model, self.tokenizer, context, question) |
|
_, prompt = cc._get_prompt_ids(return_prompt=True) |
|
cc._cache["output"] = prompt + response |
|
result = cc.get_attributions(as_dataframe=True, top_k=1000).data.to_dict(orient='records') |
|
return result |
|
|
|
|
|
def parse_attribution_results(self, docs, results): |
|
context = '\n\n'.join(docs) |
|
lens = [len(doc) for doc in docs] |
|
len_sep = len('\n\n') |
|
final_results = {} |
|
for level, result in results.items(): |
|
if level == 'span': |
|
ordered_all_sents = [] |
|
for output_sent_result in result: |
|
final_end_for_span = {} |
|
all_span_results = [] |
|
for each_span in output_sent_result: |
|
span_text = each_span["Source"] |
|
span_score = each_span["Score"] |
|
start = 0 |
|
if span_text in final_end_for_span: |
|
start = final_end_for_span[span_text] |
|
span_start = context.find(span_text, start) |
|
span_end = span_start + len(span_text) |
|
final_end_for_span[span_text] = span_end |
|
|
|
doc_idx = 0 |
|
while span_start > lens[doc_idx]: |
|
span_start -= lens[doc_idx] + len_sep |
|
span_end -= lens[doc_idx] + len_sep |
|
doc_idx += 1 |
|
all_span_results.append((span_start, span_score, doc_idx)) |
|
ordered = [[] for _ in range(len(docs))] |
|
for span_start, span_score, doc_idx in all_span_results: |
|
ordered[doc_idx].append((span_start, span_score)) |
|
for i in range(len(docs)): |
|
doc = docs[i] |
|
real_start = ma(doc) |
|
ordered[i] = sorted(ordered[i], key=lambda x: x[0]) |
|
ordered[i][0] = (real_start, ordered[i][0][1]) |
|
|
|
ordered_all_sents.append(ordered) |
|
final_results[level+'_level'] = all_normalize_in(ordered_all_sents) |
|
elif level == 'doc': |
|
self.span_to_doc(result) |
|
else: |
|
raise NotImplementedError(f'Parsing for {level} not implemented yet') |
|
return final_results |
|
|
|
def span_to_doc(self, results): |
|
import numpy as np |
|
span_level = results['span_level'] |
|
doc_level = [] |
|
for output_sent_result in span_level: |
|
doc_level.append([np.mean([span[1] for span in doc]) for doc in output_sent_result]) |
|
results['doc_level'] = doc_level |
|
|
|
|
|
def attribute_for_result(self, result): |
|
docs = result['doc_cache'] |
|
question = result['data']['question'] |
|
output = result['output'] |
|
attribution_results = self.attribute(question, docs, output) |
|
parsed_results = self.parse_attribution_results(docs, attribution_results) |
|
result.update(parsed_results) |
|
|
|
if 'doc' not in self.levels: |
|
|
|
print('Converting span level to doc level...') |
|
try: |
|
self.span_to_doc(result) |
|
print('Conversion successful') |
|
except Exception as e: |
|
print(f'Error converting span level to doc level: {e}') |
|
|
|
def attribute_for_results(self, results): |
|
for result in results: |
|
self.attribute_for_result(result) |
|
return results |
|
|
|
|
|
if __name__ == '__main__': |
|
attributer = InterpretableAttributer(levels=['span']) |
|
results = load_json('res_attr.json') |
|
attributer.attribute_for_results(results) |
|
write_json('res_attr_span.json', results) |