Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import spaces
|
2 |
import gradio as gr
|
3 |
from rdkit import Chem
|
4 |
from rdkit.Chem import Draw
|
@@ -15,6 +14,7 @@ import glob
|
|
15 |
import csv
|
16 |
from datetime import datetime
|
17 |
import json
|
|
|
18 |
|
19 |
from evaluator import Evaluator
|
20 |
# from loader import load_graph_decoder
|
@@ -90,6 +90,8 @@ def random_properties():
|
|
90 |
def load_model(model_choice):
|
91 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
92 |
model = load_graph_decoder(path=model_choice)
|
|
|
|
|
93 |
return (model, device)
|
94 |
|
95 |
# Create a flagged folder if it doesn't exist
|
@@ -118,7 +120,7 @@ def save_interesting_log(smiles, properties, suggested_properties):
|
|
118 |
|
119 |
# @spaces.GPU(duration=60)
|
120 |
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
121 |
-
print('in
|
122 |
model, device = model_state
|
123 |
|
124 |
properties = [CH4, CO2, H2, N2, O2]
|
@@ -135,54 +137,53 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
|
|
135 |
num_nodes = None if num_nodes == 0 else num_nodes
|
136 |
|
137 |
for _ in range(repeating_time):
|
138 |
-
try:
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
return (
|
166 |
-
f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
|
167 |
-
f"**{nan_message}**\n\n"
|
168 |
-
f"**{novelty_status}**\n\n"
|
169 |
-
f"**Suggested Properties:**\n{suggested_properties_text}",
|
170 |
-
img,
|
171 |
-
gif_path,
|
172 |
-
properties, # Add this
|
173 |
-
suggested_properties # Add this
|
174 |
-
)
|
175 |
-
else:
|
176 |
return (
|
177 |
-
f"**
|
178 |
-
|
|
|
|
|
|
|
179 |
gif_path,
|
180 |
-
properties,
|
181 |
-
|
182 |
)
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
|
188 |
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from rdkit import Chem
|
3 |
from rdkit.Chem import Draw
|
|
|
14 |
import csv
|
15 |
from datetime import datetime
|
16 |
import json
|
17 |
+
# import spaces
|
18 |
|
19 |
from evaluator import Evaluator
|
20 |
# from loader import load_graph_decoder
|
|
|
90 |
def load_model(model_choice):
|
91 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
92 |
model = load_graph_decoder(path=model_choice)
|
93 |
+
model.to(device)
|
94 |
+
print('in load_model', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
95 |
return (model, device)
|
96 |
|
97 |
# Create a flagged folder if it doesn't exist
|
|
|
120 |
|
121 |
# @spaces.GPU(duration=60)
|
122 |
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
123 |
+
print('in generate_graph', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
124 |
model, device = model_state
|
125 |
|
126 |
properties = [CH4, CO2, H2, N2, O2]
|
|
|
137 |
num_nodes = None if num_nodes == 0 else num_nodes
|
138 |
|
139 |
for _ in range(repeating_time):
|
140 |
+
# try:
|
141 |
+
generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
|
142 |
+
|
143 |
+
# Create GIF if img_list is available
|
144 |
+
gif_path = None
|
145 |
+
if img_list and len(img_list) > 0:
|
146 |
+
imgs = [np.array(pil_img) for pil_img in img_list]
|
147 |
+
imgs.extend([imgs[-1]] * 10)
|
148 |
+
gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
|
149 |
+
imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
|
150 |
+
|
151 |
+
if generated_molecule is not None:
|
152 |
+
mol = Chem.MolFromSmiles(generated_molecule)
|
153 |
+
if mol is not None:
|
154 |
+
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
|
155 |
+
is_novel = standardized_smiles not in knwon_smiles['smiles'].values
|
156 |
+
novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
|
157 |
+
img = Draw.MolToImage(mol)
|
158 |
+
|
159 |
+
# Evaluate the generated molecule
|
160 |
+
suggested_properties = {}
|
161 |
+
for prop, evaluator in evaluators.items():
|
162 |
+
suggested_properties[prop] = evaluator([standardized_smiles])[0]
|
163 |
+
|
164 |
+
suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
|
165 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
return (
|
167 |
+
f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
|
168 |
+
f"**{nan_message}**\n\n"
|
169 |
+
f"**{novelty_status}**\n\n"
|
170 |
+
f"**Suggested Properties:**\n{suggested_properties_text}",
|
171 |
+
img,
|
172 |
gif_path,
|
173 |
+
properties, # Add this
|
174 |
+
suggested_properties # Add this
|
175 |
)
|
176 |
+
else:
|
177 |
+
return (
|
178 |
+
f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
|
179 |
+
None,
|
180 |
+
gif_path,
|
181 |
+
properties,
|
182 |
+
None,
|
183 |
+
)
|
184 |
+
# except Exception as e:
|
185 |
+
# print(f"Error in generation: {e}")
|
186 |
+
# continue
|
187 |
|
188 |
return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
|
189 |
|