|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
|
|
from ..extras.constants import BOND_INDEX |
|
|
|
def dict_to_list(data_dict, mol_properties): |
|
return [data_dict.get(prop, float("nan")) for prop in mol_properties] |
|
|
|
class MolQADataset(Dataset): |
|
def __init__(self, data, tokenizer, max_len): |
|
self.data = data |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
mol_properties = [ |
|
"BBBP", |
|
"HIV", |
|
"BACE", |
|
"CO2", |
|
"N2", |
|
"O2", |
|
"FFV", |
|
"TC", |
|
"SC", |
|
"SA", |
|
] |
|
item = self.data[idx] |
|
instruction = item["instruction"] |
|
input_text = item["input"] |
|
property_data = dict_to_list(item["property"], mol_properties) |
|
property_data = torch.tensor(property_data) |
|
|
|
|
|
combined_input = f"{instruction}\n{input_text}" |
|
|
|
|
|
messages = [ |
|
{"role": "user", "content": combined_input} |
|
] |
|
|
|
|
|
chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
encoding = self.tokenizer( |
|
chat_text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=self.max_len, |
|
) |
|
|
|
return { |
|
"input_ids": encoding.input_ids.squeeze(), |
|
"attention_mask": encoding.attention_mask.squeeze(), |
|
"property": property_data, |
|
} |