Llamole / src /webui /dataset.py
msun415's picture
Upload folder using huggingface_hub
13362e2 verified
# Copyright 2024 Llamole Team
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.utils.data import Dataset
from ..extras.constants import BOND_INDEX
def dict_to_list(data_dict, mol_properties):
return [data_dict.get(prop, float("nan")) for prop in mol_properties]
class MolQADataset(Dataset):
def __init__(self, data, tokenizer, max_len):
self.data = data
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
mol_properties = [
"BBBP",
"HIV",
"BACE",
"CO2",
"N2",
"O2",
"FFV",
"TC",
"SC",
"SA",
]
item = self.data[idx]
instruction = item["instruction"]
input_text = item["input"]
property_data = dict_to_list(item["property"], mol_properties)
property_data = torch.tensor(property_data)
# Combine instruction and input
combined_input = f"{instruction}\n{input_text}"
# Create messages for chat template
messages = [
{"role": "user", "content": combined_input}
]
# Apply chat template
chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize the chat text
encoding = self.tokenizer(
chat_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_len,
)
return {
"input_ids": encoding.input_ids.squeeze(),
"attention_mask": encoding.attention_mask.squeeze(),
"property": property_data,
}