Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from transformers import AutoProcessor, AutoTokenizer, AutoModelForQuestionAnswering, pipeline | |
from transformers import ViltProcessor, ViltForQuestionAnswering | |
from transformers import BlipProcessor, BlipForQuestionAnswering | |
from sentence_transformers import SentenceTransformer | |
import openai | |
from PIL import Image | |
def get_response_model(args, response_type): | |
if response_type=="QA": | |
return ResponseModelQA(args.device, args.include_what) | |
elif response_type=="VQA1": | |
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt1") | |
elif response_type=="VQA2": | |
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="vilt2") | |
elif response_type=="VQA3": | |
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="blip") | |
elif response_type=="VQA4": | |
return ResponseModelVQA(args.device, args.include_what, args.question_generator, vqa_type="git") | |
else: | |
raise ValueError(f"{response_type} is not a valid response type.") | |
class ResponseModel(nn.Module): | |
# Class for the other ResponseModels to inherit from | |
def __init__(self, device, include_what): | |
super(ResponseModel, self).__init__() | |
self.device = device | |
self.include_what = include_what | |
self.model = None | |
def get_response(self, question, image, caption, target_questions, **kwargs): | |
raise NotImplemented | |
def get_p_r_qy(self, response, question, images, captions, **kwargs): | |
raise NotImplemented | |
class ResponseModelQA(ResponseModel): | |
def __init__(self, device, include_what): | |
super(ResponseModelQA, self).__init__(device, include_what) | |
if not self.include_what: | |
tokenizer = AutoTokenizer.from_pretrained("AmazonScience/qanlu") | |
model = AutoModelForQuestionAnswering.from_pretrained("AmazonScience/qanlu") | |
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu | |
elif self.include_what: | |
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") | |
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") | |
self.model_wh = pipeline('question-answering', model=model, tokenizer=tokenizer, device=0) # remove device=0 for cpu | |
def get_response(self, question, image, caption, target_questions, **kwargs): | |
if self.include_what: | |
answer = self.model({'context':caption, 'question':question}) | |
return answer['answer'].split(' ')[-1] | |
else: | |
answer = self.model({'context':f"Yes. No. {caption}", 'question':question}) | |
response, score = answer['answer'], answer['score'] | |
if score>0.5: | |
response = response.lower().replace('.','') | |
if "yes" in response.split() and "no" not in response.split(): | |
response = 'yes' | |
elif "no" in response.split() and "yes" not in response.split(): | |
response = 'no' | |
else: | |
response = 'yes' if question in target_questions else 'no' | |
else: | |
response = 'yes' if question in target_questions else 'no' | |
return response | |
def get_p_r_qy(self, response, question, images, captions, **kwargs): | |
if self.include_what: | |
raise NotImplementedError | |
else: | |
p_r_qy = torch.zeros(len(captions)) | |
qa_input = {'context':[f"Yes. No. {c}" for c in captions], 'question':[question for _ in captions]} | |
answers = self.model(qa_input) | |
for idx, answer in enumerate(answers): | |
curr_ans, score = answer['answer'], answer['score'] | |
if curr_ans.strip() in ["Yes.", "No."]: | |
if response==None: | |
if curr_ans.strip()=="No.": p_r_qy[idx] = 1-score | |
if curr_ans.strip()=="Yes.": p_r_qy[idx] = score | |
elif curr_ans.strip().lower().replace('.','')==response: p_r_qy[idx]=score | |
else: p_r_qy[idx]=1-score | |
else: | |
p_r_qy[idx]=0.5 | |
return p_r_qy.to(self.device) | |
class ResponseModelVQA(ResponseModel): | |
def __init__(self, device, include_what, question_generator, vqa_type): | |
super(ResponseModelVQA, self).__init__(device, include_what) | |
self.vqa_type = vqa_type | |
self.question_generator = question_generator | |
self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') | |
if vqa_type=="vilt1": | |
self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
self.model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device) | |
self.vocab = list(self.model.config.label2id.keys()) | |
elif vqa_type=="vilt2": | |
self.processor = AutoProcessor.from_pretrained("tufa15nik/vilt-finetuned-vqasi") | |
self.model = ViltForQuestionAnswering.from_pretrained("tufa15nik/vilt-finetuned-vqasi").to("cuda") | |
self.vocab = list(self.model.config.label2id.keys()) | |
elif vqa_type=="blip": | |
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
elif vqa_type=="git": | |
pass | |
else: | |
raise ValueError(f"{vqa_type} is not a valid vqa_type.") | |
def get_response(self, question, image, caption, target_questions, is_a=False): | |
encoding = self.processor(image, question, return_tensors="pt").to(self.device) | |
if is_a==False: | |
is_a_questions = self.question_generator.generate_is_there_question_v2(question) | |
is_a_responses = [] | |
if question in ["What is in the photo?", "What is in the picture?", "What is in the background?"]: | |
is_a_questions = [] | |
for q in is_a_questions: | |
is_a_responses.append(self.get_response(q, image, caption, target_questions, is_a=True)) | |
no_cnt = sum([i.lower()=="no" for i in is_a_responses]) | |
if len(is_a_responses)>0 and no_cnt/len(is_a_responses)>=0.5: | |
if question[:8]=="How many": return "0" | |
else: return "nothing" | |
if self.vqa_type in ["vilt1", "vilt2"]: | |
outputs = self.model(**encoding) | |
logits = torch.nn.functional.softmax(outputs.logits, dim=1) | |
idx = logits.argmax(-1).item() | |
response = self.model.config.id2label[idx] | |
response = response.lower().replace('.','').strip() | |
elif self.vqa_type == "blip": | |
outputs = self.model.generate(**encoding) | |
response = self.processor.decode(outputs[0], skip_special_tokens=True) | |
return response | |
def get_p_r_qy(self, response, question, images, captions, is_a=False): | |
p_r_qy = torch.zeros(len(captions)) | |
logits_arr = [] | |
for i, image in enumerate(images): | |
with torch.no_grad(): | |
if len(question) > 150: question="" # ignore question if too long | |
encoding = self.processor(image, question, return_tensors="pt").to(self.device) | |
outputs = self.model(**encoding) | |
logits = torch.nn.functional.softmax(outputs.logits, dim=1) | |
idx = logits.argmax(-1).item() | |
curr_response = self.model.config.id2label[idx] | |
curr_response = curr_response.lower().replace('.','').strip() | |
if self.include_what==False or is_a==True: | |
if response==None: | |
if curr_response=="yes": p_r_qy[i] = logits[0][3].item() | |
elif curr_response=="no": p_r_qy[i] = 1-logits[0][9].item() | |
else: p_r_qy[i] = 0.5 | |
elif curr_response==response: p_r_qy[i] = logits[0][idx].item() | |
else: p_r_qy[i] = 1-logits[0][idx].item() | |
else: | |
logits_arr.append(logits) | |
if self.include_what==False or is_a==True: | |
return p_r_qy.to(self.device) | |
else: | |
logits = torch.concat(logits_arr) | |
if response==None: | |
top_answers = logits.argmax(1) | |
p_r_qy = logits[:,top_answers] | |
else: | |
response_idx = self.get_response_idx(response) | |
p_r_qy = logits[:,response_idx] | |
# check if this | |
# consider rerunning also without the geometric mean | |
if response=="nothing": | |
is_a_questions = self.question_generator.generate_is_there_question_v2(question) | |
for idx, (caption, image) in enumerate(zip(captions, images)): | |
current_responses = [] | |
for is_a_q in is_a_questions: | |
current_responses.append(self.get_response(is_a_q, image, caption, None, is_a=True)) | |
no_cnt = sum([i.lower()=="no" for i in current_responses]) | |
if len(current_responses)>0 and no_cnt/len(current_responses)>=0.5: | |
p_r_qy[idx] = 1.0 | |
return p_r_qy.to(self.device) | |
def get_response_idx(self, response): | |
if response in self.model.config.label2id: | |
return self.model.config.label2id[response] | |
else: | |
embs = self.sentence_transformer.encode(self.vocab, convert_to_tensor=True) | |
emb_response = self.sentence_transformer.encode([response], convert_to_tensor=True) | |
dists = torch.nn.CosineSimilarity(-1)(emb_response, embs) | |
best_response_idx = torch.argmax(dists) | |
best_response = self.vocab[best_response_idx] | |
return self.model.config.label2id[best_response] |