""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import torch def prepare_qa_input(sample, num_captions, num_captions_fid): sample_question_captions = [] for question, captions in zip(sample['text_input'], sample['captions']): assert isinstance(captions, list) question_captions = [] question_caption = '' for cap_id, cap_ in enumerate(captions[0:num_captions]): question_caption += (cap_.strip() + '. ') if (cap_id + 1) != num_captions and ((cap_id + 1) % num_captions_fid == 0): question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip() question_captions.append(question_caption) question_caption = '' if (cap_id + 1) == num_captions: question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip() question_captions.append(question_caption) sample_question_captions.append(question_captions) sample['question_captions'] = sample_question_captions