Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
import gradio as gr
|
2 |
import spaces
|
3 |
-
|
4 |
from rdkit import Chem
|
5 |
from rdkit.Chem import Draw
|
6 |
import numpy as np
|
@@ -16,45 +15,9 @@ import glob
|
|
16 |
import csv
|
17 |
from datetime import datetime
|
18 |
import json
|
19 |
-
# import spaces
|
20 |
|
21 |
from evaluator import Evaluator
|
22 |
-
|
23 |
-
|
24 |
-
### loader
|
25 |
-
from graph_decoder.diffusion_model import GraphDiT
|
26 |
-
def count_parameters(model):
|
27 |
-
r"""
|
28 |
-
Returns the number of trainable parameters and number of all parameters in the model.
|
29 |
-
"""
|
30 |
-
trainable_params, all_param = 0, 0
|
31 |
-
for param in model.parameters():
|
32 |
-
num_params = param.numel()
|
33 |
-
all_param += num_params
|
34 |
-
if param.requires_grad:
|
35 |
-
trainable_params += num_params
|
36 |
-
|
37 |
-
return trainable_params, all_param
|
38 |
-
|
39 |
-
def load_graph_decoder(path='model_labeled'):
|
40 |
-
model_config_path = f"{path}/config.yaml"
|
41 |
-
data_info_path = f"{path}/data.meta.json"
|
42 |
-
|
43 |
-
model = GraphDiT(
|
44 |
-
model_config_path=model_config_path,
|
45 |
-
data_info_path=data_info_path,
|
46 |
-
# model_dtype=torch.float16,
|
47 |
-
model_dtype=torch.float32,
|
48 |
-
)
|
49 |
-
model.init_model(path)
|
50 |
-
model.disable_grads()
|
51 |
-
|
52 |
-
trainable_params, all_param = count_parameters(model)
|
53 |
-
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
54 |
-
path, trainable_params, all_param, 100 * trainable_params / all_param
|
55 |
-
)
|
56 |
-
print(param_stats)
|
57 |
-
return model
|
58 |
|
59 |
# Load the CSV data
|
60 |
known_labels = pd.read_csv('data/known_labels.csv')
|
@@ -92,8 +55,6 @@ def random_properties():
|
|
92 |
def load_model(model_choice):
|
93 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
94 |
model = load_graph_decoder(path=model_choice)
|
95 |
-
# model.to(device)
|
96 |
-
print('in load_model', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
97 |
return (model, device)
|
98 |
|
99 |
# Create a flagged folder if it doesn't exist
|
@@ -122,7 +83,6 @@ def save_interesting_log(smiles, properties, suggested_properties):
|
|
122 |
|
123 |
@spaces.GPU(duration=60)
|
124 |
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
125 |
-
print('in generate_graph', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
126 |
model, device = model_state
|
127 |
|
128 |
properties = [CH4, CO2, H2, N2, O2]
|
@@ -139,53 +99,61 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
|
|
139 |
num_nodes = None if num_nodes == 0 else num_nodes
|
140 |
|
141 |
for _ in range(repeating_time):
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
if
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
return (
|
169 |
-
f"**
|
170 |
-
|
171 |
-
f"**{novelty_status}**\n\n"
|
172 |
-
f"**Suggested Properties:**\n{suggested_properties_text}",
|
173 |
-
img,
|
174 |
gif_path,
|
175 |
-
properties,
|
176 |
-
|
177 |
)
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
None,
|
182 |
-
gif_path,
|
183 |
-
properties,
|
184 |
-
None,
|
185 |
-
)
|
186 |
-
# except Exception as e:
|
187 |
-
# print(f"Error in generation: {e}")
|
188 |
-
# continue
|
189 |
|
190 |
return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
|
191 |
|
|
|
|
|
1 |
import spaces
|
2 |
+
import gradio as gr
|
3 |
from rdkit import Chem
|
4 |
from rdkit.Chem import Draw
|
5 |
import numpy as np
|
|
|
15 |
import csv
|
16 |
from datetime import datetime
|
17 |
import json
|
|
|
18 |
|
19 |
from evaluator import Evaluator
|
20 |
+
from loader import load_graph_decoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Load the CSV data
|
23 |
known_labels = pd.read_csv('data/known_labels.csv')
|
|
|
55 |
def load_model(model_choice):
|
56 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
57 |
model = load_graph_decoder(path=model_choice)
|
|
|
|
|
58 |
return (model, device)
|
59 |
|
60 |
# Create a flagged folder if it doesn't exist
|
|
|
83 |
|
84 |
@spaces.GPU(duration=60)
|
85 |
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
|
|
|
86 |
model, device = model_state
|
87 |
|
88 |
properties = [CH4, CO2, H2, N2, O2]
|
|
|
99 |
num_nodes = None if num_nodes == 0 else num_nodes
|
100 |
|
101 |
for _ in range(repeating_time):
|
102 |
+
try:
|
103 |
+
# def generate_func():
|
104 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
105 |
+
# model.to(device)
|
106 |
+
# print('Before generation, move model to', device)
|
107 |
+
# return generated_molecule, img_list
|
108 |
+
# generated_molecule, img_list = generate_func()
|
109 |
+
|
110 |
+
model.to(device)
|
111 |
+
generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
|
112 |
+
|
113 |
+
# Create GIF if img_list is available
|
114 |
+
gif_path = None
|
115 |
+
if img_list and len(img_list) > 0:
|
116 |
+
imgs = [np.array(pil_img) for pil_img in img_list]
|
117 |
+
imgs.extend([imgs[-1]] * 10)
|
118 |
+
gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
|
119 |
+
imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
|
120 |
+
|
121 |
+
if generated_molecule is not None:
|
122 |
+
mol = Chem.MolFromSmiles(generated_molecule)
|
123 |
+
if mol is not None:
|
124 |
+
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
|
125 |
+
is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
|
126 |
+
novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
|
127 |
+
img = Draw.MolToImage(mol)
|
128 |
+
|
129 |
+
# Evaluate the generated molecule
|
130 |
+
suggested_properties = {}
|
131 |
+
for prop, evaluator in evaluators.items():
|
132 |
+
suggested_properties[prop] = evaluator([standardized_smiles])[0]
|
133 |
+
|
134 |
+
suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
|
135 |
+
|
136 |
+
return (
|
137 |
+
f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
|
138 |
+
f"**{nan_message}**\n\n"
|
139 |
+
f"**{novelty_status}**\n\n"
|
140 |
+
f"**Suggested Properties:**\n{suggested_properties_text}",
|
141 |
+
img,
|
142 |
+
gif_path,
|
143 |
+
properties, # Add this
|
144 |
+
suggested_properties # Add this
|
145 |
+
)
|
146 |
+
else:
|
147 |
return (
|
148 |
+
f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
|
149 |
+
None,
|
|
|
|
|
|
|
150 |
gif_path,
|
151 |
+
properties,
|
152 |
+
None,
|
153 |
)
|
154 |
+
except Exception as e:
|
155 |
+
print(f"Error in generation: {e}")
|
156 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
|
159 |
|