liuganghuggingface commited on
Commit
020f142
·
verified ·
1 Parent(s): 401e17a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -47
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
@@ -15,6 +14,7 @@ import glob
15
  import csv
16
  from datetime import datetime
17
  import json
 
18
 
19
  from evaluator import Evaluator
20
  # from loader import load_graph_decoder
@@ -90,6 +90,8 @@ def random_properties():
90
  def load_model(model_choice):
91
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  model = load_graph_decoder(path=model_choice)
 
 
93
  return (model, device)
94
 
95
  # Create a flagged folder if it doesn't exist
@@ -118,7 +120,7 @@ def save_interesting_log(smiles, properties, suggested_properties):
118
 
119
  # @spaces.GPU(duration=60)
120
  def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
121
- print('in generate graph', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
122
  model, device = model_state
123
 
124
  properties = [CH4, CO2, H2, N2, O2]
@@ -135,54 +137,53 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
135
  num_nodes = None if num_nodes == 0 else num_nodes
136
 
137
  for _ in range(repeating_time):
138
- try:
139
- model.to(device)
140
- generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
141
-
142
- # Create GIF if img_list is available
143
- gif_path = None
144
- if img_list and len(img_list) > 0:
145
- imgs = [np.array(pil_img) for pil_img in img_list]
146
- imgs.extend([imgs[-1]] * 10)
147
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
148
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
149
-
150
- if generated_molecule is not None:
151
- mol = Chem.MolFromSmiles(generated_molecule)
152
- if mol is not None:
153
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
154
- is_novel = standardized_smiles not in knwon_smiles['smiles'].values
155
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
156
- img = Draw.MolToImage(mol)
157
-
158
- # Evaluate the generated molecule
159
- suggested_properties = {}
160
- for prop, evaluator in evaluators.items():
161
- suggested_properties[prop] = evaluator([standardized_smiles])[0]
162
-
163
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
164
-
165
- return (
166
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
167
- f"**{nan_message}**\n\n"
168
- f"**{novelty_status}**\n\n"
169
- f"**Suggested Properties:**\n{suggested_properties_text}",
170
- img,
171
- gif_path,
172
- properties, # Add this
173
- suggested_properties # Add this
174
- )
175
- else:
176
  return (
177
- f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
178
- None,
 
 
 
179
  gif_path,
180
- properties,
181
- None,
182
  )
183
- except Exception as e:
184
- print(f"Error in generation: {e}")
185
- continue
 
 
 
 
 
 
 
 
186
 
187
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
188
 
 
 
1
  import gradio as gr
2
  from rdkit import Chem
3
  from rdkit.Chem import Draw
 
14
  import csv
15
  from datetime import datetime
16
  import json
17
+ # import spaces
18
 
19
  from evaluator import Evaluator
20
  # from loader import load_graph_decoder
 
90
  def load_model(model_choice):
91
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  model = load_graph_decoder(path=model_choice)
93
+ model.to(device)
94
+ print('in load_model', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
95
  return (model, device)
96
 
97
  # Create a flagged folder if it doesn't exist
 
120
 
121
  # @spaces.GPU(duration=60)
122
  def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
123
+ print('in generate_graph', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
124
  model, device = model_state
125
 
126
  properties = [CH4, CO2, H2, N2, O2]
 
137
  num_nodes = None if num_nodes == 0 else num_nodes
138
 
139
  for _ in range(repeating_time):
140
+ # try:
141
+ generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
142
+
143
+ # Create GIF if img_list is available
144
+ gif_path = None
145
+ if img_list and len(img_list) > 0:
146
+ imgs = [np.array(pil_img) for pil_img in img_list]
147
+ imgs.extend([imgs[-1]] * 10)
148
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
149
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
150
+
151
+ if generated_molecule is not None:
152
+ mol = Chem.MolFromSmiles(generated_molecule)
153
+ if mol is not None:
154
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
155
+ is_novel = standardized_smiles not in knwon_smiles['smiles'].values
156
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
157
+ img = Draw.MolToImage(mol)
158
+
159
+ # Evaluate the generated molecule
160
+ suggested_properties = {}
161
+ for prop, evaluator in evaluators.items():
162
+ suggested_properties[prop] = evaluator([standardized_smiles])[0]
163
+
164
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
165
+
 
 
 
 
 
 
 
 
 
 
 
 
166
  return (
167
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
168
+ f"**{nan_message}**\n\n"
169
+ f"**{novelty_status}**\n\n"
170
+ f"**Suggested Properties:**\n{suggested_properties_text}",
171
+ img,
172
  gif_path,
173
+ properties, # Add this
174
+ suggested_properties # Add this
175
  )
176
+ else:
177
+ return (
178
+ f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
179
+ None,
180
+ gif_path,
181
+ properties,
182
+ None,
183
+ )
184
+ # except Exception as e:
185
+ # print(f"Error in generation: {e}")
186
+ # continue
187
 
188
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
189