File size: 8,528 Bytes
96b6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0086d2b
 
96b6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import json
from context_cite import ContextCiter
import re
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer


def all_normalize(obj):
    all_values = []
    for output_sent_result in obj:
        for each_doc in output_sent_result:
            for each_span in each_doc:
                all_values.append(each_span[1])
    max_val = max(all_values)
    min_val = min(all_values)
    for output_sent_result in obj:
        for i, each_doc in enumerate(output_sent_result):
            for j, each_span in enumerate(each_doc):
                each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val))
                output_sent_result[i][j] = each_span
    return obj

def all_normalize_in(obj):
    for output_sent_result in obj:
        all_values = []
        for each_doc in output_sent_result:
            for each_span in each_doc:
                all_values.append(each_span[1])
        max_val = max(all_values)
        min_val = min(all_values)
        for i, each_doc in enumerate(output_sent_result):
            for j, each_span in enumerate(each_doc):
                each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val))
                output_sent_result[i][j] = each_span
    return obj

def load_json(file_path):

    with open(file_path, 'r') as file:
        data = file.read()
    if file_path.endswith('.jsonl'):
        joined = "},{".join(data.split("}\n{"))
        data = f'[{{{joined}}}]'
    objects = json.loads(data)
    return objects

def ma(text):
    pattern = r"Document \[\d+\]\(Title:[^)]+\)"

    match = re.search(pattern, text)

    if match:
        index = match.end()
        return index
    else:
        return 0

def write_json(file_path, data):
    with open(file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)



def load_model(model_name_or_path):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        device_map='auto',
        token = 'your token'
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model.eval()
    return model, tokenizer


def compute_log_prob(model, tokenizer, input_text, output_text):
    inputs = tokenizer(input_text, return_tensors="pt")
    output_tokens = tokenizer(output_text, return_tensors="pt")["input_ids"]
    
    with torch.no_grad():
        logits = model(**inputs).logits[:, -output_tokens.shape[1]-1:-1, :]
    
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    output_log_probs = log_probs.gather(2, output_tokens.unsqueeze(-1)).squeeze(-1)
    return output_log_probs.sum().item()

def compute_contributions(model, tokenizer, question, docs, output):
    full_input = question + '\n\n' + '\n'.join(docs)
    base_prob = compute_log_prob(model, tokenizer, full_input, output)
    
    contributions = []
    for i in range(len(docs)):
        reduced_docs = docs[:i] + docs[i+1:]
        reduced_input = question + '\n\n' + '\n'.join(reduced_docs)
        reduced_prob = compute_log_prob(model, tokenizer, reduced_input, output)
        contributions.append(base_prob - reduced_prob)
    
    return contributions
    
class InterpretableAttributer:

    def __init__(self, levels=['doc', 'span', 'word'], model = 'gpt-2'):
        for level in levels:
            assert level in ['doc', 'span', 'word'], f'Invalid level: {level}'
        # span before doc
        self.levels = sorted(levels, key=lambda x: ['span', 'doc', 'word'].index(x))
        #self.model, self.tokenizer = load_model(model)


    def attribute(self, question, docs, output):
        attribute_results = {}
        for level in self.levels:
            attribute_result = []
            for sentence in output:
                attribute_result.append(self._attribute(question, docs, sentence, level))
            attribute_results[level] = attribute_result
        return attribute_results
    

    def _attribute(self, question, docs, output, level):
        if level == 'doc':
            return self.doc_level_attribution(question, docs, output)
        elif level == 'span':
            return self.span_level_attribution(question, docs, output)
        elif level == 'word':
            return self.word_level_attribution(question, docs, output)
        else:
            raise ValueError(f'Invalid level: {level}')
    
    def span_level_attribution(self, question, docs, output):
        # USE CONTEXT CITE
        context = '\n\n'.join(docs)
        response = output

        cc = ContextCiter(self.model, self.tokenizer, context, question)
        _, prompt = cc._get_prompt_ids(return_prompt=True)
        cc._cache["output"] = prompt + response  
        result = cc.get_attributions(as_dataframe=True, top_k=1000).data.to_dict(orient='records')
        return result
    
    
    def parse_attribution_results(self, docs, results):
        context = '\n\n'.join(docs)
        lens = [len(doc) for doc in docs]
        len_sep = len('\n\n')
        final_results = {}
        for level, result in results.items(): 
            if level == 'span':
                ordered_all_sents = []
                for output_sent_result in result:
                    final_end_for_span = {}
                    all_span_results = []
                    for each_span in output_sent_result:
                        span_text = each_span["Source"]
                        span_score = each_span["Score"]
                        start = 0
                        if span_text in final_end_for_span:
                            start = final_end_for_span[span_text]
                        span_start = context.find(span_text, start)
                        span_end = span_start + len(span_text)
                        final_end_for_span[span_text] = span_end
                        # locate the document
                        doc_idx = 0
                        while span_start > lens[doc_idx]:
                            span_start -= lens[doc_idx] + len_sep
                            span_end -= lens[doc_idx] + len_sep
                            doc_idx += 1
                        all_span_results.append((span_start, span_score, doc_idx))
                    ordered = [[] for _ in range(len(docs))]
                    for span_start, span_score, doc_idx in all_span_results:
                        ordered[doc_idx].append((span_start, span_score))
                    for i in range(len(docs)):
                        doc = docs[i]
                        real_start = ma(doc)
                        ordered[i] = sorted(ordered[i], key=lambda x: x[0])
                        ordered[i][0] = (real_start, ordered[i][0][1])

                    ordered_all_sents.append(ordered)
                final_results[level+'_level'] = all_normalize_in(ordered_all_sents)
            elif level == 'doc':
                self.span_to_doc(result)
            else:
                raise NotImplementedError(f'Parsing for {level} not implemented yet')
        return final_results
    
    def span_to_doc(self, results):
        import numpy as np
        span_level = results['span_level']
        doc_level = []
        for output_sent_result in span_level:
            doc_level.append([np.mean([span[1] for span in doc]) for doc in output_sent_result])
        results['doc_level'] = doc_level
        

    def attribute_for_result(self, result):
        docs = result['doc_cache']
        question = result['data']['question']
        output = result['output']
        attribution_results = self.attribute(question, docs, output)
        parsed_results = self.parse_attribution_results(docs, attribution_results)
        result.update(parsed_results)

        if 'doc' not in self.levels:
            # if doc is not in the levels, we need to convert the span level to doc level
            print('Converting span level to doc level...')
            try:
                self.span_to_doc(result)
                print('Conversion successful')
            except Exception as e:
                print(f'Error converting span level to doc level: {e}')

    def attribute_for_results(self, results):
        for result in results:
            self.attribute_for_result(result)
        return results
    

if __name__ == '__main__':
    attributer = InterpretableAttributer(levels=['span'])
    results = load_json('res_attr.json')
    attributer.attribute_for_results(results)
    write_json('res_attr_span.json', results)