File size: 8,554 Bytes
13362e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
# Copyright 2024 Llamole Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Dict, Any
from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer
from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX
from ..extras.misc import get_logits_processor
from ..extras.ploting import plot_loss
from ..model import load_tokenizer
from ..hparams import get_infer_args, get_train_args
from ..model import GraphLLMForCausalMLM
from .dataset import MolQADataset
import re
import os
import json
import math
import torch
from torch.utils.data import DataLoader
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from ..hparams import (
DataArguments,
FinetuningArguments,
GeneratingArguments,
ModelArguments,
)
def remove_extra_spaces(text):
# Replace multiple spaces with a single space
cleaned_text = re.sub(r'\s+', ' ', text)
# Strip leading and trailing spaces
return cleaned_text.strip()
def run_eval(args: Optional[Dict[str, Any]] = None) -> None:
print(args)
raise ValueError('stop')
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args(args)
)
if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]:
run_molqa(
model_args, data_args, training_args, finetuning_args, generating_args
)
else:
raise ValueError("Unknown dataset: {}.".format(data_args.dataset))
def run_molqa(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
):
tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"]
data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json")
with open(data_info_path, "r") as f:
dataset_info = json.load(f)
tokenizer.pad_token = tokenizer.eos_token
dataset_name = data_args.dataset.strip()
try:
filename = dataset_info[dataset_name]["file_name"]
except KeyError:
raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json")
data_path = os.path.join(data_args.dataset_dir, f"{filename}")
with open(data_path, "r") as f:
original_data = json.load(f)
# Create dataset and dataloader
dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len)
dataloader = DataLoader(
dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False
)
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = [
tokenizer.eos_token_id
] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
model = GraphLLMForCausalMLM.from_pretrained(
tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True
)
all_results = []
property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"]
# Phase 1: Molecular Design
global_idx = 0
all_smiles = []
for batch_idx, batch in enumerate(dataloader):
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
property_data = batch["property"].to(model.device)
model.eval()
with torch.no_grad():
all_info_dict = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
molecule_properties=property_data,
do_molecular_design=True,
do_retrosynthesis=False,
rollback=True,
**gen_kwargs,
)
batch_results = []
for i in range(len(all_info_dict["smiles_list"])):
original_data_idx = global_idx + i
original_item = original_data[original_data_idx]
llm_response = "".join(item for item in all_info_dict["text_lists"][i])
result = {
"qa_idx": original_data_idx,
"instruction": original_item["instruction"],
"input": original_item["input"],
"llm_response": llm_response,
"response_design": remove_extra_spaces(llm_response),
"llm_smiles": all_info_dict["smiles_list"][i],
"property": {},
}
# Add non-NaN property values
for j, prop_name in enumerate(property_names):
prop_value = property_data[i][j].item()
if not math.isnan(prop_value):
result["property"][prop_name] = prop_value
batch_results.append(result)
all_results.extend(batch_results)
all_smiles.extend([result['llm_smiles'] for result in batch_results])
global_idx += len(batch_results)
# Phase 2: Retrosynthesis
retro_batch_start = 0
for batch_idx, batch in enumerate(dataloader):
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
batch_size = input_ids.shape[0]
batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size]
model.eval()
with torch.no_grad():
all_info_dict = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_molecular_design=False,
do_retrosynthesis=True,
input_smiles_list=batch_smiles,
expansion_topk=50,
iterations=100,
max_planning_time=30,
**gen_kwargs,
)
batch_results = []
for i in range(batch_size):
result = all_results[retro_batch_start + i]
retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]]
result["llm_reactions"] = []
if retro_plan["success"]:
for reaction, template, cost in zip(
retro_plan["reaction_list"],
retro_plan["templates"],
retro_plan["cost"],
):
result["llm_reactions"].append(
{"reaction": reaction, "template": template, "cost": cost}
)
# new_text = "".join(item for item in all_info_dict["text_lists"][i])
if None in all_info_dict["text_lists"][i]:
print(f"List contains None: {all_info_dict['text_lists'][i]}")
new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None)
else:
new_text = "".join(item for item in all_info_dict["text_lists"][i])
result["llm_response"] += new_text
result["llm_response"] = remove_extra_spaces(result["llm_response"])
result["response_retro"] = remove_extra_spaces(new_text)
batch_results.append(result)
retro_batch_start += batch_size
print('all_results', all_results)
print("\nSummary of results:")
print_len = min(5, len(all_results))
for result in all_results[:print_len]:
print(f"\nData point {result['qa_idx']}:")
print(f" Instruction: {result['instruction']}")
print(f" Input: {result['input']}")
print(f" LLM Response: {result['llm_response']}")
print(f" LLM SMILES: {result['llm_smiles']}")
print(f" Number of reactions: {len(result['llm_reactions'])}")
for prop_name, prop_value in result["property"].items():
print(f" {prop_name}: {prop_value}")
print("\nAll data processed successfully.") |