Spaces:
Build error
Build error
Commit
·
7802ab3
1
Parent(s):
39aed69
Barebone demo
Browse files
app.py
CHANGED
|
@@ -1,15 +1,132 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
gr.Interface(
|
| 11 |
-
predict,
|
| 12 |
-
inputs=
|
| 13 |
-
outputs=
|
| 14 |
-
title="
|
| 15 |
-
).launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import transformers
|
| 4 |
|
| 5 |
+
def reduce_sum(value, mask, axis=None):
|
| 6 |
+
if axis is None:
|
| 7 |
+
return torch.sum(value * mask)
|
| 8 |
+
return torch.sum(value * mask, axis)
|
| 9 |
+
def reduce_mean(value, mask, axis=None):
|
| 10 |
+
if axis is None:
|
| 11 |
+
return torch.sum(value * mask) / torch.sum(mask)
|
| 12 |
+
return reduce_sum(value, mask, axis) / torch.sum(mask, axis)
|
| 13 |
|
| 14 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 15 |
+
|
| 16 |
+
max_input_len = 256
|
| 17 |
+
max_output_len = 32
|
| 18 |
+
m = 10
|
| 19 |
+
top_p = 0.5
|
| 20 |
+
|
| 21 |
+
class InteractiveRainier:
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
|
| 25 |
+
self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
|
| 26 |
+
self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device)
|
| 27 |
+
self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none')
|
| 28 |
+
|
| 29 |
+
def parse_choices(self, s):
|
| 30 |
+
'''
|
| 31 |
+
s: serialized_choices '(A) ... (B) ... (C) ...'
|
| 32 |
+
'''
|
| 33 |
+
choices = []
|
| 34 |
+
key = 'A' if s.find('(A)') != -1 else 'a'
|
| 35 |
+
while True:
|
| 36 |
+
pos = s.find(f'({chr(ord(key) + 1)})')
|
| 37 |
+
if pos == -1:
|
| 38 |
+
break
|
| 39 |
+
choice = s[3:pos]
|
| 40 |
+
s = s[pos:]
|
| 41 |
+
choice = choice.strip(' ')
|
| 42 |
+
choices.append(choice)
|
| 43 |
+
key = chr(ord(key) + 1)
|
| 44 |
+
choice = s[3:]
|
| 45 |
+
choice = choice.strip(' ')
|
| 46 |
+
choices.append(choice)
|
| 47 |
+
return choices
|
| 48 |
+
|
| 49 |
+
def run(self, question):
|
| 50 |
+
tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
|
| 51 |
+
knowledges_ids = self.rainier_model.generate(
|
| 52 |
+
input_ids=tokenized.input_ids,
|
| 53 |
+
max_length=max_output_len + 1,
|
| 54 |
+
min_length=3,
|
| 55 |
+
do_sample=True,
|
| 56 |
+
num_return_sequences=m,
|
| 57 |
+
top_p=top_p,
|
| 58 |
+
) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS])
|
| 59 |
+
knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
|
| 60 |
+
knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 61 |
+
knowledges = list(set(knowledges))
|
| 62 |
+
knowledges = [''] + knowledges
|
| 63 |
+
|
| 64 |
+
prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges]
|
| 65 |
+
choices = self.parse_choices(question.split('\\n')[1].strip(' '))
|
| 66 |
+
prompts = [prompt.lower() for prompt in prompts]
|
| 67 |
+
choices = [choice.lower() for choice in choices]
|
| 68 |
+
answer_logitss = []
|
| 69 |
+
for choice in choices:
|
| 70 |
+
tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L)
|
| 71 |
+
tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
|
| 72 |
+
pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id)
|
| 73 |
+
tokenized_choices.input_ids[pad_mask] = -100
|
| 74 |
+
tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L)
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
logits = self.qa_model(
|
| 78 |
+
input_ids=tokenized_prompts.input_ids,
|
| 79 |
+
attention_mask=tokenized_prompts.attention_mask,
|
| 80 |
+
labels=tokenized_choices.input_ids,
|
| 81 |
+
).logits # (1+K, L, V)
|
| 82 |
+
|
| 83 |
+
losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1))
|
| 84 |
+
losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L)
|
| 85 |
+
losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K)
|
| 86 |
+
answer_logitss.append(-losses)
|
| 87 |
+
answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C)
|
| 88 |
+
answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)
|
| 89 |
+
|
| 90 |
+
# Ensemble
|
| 91 |
+
knowless_pred = answer_probss[0, :].argmax(dim=0).item()
|
| 92 |
+
knowless_pred = choices[knowless_pred]
|
| 93 |
+
|
| 94 |
+
answer_probs = answer_probss.max(dim=0).values # (C)
|
| 95 |
+
knowful_pred = answer_probs.argmax(dim=0).item()
|
| 96 |
+
knowful_pred = choices[knowful_pred]
|
| 97 |
+
selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
|
| 98 |
+
selected_knowledge = knowledges[selected_knowledge_ix]
|
| 99 |
+
|
| 100 |
+
return {
|
| 101 |
+
'question': question,
|
| 102 |
+
'knowledges': knowledges,
|
| 103 |
+
'knowless_pred': knowless_pred,
|
| 104 |
+
'knowful_pred': knowful_pred,
|
| 105 |
+
'selected_knowledge': selected_knowledge,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
rainier = InteractiveRainier()
|
| 109 |
+
|
| 110 |
+
def predict(question, choices):
|
| 111 |
+
result = rainier.run(f'{question} \\n {choices}')
|
| 112 |
+
output = ''
|
| 113 |
+
output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
|
| 114 |
+
output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
|
| 115 |
+
output += '\n'
|
| 116 |
+
output += f'All generated knowledges:\n'
|
| 117 |
+
for knowledge in result['knowledges']:
|
| 118 |
+
output += f' {knowledge}\n'
|
| 119 |
+
output += '\n'
|
| 120 |
+
output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
|
| 121 |
+
return output
|
| 122 |
+
|
| 123 |
+
input_question = gr.inputs.Textbox(label='Question:')
|
| 124 |
+
input_choices = gr.inputs.TextBox(label='Choices:')
|
| 125 |
+
output_text = gr.outputs.Textbox(label='Output')
|
| 126 |
|
| 127 |
gr.Interface(
|
| 128 |
+
fn=predict,
|
| 129 |
+
inputs=[input_question, input_choices],
|
| 130 |
+
outputs=output_text,
|
| 131 |
+
title="Rainier",
|
| 132 |
+
).launch()
|