vqa-guessing-game / model /model /response_model.py
sedrickkeh's picture
Upload 13 files
016285f
raw
history blame
9.99 kB
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoTokenizer, AutoModelForQuestionAnswering, pipeline
from transformers import ViltProcessor, ViltForQuestionAnswering
from transformers import BlipProcessor, BlipForQuestionAnswering
from sentence_transformers import SentenceTransformer
import openai
from PIL import Image
def get_response_model(args, response_type):
if response_type=="QA":
return ResponseModelQA(args.device, args.include_what)
elif response_type=="VQA1":
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt1")
elif response_type=="VQA2":
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt2")
elif response_type=="VQA3":
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="blip")
elif response_type=="VQA4":
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="git")
else:
raise ValueError(f"{response_type} is not a valid response type.")
class ResponseModel(nn.Module):
# Class for the other ResponseModels to inherit from
def __init__(self, device, include_what):
super(ResponseModel, self).__init__()
self.device = device
self.include_what = include_what
self.model = None
def get_response(self, question, image, caption, target_questions, **kwargs):
raise NotImplemented
def get_p_r_qy(self, response, question, images, captions, **kwargs):
raise NotImplemented
class ResponseModelQA(ResponseModel):
def __init__(self, device, include_what):
super(ResponseModelQA, self).__init__(device, include_what)
if not self.include_what:
tokenizer = AutoTokenizer.from_pretrained("AmazonScience/qanlu")
model = AutoModelForQuestionAnswering.from_pretrained("AmazonScience/qanlu")
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu
elif self.include_what:
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
self.model_wh = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu
def get_response(self, question, image, caption, target_questions, **kwargs):
if self.include_what:
answer = self.model({'context':caption, 'question':question})
return answer['answer'].split(' ')[-1]
else:
answer = self.model({'context':f"Yes. No. {caption}", 'question':question})
response, score = answer['answer'], answer['score']
if score>0.5:
response = response.lower().replace('.','')
if "yes" in response.split() and "no" not in response.split():
response = 'yes'
elif "no" in response.split() and "yes" not in response.split():
response = 'no'
else:
response = 'yes' if question in target_questions else 'no'
else:
response = 'yes' if question in target_questions else 'no'
return response
def get_p_r_qy(self, response, question, images, captions, **kwargs):
if self.include_what:
raise NotImplementedError
else:
p_r_qy = torch.zeros(len(captions))
qa_input = {'context':[f"Yes. No. {c}" for c in captions], 'question':[question for _ in captions]}
answers = self.model(qa_input)
for idx, answer in enumerate(answers):
curr_ans, score = answer['answer'], answer['score']
if curr_ans.strip() in ["Yes.", "No."]:
if response==None:
if curr_ans.strip()=="No.": p_r_qy[idx] = 1-score
if curr_ans.strip()=="Yes.": p_r_qy[idx] = score
elif curr_ans.strip().lower().replace('.','')==response: p_r_qy[idx]=score
else: p_r_qy[idx]=1-score
else:
p_r_qy[idx]=0.5
return p_r_qy.to(self.device)
class ResponseModelVQA(ResponseModel):
def __init__(self, device, include_what, question_generator, vqa_type):
super(ResponseModelVQA, self).__init__(device, include_what)
self.vqa_type = vqa_type
self.question_generator = question_generator
self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
if vqa_type=="vilt1":
self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
self.model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device)
self.vocab = list(self.model.config.label2id.keys())
elif vqa_type=="vilt2":
self.processor = AutoProcessor.from_pretrained("tufa15nik/vilt-finetuned-vqasi")
self.model = ViltForQuestionAnswering.from_pretrained("tufa15nik/vilt-finetuned-vqasi").to("cuda")
self.vocab = list(self.model.config.label2id.keys())
elif vqa_type=="blip":
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
elif vqa_type=="git":
pass
else:
raise ValueError(f"{vqa_type} is not a valid vqa_type.")
def get_response(self, question, image, caption, target_questions, is_a=False):
encoding = self.processor(image, question, return_tensors="pt").to(self.device)
if is_a==False:
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
is_a_responses = []
if question in ["What is in the photo?", "What is in the picture?", "What is in the background?"]:
is_a_questions = []
for q in is_a_questions:
is_a_responses.append(self.get_response(q, image, caption, target_questions, is_a=True))
no_cnt = sum([i.lower()=="no" for i in is_a_responses])
if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5:
if question[:8]=="How many": return "0"
else: return "nothing"
if self.vqa_type in ["vilt1", "vilt2"]:
outputs = self.model(**encoding)
logits = torch.nn.functional.softmax(outputs.logits, dim=1)
idx = logits.argmax(-1).item()
response = self.model.config.id2label[idx]
response = response.lower().replace('.','').strip()
elif self.vqa_type == "blip":
outputs = self.model.generate(**encoding)
response = self.processor.decode(outputs[0], skip_special_tokens=True)
return response
def get_p_r_qy(self, response, question, images, captions, is_a=False):
p_r_qy = torch.zeros(len(captions))
logits_arr = []
for i, image in enumerate(images):
with torch.no_grad():
if len(question) > 150: question="" # ignore question if too long
encoding = self.processor(image, question, return_tensors="pt").to(self.device)
outputs = self.model(**encoding)
logits = torch.nn.functional.softmax(outputs.logits, dim=1)
idx = logits.argmax(-1).item()
curr_response = self.model.config.id2label[idx]
curr_response = curr_response.lower().replace('.','').strip()
if self.include_what==False or is_a==True:
if response==None:
if curr_response=="yes": p_r_qy[i] = logits[0][3].item()
elif curr_response=="no": p_r_qy[i] = 1-logits[0][9].item()
else: p_r_qy[i] = 0.5
elif curr_response==response: p_r_qy[i] = logits[0][idx].item()
else: p_r_qy[i] = 1-logits[0][idx].item()
else:
logits_arr.append(logits)
if self.include_what==False or is_a==True:
return p_r_qy.to(self.device)
else:
logits = torch.concat(logits_arr)
if response==None:
top_answers = logits.argmax(1)
p_r_qy = logits[:,top_answers]
else:
response_idx = self.get_response_idx(response)
p_r_qy = logits[:,response_idx]
# check if this
# consider rerunning also without the geometric mean
if response=="nothing":
is_a_questions = self.question_generator.generate_is_there_question_v2(question)
for idx, (caption, image) in enumerate(zip(captions, images)):
current_responses = []
for is_a_q in is_a_questions:
current_responses.append(self.get_response(is_a_q, image, caption, None, is_a=True))
no_cnt = sum([i.lower()=="no" for i in current_responses])
if len(current_responses)>0 and no_cnt/len(current_responses)>=0.5:
p_r_qy[idx] = 1.0
return p_r_qy.to(self.device)
def get_response_idx(self, response):
if response in self.model.config.label2id:
return self.model.config.label2id[response]
else:
embs = self.sentence_transformer.encode(self.vocab, convert_to_tensor=True)
emb_response = self.sentence_transformer.encode([response], convert_to_tensor=True)
dists = torch.nn.CosineSimilarity(-1)(emb_response, embs)
best_response_idx = torch.argmax(dists)
best_response = self.vocab[best_response_idx]
return self.model.config.label2id[best_response]