import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoConfig import json from torch.utils.data import Dataset, DataLoader instruct_dataset = f'./llava_instruct_150k.json' with open(instruct_dataset, 'r') as f: instruct_data = json.load(f) class CustomTextDataset(Dataset): def __init__(self, json_data, image_embedding_dict, tokenizer, maxContext=512): self.image_embedding_dict = image_embedding_dict self.tokenizer = tokenizer self.json_data = json_data self.maxContext = maxContext self.entries = [] for entry in json_data: image = entry['image'] image_embedding = self.getEmbeddingForImage(image) if image_embedding is None: continue conversations = entry['conversations'] for i in range(len(conversations)): if conversations[i]['from'] == 'human': if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512: continue question = 'Question: ' + conversations[i]['value'].lstrip('\n') answer = 'Answer: ' + conversations[i + 1]['value'] self.entries.append({ 'image_name': image, 'image_embedding': image_embedding, 'Question': question, 'Answer': answer, 'QnAText': question + answer }) print('------------- num entries = -----------------') print(len(self.entries)) def getEmbeddingForImage(self, image): if image in self.image_embedding_dict: image_embedding = self.image_embedding_dict[image] return image_embedding else: return None def __len__(self): return len(self.entries) def __getitem__(self, idx): entry = self.entries[idx] image_name = entry['image_name'] Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True) QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True) QTokensLength = len(Q_caption_tokens) QnA_length = len(QnA_captions_tokens) QnA_captions_tokens = QnA_captions_tokens + \ [tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens)) return {'image_name': entry['image_name'], 'QText': entry['Question'], 'AText': entry['Answer'], 'image_embedding': entry['image_embedding'].to("cuda"), 'QnA_tokens': torch.tensor(QnA_captions_tokens), 'QTokensLength': QTokensLength, 'QnA_length': QnA_length } imgEmbDict = torch.load('img_embeddings_dict.pth') model_name = "microsoft/phi-2" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token custom_dataset = CustomTextDataset(instruct_data, imgEmbDict, tokenizer) custom_dataloader = DataLoader(custom_dataset, batch_size=10, shuffle=True)