liuganghuggingface commited on
Commit
35eb13d
·
verified ·
1 Parent(s): 3fdb021

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +47 -54
app.py CHANGED
@@ -81,7 +81,6 @@ def save_interesting_log(smiles, properties, suggested_properties):
81
  }
82
  writer.writerow(log_data)
83
 
84
- @spaces.GPU
85
  def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
86
  model, device = model_state
87
 
@@ -99,61 +98,54 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
99
  num_nodes = None if num_nodes == 0 else num_nodes
100
 
101
  for _ in range(repeating_time):
102
- try:
103
- # def generate_func():
104
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
- # model.to(device)
106
- # print('Before generation, move model to', device)
107
- # return generated_molecule, img_list
108
- # generated_molecule, img_list = generate_func()
109
-
110
- model.to(device)
111
- generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
112
-
113
- # Create GIF if img_list is available
114
- gif_path = None
115
- if img_list and len(img_list) > 0:
116
- imgs = [np.array(pil_img) for pil_img in img_list]
117
- imgs.extend([imgs[-1]] * 10)
118
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
119
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
120
-
121
- if generated_molecule is not None:
122
- mol = Chem.MolFromSmiles(generated_molecule)
123
- if mol is not None:
124
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
125
- is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
126
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
127
- img = Draw.MolToImage(mol)
128
-
129
- # Evaluate the generated molecule
130
- suggested_properties = {}
131
- for prop, evaluator in evaluators.items():
132
- suggested_properties[prop] = evaluator([standardized_smiles])[0]
133
-
134
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
135
-
136
- return (
137
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
138
- f"**{nan_message}**\n\n"
139
- f"**{novelty_status}**\n\n"
140
- f"**Suggested Properties:**\n{suggested_properties_text}",
141
- img,
142
- gif_path,
143
- properties, # Add this
144
- suggested_properties # Add this
145
- )
146
- else:
147
  return (
148
- f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
149
- None,
 
 
 
150
  gif_path,
151
- properties,
152
- None,
153
  )
154
- except Exception as e:
155
- print(f"Error in generation: {e}")
156
- continue
 
 
 
 
 
 
 
 
157
 
158
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
159
 
@@ -180,6 +172,7 @@ def numpy_to_python(obj):
180
  else:
181
  return obj
182
 
 
183
  def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
184
  result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
185
  # Check if the generation was successful
@@ -276,7 +269,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
276
  ```
277
  """)
278
 
279
- model_state = gr.State(lambda: load_model("model_all"))
280
 
281
  with gr.Row():
282
  CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
 
81
  }
82
  writer.writerow(log_data)
83
 
 
84
  def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
85
  model, device = model_state
86
 
 
98
  num_nodes = None if num_nodes == 0 else num_nodes
99
 
100
  for _ in range(repeating_time):
101
+ # try:
102
+ model.to(device)
103
+ generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
104
+
105
+ # Create GIF if img_list is available
106
+ gif_path = None
107
+ if img_list and len(img_list) > 0:
108
+ imgs = [np.array(pil_img) for pil_img in img_list]
109
+ imgs.extend([imgs[-1]] * 10)
110
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
111
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
112
+
113
+ if generated_molecule is not None:
114
+ mol = Chem.MolFromSmiles(generated_molecule)
115
+ if mol is not None:
116
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
117
+ is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
118
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
119
+ img = Draw.MolToImage(mol)
120
+
121
+ # Evaluate the generated molecule
122
+ suggested_properties = {}
123
+ for prop, evaluator in evaluators.items():
124
+ suggested_properties[prop] = evaluator([standardized_smiles])[0]
125
+
126
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
127
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return (
129
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
130
+ f"**{nan_message}**\n\n"
131
+ f"**{novelty_status}**\n\n"
132
+ f"**Suggested Properties:**\n{suggested_properties_text}",
133
+ img,
134
  gif_path,
135
+ properties, # Add this
136
+ suggested_properties # Add this
137
  )
138
+ else:
139
+ return (
140
+ f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
141
+ None,
142
+ gif_path,
143
+ properties,
144
+ None,
145
+ )
146
+ # except Exception as e:
147
+ # print(f"Error in generation: {e}")
148
+ # continue
149
 
150
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
151
 
 
172
  else:
173
  return obj
174
 
175
+ @spaces.GPU
176
  def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
177
  result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
178
  # Check if the generation was successful
 
269
  ```
270
  """)
271
 
272
+ model_state = gr.State(lambda: load_model("model_labeled"))
273
 
274
  with gr.Row():
275
  CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")