File size: 5,612 Bytes
39aed69
7802ab3
 
39aed69
7802ab3
 
 
 
 
 
 
 
39aed69
7802ab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39aed69
 
7802ab3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
import transformers

def reduce_sum(value, mask, axis=None):
    if axis is None:
        return torch.sum(value * mask)
    return torch.sum(value * mask, axis)
def reduce_mean(value, mask, axis=None):
    if axis is None:
        return torch.sum(value * mask) / torch.sum(mask)
    return reduce_sum(value, mask, axis) / torch.sum(mask, axis)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

max_input_len = 256
max_output_len = 32
m = 10
top_p = 0.5

class InteractiveRainier:

    def __init__(self):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
        self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
        self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device)
        self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none')

    def parse_choices(self, s):
        '''
        s: serialized_choices '(A) ... (B) ... (C) ...'
        '''
        choices = []
        key = 'A' if s.find('(A)') != -1 else 'a'
        while True:
            pos = s.find(f'({chr(ord(key) + 1)})')
            if pos == -1:
                break
            choice = s[3:pos]
            s = s[pos:]
            choice = choice.strip(' ')
            choices.append(choice)
            key = chr(ord(key) + 1)
        choice = s[3:]
        choice = choice.strip(' ')
        choices.append(choice)
        return choices

    def run(self, question):
        tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
        knowledges_ids = self.rainier_model.generate(
            input_ids=tokenized.input_ids,
            max_length=max_output_len + 1,
            min_length=3,
            do_sample=True,
            num_return_sequences=m,
            top_p=top_p,
        ) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS])
        knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
        knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        knowledges = list(set(knowledges))
        knowledges = [''] + knowledges

        prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges]
        choices = self.parse_choices(question.split('\\n')[1].strip(' '))
        prompts = [prompt.lower() for prompt in prompts]
        choices = [choice.lower() for choice in choices]
        answer_logitss = []
        for choice in choices:
            tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L)
            tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
            pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id)
            tokenized_choices.input_ids[pad_mask] = -100
            tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L)

            with torch.no_grad():
                logits = self.qa_model(
                    input_ids=tokenized_prompts.input_ids,
                    attention_mask=tokenized_prompts.attention_mask,
                    labels=tokenized_choices.input_ids,
                ).logits # (1+K, L, V)

            losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1))
            losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L)
            losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K)
            answer_logitss.append(-losses)
        answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C)
        answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)

        # Ensemble
        knowless_pred = answer_probss[0, :].argmax(dim=0).item()
        knowless_pred = choices[knowless_pred]

        answer_probs = answer_probss.max(dim=0).values # (C)
        knowful_pred = answer_probs.argmax(dim=0).item()
        knowful_pred = choices[knowful_pred]
        selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
        selected_knowledge = knowledges[selected_knowledge_ix]

        return {
            'question': question,
            'knowledges': knowledges,
            'knowless_pred': knowless_pred,
            'knowful_pred': knowful_pred,
            'selected_knowledge': selected_knowledge,
        }

rainier = InteractiveRainier()

def predict(question, choices):
    result = rainier.run(f'{question} \\n {choices}')
    output = ''
    output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
    output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
    output += '\n'
    output += f'All generated knowledges:\n'
    for knowledge in result['knowledges']:
        output += f'  {knowledge}\n'
    output += '\n'
    output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
    return output

input_question = gr.inputs.Textbox(label='Question:')
input_choices = gr.inputs.TextBox(label='Choices:')
output_text = gr.outputs.Textbox(label='Output')

gr.Interface(
    fn=predict,
    inputs=[input_question, input_choices],
    outputs=output_text,
    title="Rainier",
).launch()