rainier / app.py
liujch1998's picture
Barebone demo
7802ab3
raw
history blame
5.61 kB
import gradio as gr
import torch
import transformers
def reduce_sum(value, mask, axis=None):
if axis is None:
return torch.sum(value * mask)
return torch.sum(value * mask, axis)
def reduce_mean(value, mask, axis=None):
if axis is None:
return torch.sum(value * mask) / torch.sum(mask)
return reduce_sum(value, mask, axis) / torch.sum(mask, axis)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
max_input_len = 256
max_output_len = 32
m = 10
top_p = 0.5
class InteractiveRainier:
def __init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
self.qa_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('allenai/unifiedqa-t5-large').to(device)
self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100,reduction='none')
def parse_choices(self, s):
'''
s: serialized_choices '(A) ... (B) ... (C) ...'
'''
choices = []
key = 'A' if s.find('(A)') != -1 else 'a'
while True:
pos = s.find(f'({chr(ord(key) + 1)})')
if pos == -1:
break
choice = s[3:pos]
s = s[pos:]
choice = choice.strip(' ')
choices.append(choice)
key = chr(ord(key) + 1)
choice = s[3:]
choice = choice.strip(' ')
choices.append(choice)
return choices
def run(self, question):
tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
knowledges_ids = self.rainier_model.generate(
input_ids=tokenized.input_ids,
max_length=max_output_len + 1,
min_length=3,
do_sample=True,
num_return_sequences=m,
top_p=top_p,
) # (K, L); begins with 0 ([BOS]); ends with 1 ([EOS])
knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
knowledges = list(set(knowledges))
knowledges = [''] + knowledges
prompts = [question + (f' \\n {knowledge}' if knowledge != '' else '') for knowledge in knowledges]
choices = self.parse_choices(question.split('\\n')[1].strip(' '))
prompts = [prompt.lower() for prompt in prompts]
choices = [choice.lower() for choice in choices]
answer_logitss = []
for choice in choices:
tokenized_prompts = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1+K, L)
tokenized_choices = self.tokenizer([choice], return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
pad_mask = (tokenized_choices.input_ids == self.tokenizer.pad_token_id)
tokenized_choices.input_ids[pad_mask] = -100
tokenized_choices.input_ids = tokenized_choices.input_ids.repeat(len(knowledges), 1) # (1+K, L)
with torch.no_grad():
logits = self.qa_model(
input_ids=tokenized_prompts.input_ids,
attention_mask=tokenized_prompts.attention_mask,
labels=tokenized_choices.input_ids,
).logits # (1+K, L, V)
losses = self.loss_fct(logits.view(-1, logits.size(-1)), tokenized_choices.input_ids.view(-1))
losses = losses.view(tokenized_choices.input_ids.shape) # (1+K, L)
losses = reduce_mean(losses, ~pad_mask, axis=-1) # (1+K)
answer_logitss.append(-losses)
answer_logitss = torch.stack(answer_logitss, dim=1) # (1+K, C)
answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)
# Ensemble
knowless_pred = answer_probss[0, :].argmax(dim=0).item()
knowless_pred = choices[knowless_pred]
answer_probs = answer_probss.max(dim=0).values # (C)
knowful_pred = answer_probs.argmax(dim=0).item()
knowful_pred = choices[knowful_pred]
selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
selected_knowledge = knowledges[selected_knowledge_ix]
return {
'question': question,
'knowledges': knowledges,
'knowless_pred': knowless_pred,
'knowful_pred': knowful_pred,
'selected_knowledge': selected_knowledge,
}
rainier = InteractiveRainier()
def predict(question, choices):
result = rainier.run(f'{question} \\n {choices}')
output = ''
output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
output += '\n'
output += f'All generated knowledges:\n'
for knowledge in result['knowledges']:
output += f' {knowledge}\n'
output += '\n'
output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
return output
input_question = gr.inputs.Textbox(label='Question:')
input_choices = gr.inputs.TextBox(label='Choices:')
output_text = gr.outputs.Textbox(label='Output')
gr.Interface(
fn=predict,
inputs=[input_question, input_choices],
outputs=output_text,
title="Rainier",
).launch()