File size: 8,554 Bytes
13362e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# Copyright 2024 Llamole Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional, Dict, Any

from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer
from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX
from ..extras.misc import get_logits_processor
from ..extras.ploting import plot_loss
from ..model import load_tokenizer
from ..hparams import get_infer_args, get_train_args
from ..model import GraphLLMForCausalMLM
from .dataset import MolQADataset

import re
import os
import json
import math
import torch
from torch.utils.data import DataLoader

if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments

    from ..hparams import (
        DataArguments,
        FinetuningArguments,
        GeneratingArguments,
        ModelArguments,
    )

def remove_extra_spaces(text):
    # Replace multiple spaces with a single space
    cleaned_text = re.sub(r'\s+', ' ', text)
    # Strip leading and trailing spaces
    return cleaned_text.strip()

def run_eval(args: Optional[Dict[str, Any]] = None) -> None:
    print(args)
    raise ValueError('stop')
    model_args, data_args, training_args, finetuning_args, generating_args = (
        get_train_args(args)
    )

    if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]:
        run_molqa(
            model_args, data_args, training_args, finetuning_args, generating_args
        )
    else:
        raise ValueError("Unknown dataset: {}.".format(data_args.dataset))


def run_molqa(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
):
    tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"]

    data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json")
    with open(data_info_path, "r") as f:
        dataset_info = json.load(f)

    tokenizer.pad_token = tokenizer.eos_token
    dataset_name = data_args.dataset.strip()
    try:
        filename = dataset_info[dataset_name]["file_name"]
    except KeyError:
        raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json")
    data_path = os.path.join(data_args.dataset_dir, f"{filename}")
    with open(data_path, "r") as f:
        original_data = json.load(f)

    # Create dataset and dataloader
    dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len)
    dataloader = DataLoader(
        dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False
    )

    gen_kwargs = generating_args.to_dict()
    gen_kwargs["eos_token_id"] = [
        tokenizer.eos_token_id
    ] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
    gen_kwargs["logits_processor"] = get_logits_processor()

    model = GraphLLMForCausalMLM.from_pretrained(
        tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True
    )

    all_results = []
    property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"]

    # Phase 1: Molecular Design
    global_idx = 0
    all_smiles = []
    for batch_idx, batch in enumerate(dataloader):
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        property_data = batch["property"].to(model.device)
        model.eval()
        with torch.no_grad():
            all_info_dict = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                molecule_properties=property_data,
                do_molecular_design=True,
                do_retrosynthesis=False,
                rollback=True,
                **gen_kwargs,
            )

            batch_results = []
            for i in range(len(all_info_dict["smiles_list"])):
                original_data_idx = global_idx + i
                original_item = original_data[original_data_idx]

                llm_response = "".join(item for item in all_info_dict["text_lists"][i])
                result = {
                    "qa_idx": original_data_idx,
                    "instruction": original_item["instruction"],
                    "input": original_item["input"],
                    "llm_response": llm_response,
                    "response_design": remove_extra_spaces(llm_response),
                    "llm_smiles": all_info_dict["smiles_list"][i],
                    "property": {},
                }

                # Add non-NaN property values
                for j, prop_name in enumerate(property_names):
                    prop_value = property_data[i][j].item()
                    if not math.isnan(prop_value):
                        result["property"][prop_name] = prop_value

                batch_results.append(result)

            all_results.extend(batch_results)
            all_smiles.extend([result['llm_smiles'] for result in batch_results])
            global_idx += len(batch_results)
    
    # Phase 2: Retrosynthesis
    retro_batch_start = 0
    for batch_idx, batch in enumerate(dataloader):
        
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        batch_size = input_ids.shape[0]
        batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size]

        model.eval()
        with torch.no_grad():
            all_info_dict = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                do_molecular_design=False,
                do_retrosynthesis=True,
                input_smiles_list=batch_smiles,
                expansion_topk=50,
                iterations=100,
                max_planning_time=30,
                **gen_kwargs,
            )

            batch_results = []
            for i in range(batch_size):
                result = all_results[retro_batch_start + i]
                retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]]
                result["llm_reactions"] = []
                if retro_plan["success"]:
                    for reaction, template, cost in zip(
                        retro_plan["reaction_list"],
                        retro_plan["templates"],
                        retro_plan["cost"],
                    ):
                        result["llm_reactions"].append(
                            {"reaction": reaction, "template": template, "cost": cost}
                        )

                # new_text = "".join(item for item in all_info_dict["text_lists"][i])
                if None in all_info_dict["text_lists"][i]:
                    print(f"List contains None: {all_info_dict['text_lists'][i]}")
                    new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None)
                else:
                    new_text = "".join(item for item in all_info_dict["text_lists"][i])
                    
                result["llm_response"] += new_text
                result["llm_response"] = remove_extra_spaces(result["llm_response"])
                result["response_retro"] = remove_extra_spaces(new_text)
                batch_results.append(result)

            retro_batch_start += batch_size
    
    print('all_results', all_results)
    print("\nSummary of results:")
    print_len = min(5, len(all_results))
    for result in all_results[:print_len]:
        print(f"\nData point {result['qa_idx']}:")
        print(f"  Instruction: {result['instruction']}")
        print(f"  Input: {result['input']}")
        print(f"  LLM Response: {result['llm_response']}")
        print(f"  LLM SMILES: {result['llm_smiles']}")
        print(f"  Number of reactions: {len(result['llm_reactions'])}")
        for prop_name, prop_value in result["property"].items():
            print(f"  {prop_name}: {prop_value}")

    print("\nAll data processed successfully.")