John6666's picture
Upload 351 files
e84842d verified
raw
history blame
1.25 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
def prepare_qa_input(sample, num_captions, num_captions_fid):
sample_question_captions = []
for question, captions in zip(sample['text_input'], sample['captions']):
assert isinstance(captions, list)
question_captions = []
question_caption = ''
for cap_id, cap_ in enumerate(captions[0:num_captions]):
question_caption += (cap_.strip() + '. ')
if (cap_id + 1) != num_captions and ((cap_id + 1) % num_captions_fid == 0):
question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip()
question_captions.append(question_caption)
question_caption = ''
if (cap_id + 1) == num_captions:
question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip()
question_captions.append(question_caption)
sample_question_captions.append(question_captions)
sample['question_captions'] = sample_question_captions