File size: 8,528 Bytes
96b6673 0086d2b 96b6673 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import json
from context_cite import ContextCiter
import re
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
def all_normalize(obj):
all_values = []
for output_sent_result in obj:
for each_doc in output_sent_result:
for each_span in each_doc:
all_values.append(each_span[1])
max_val = max(all_values)
min_val = min(all_values)
for output_sent_result in obj:
for i, each_doc in enumerate(output_sent_result):
for j, each_span in enumerate(each_doc):
each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val))
output_sent_result[i][j] = each_span
return obj
def all_normalize_in(obj):
for output_sent_result in obj:
all_values = []
for each_doc in output_sent_result:
for each_span in each_doc:
all_values.append(each_span[1])
max_val = max(all_values)
min_val = min(all_values)
for i, each_doc in enumerate(output_sent_result):
for j, each_span in enumerate(each_doc):
each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val))
output_sent_result[i][j] = each_span
return obj
def load_json(file_path):
with open(file_path, 'r') as file:
data = file.read()
if file_path.endswith('.jsonl'):
joined = "},{".join(data.split("}\n{"))
data = f'[{{{joined}}}]'
objects = json.loads(data)
return objects
def ma(text):
pattern = r"Document \[\d+\]\(Title:[^)]+\)"
match = re.search(pattern, text)
if match:
index = match.end()
return index
else:
return 0
def write_json(file_path, data):
with open(file_path, 'w') as json_file:
json.dump(data, json_file, indent=4)
def load_model(model_name_or_path):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map='auto',
token = 'your token'
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.eval()
return model, tokenizer
def compute_log_prob(model, tokenizer, input_text, output_text):
inputs = tokenizer(input_text, return_tensors="pt")
output_tokens = tokenizer(output_text, return_tensors="pt")["input_ids"]
with torch.no_grad():
logits = model(**inputs).logits[:, -output_tokens.shape[1]-1:-1, :]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
output_log_probs = log_probs.gather(2, output_tokens.unsqueeze(-1)).squeeze(-1)
return output_log_probs.sum().item()
def compute_contributions(model, tokenizer, question, docs, output):
full_input = question + '\n\n' + '\n'.join(docs)
base_prob = compute_log_prob(model, tokenizer, full_input, output)
contributions = []
for i in range(len(docs)):
reduced_docs = docs[:i] + docs[i+1:]
reduced_input = question + '\n\n' + '\n'.join(reduced_docs)
reduced_prob = compute_log_prob(model, tokenizer, reduced_input, output)
contributions.append(base_prob - reduced_prob)
return contributions
class InterpretableAttributer:
def __init__(self, levels=['doc', 'span', 'word'], model = 'gpt-2'):
for level in levels:
assert level in ['doc', 'span', 'word'], f'Invalid level: {level}'
# span before doc
self.levels = sorted(levels, key=lambda x: ['span', 'doc', 'word'].index(x))
#self.model, self.tokenizer = load_model(model)
def attribute(self, question, docs, output):
attribute_results = {}
for level in self.levels:
attribute_result = []
for sentence in output:
attribute_result.append(self._attribute(question, docs, sentence, level))
attribute_results[level] = attribute_result
return attribute_results
def _attribute(self, question, docs, output, level):
if level == 'doc':
return self.doc_level_attribution(question, docs, output)
elif level == 'span':
return self.span_level_attribution(question, docs, output)
elif level == 'word':
return self.word_level_attribution(question, docs, output)
else:
raise ValueError(f'Invalid level: {level}')
def span_level_attribution(self, question, docs, output):
# USE CONTEXT CITE
context = '\n\n'.join(docs)
response = output
cc = ContextCiter(self.model, self.tokenizer, context, question)
_, prompt = cc._get_prompt_ids(return_prompt=True)
cc._cache["output"] = prompt + response
result = cc.get_attributions(as_dataframe=True, top_k=1000).data.to_dict(orient='records')
return result
def parse_attribution_results(self, docs, results):
context = '\n\n'.join(docs)
lens = [len(doc) for doc in docs]
len_sep = len('\n\n')
final_results = {}
for level, result in results.items():
if level == 'span':
ordered_all_sents = []
for output_sent_result in result:
final_end_for_span = {}
all_span_results = []
for each_span in output_sent_result:
span_text = each_span["Source"]
span_score = each_span["Score"]
start = 0
if span_text in final_end_for_span:
start = final_end_for_span[span_text]
span_start = context.find(span_text, start)
span_end = span_start + len(span_text)
final_end_for_span[span_text] = span_end
# locate the document
doc_idx = 0
while span_start > lens[doc_idx]:
span_start -= lens[doc_idx] + len_sep
span_end -= lens[doc_idx] + len_sep
doc_idx += 1
all_span_results.append((span_start, span_score, doc_idx))
ordered = [[] for _ in range(len(docs))]
for span_start, span_score, doc_idx in all_span_results:
ordered[doc_idx].append((span_start, span_score))
for i in range(len(docs)):
doc = docs[i]
real_start = ma(doc)
ordered[i] = sorted(ordered[i], key=lambda x: x[0])
ordered[i][0] = (real_start, ordered[i][0][1])
ordered_all_sents.append(ordered)
final_results[level+'_level'] = all_normalize_in(ordered_all_sents)
elif level == 'doc':
self.span_to_doc(result)
else:
raise NotImplementedError(f'Parsing for {level} not implemented yet')
return final_results
def span_to_doc(self, results):
import numpy as np
span_level = results['span_level']
doc_level = []
for output_sent_result in span_level:
doc_level.append([np.mean([span[1] for span in doc]) for doc in output_sent_result])
results['doc_level'] = doc_level
def attribute_for_result(self, result):
docs = result['doc_cache']
question = result['data']['question']
output = result['output']
attribution_results = self.attribute(question, docs, output)
parsed_results = self.parse_attribution_results(docs, attribution_results)
result.update(parsed_results)
if 'doc' not in self.levels:
# if doc is not in the levels, we need to convert the span level to doc level
print('Converting span level to doc level...')
try:
self.span_to_doc(result)
print('Conversion successful')
except Exception as e:
print(f'Error converting span level to doc level: {e}')
def attribute_for_results(self, results):
for result in results:
self.attribute_for_result(result)
return results
if __name__ == '__main__':
attributer = InterpretableAttributer(levels=['span'])
results = load_json('res_attr.json')
attributer.attribute_for_results(results)
write_json('res_attr_span.json', results) |