Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoConfig | |
import json | |
from torch.utils.data import Dataset, DataLoader | |
instruct_dataset = f'./llava_instruct_150k.json' | |
with open(instruct_dataset, 'r') as f: | |
instruct_data = json.load(f) | |
class CustomTextDataset(Dataset): | |
def __init__(self, json_data, image_embedding_dict, tokenizer, maxContext=512): | |
self.image_embedding_dict = image_embedding_dict | |
self.tokenizer = tokenizer | |
self.json_data = json_data | |
self.maxContext = maxContext | |
self.entries = [] | |
for entry in json_data: | |
image = entry['image'] | |
image_embedding = self.getEmbeddingForImage(image) | |
if image_embedding is None: | |
continue | |
conversations = entry['conversations'] | |
for i in range(len(conversations)): | |
if conversations[i]['from'] == 'human': | |
if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512: | |
continue | |
question = 'Question: ' + conversations[i]['value'].lstrip('<image>\n') | |
answer = 'Answer: ' + conversations[i + 1]['value'] | |
self.entries.append({ | |
'image_name': image, | |
'image_embedding': image_embedding, | |
'Question': question, | |
'Answer': answer, | |
'QnAText': question + answer | |
}) | |
print('------------- num entries = -----------------') | |
print(len(self.entries)) | |
def getEmbeddingForImage(self, image): | |
if image in self.image_embedding_dict: | |
image_embedding = self.image_embedding_dict[image] | |
return image_embedding | |
else: | |
return None | |
def __len__(self): | |
return len(self.entries) | |
def __getitem__(self, idx): | |
entry = self.entries[idx] | |
image_name = entry['image_name'] | |
Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True) | |
QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True) | |
QTokensLength = len(Q_caption_tokens) | |
QnA_length = len(QnA_captions_tokens) | |
QnA_captions_tokens = QnA_captions_tokens + \ | |
[tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens)) | |
return {'image_name': entry['image_name'], | |
'QText': entry['Question'], | |
'AText': entry['Answer'], | |
'image_embedding': entry['image_embedding'].to("cuda"), | |
'QnA_tokens': torch.tensor(QnA_captions_tokens), | |
'QTokensLength': QTokensLength, | |
'QnA_length': QnA_length | |
} | |
imgEmbDict = torch.load('img_embeddings_dict.pth') | |
model_name = "microsoft/phi-2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
custom_dataset = CustomTextDataset(instruct_data, imgEmbDict, tokenizer) | |
custom_dataloader = DataLoader(custom_dataset, batch_size=10, shuffle=True) | |