liuganghuggingface commited on
Commit
9298151
·
verified ·
1 Parent(s): 0c60789

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +54 -86
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import gradio as gr
2
  import spaces
3
-
4
  from rdkit import Chem
5
  from rdkit.Chem import Draw
6
  import numpy as np
@@ -16,45 +15,9 @@ import glob
16
  import csv
17
  from datetime import datetime
18
  import json
19
- # import spaces
20
 
21
  from evaluator import Evaluator
22
- # from loader import load_graph_decoder
23
-
24
- ### loader
25
- from graph_decoder.diffusion_model import GraphDiT
26
- def count_parameters(model):
27
- r"""
28
- Returns the number of trainable parameters and number of all parameters in the model.
29
- """
30
- trainable_params, all_param = 0, 0
31
- for param in model.parameters():
32
- num_params = param.numel()
33
- all_param += num_params
34
- if param.requires_grad:
35
- trainable_params += num_params
36
-
37
- return trainable_params, all_param
38
-
39
- def load_graph_decoder(path='model_labeled'):
40
- model_config_path = f"{path}/config.yaml"
41
- data_info_path = f"{path}/data.meta.json"
42
-
43
- model = GraphDiT(
44
- model_config_path=model_config_path,
45
- data_info_path=data_info_path,
46
- # model_dtype=torch.float16,
47
- model_dtype=torch.float32,
48
- )
49
- model.init_model(path)
50
- model.disable_grads()
51
-
52
- trainable_params, all_param = count_parameters(model)
53
- param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
54
- path, trainable_params, all_param, 100 * trainable_params / all_param
55
- )
56
- print(param_stats)
57
- return model
58
 
59
  # Load the CSV data
60
  known_labels = pd.read_csv('data/known_labels.csv')
@@ -92,8 +55,6 @@ def random_properties():
92
  def load_model(model_choice):
93
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
  model = load_graph_decoder(path=model_choice)
95
- # model.to(device)
96
- print('in load_model', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
97
  return (model, device)
98
 
99
  # Create a flagged folder if it doesn't exist
@@ -122,7 +83,6 @@ def save_interesting_log(smiles, properties, suggested_properties):
122
 
123
  @spaces.GPU(duration=60)
124
  def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
125
- print('in generate_graph', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
126
  model, device = model_state
127
 
128
  properties = [CH4, CO2, H2, N2, O2]
@@ -139,53 +99,61 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
139
  num_nodes = None if num_nodes == 0 else num_nodes
140
 
141
  for _ in range(repeating_time):
142
- # try:
143
- generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
144
-
145
- # Create GIF if img_list is available
146
- gif_path = None
147
- if img_list and len(img_list) > 0:
148
- imgs = [np.array(pil_img) for pil_img in img_list]
149
- imgs.extend([imgs[-1]] * 10)
150
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
151
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
152
-
153
- if generated_molecule is not None:
154
- mol = Chem.MolFromSmiles(generated_molecule)
155
- if mol is not None:
156
- standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
157
- is_novel = standardized_smiles not in knwon_smiles['smiles'].values
158
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
159
- img = Draw.MolToImage(mol)
160
-
161
- # Evaluate the generated molecule
162
- suggested_properties = {}
163
- for prop, evaluator in evaluators.items():
164
- suggested_properties[prop] = evaluator([standardized_smiles])[0]
165
-
166
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
167
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return (
169
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
170
- f"**{nan_message}**\n\n"
171
- f"**{novelty_status}**\n\n"
172
- f"**Suggested Properties:**\n{suggested_properties_text}",
173
- img,
174
  gif_path,
175
- properties, # Add this
176
- suggested_properties # Add this
177
  )
178
- else:
179
- return (
180
- f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
181
- None,
182
- gif_path,
183
- properties,
184
- None,
185
- )
186
- # except Exception as e:
187
- # print(f"Error in generation: {e}")
188
- # continue
189
 
190
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
191
 
 
 
1
  import spaces
2
+ import gradio as gr
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
5
  import numpy as np
 
15
  import csv
16
  from datetime import datetime
17
  import json
 
18
 
19
  from evaluator import Evaluator
20
+ from loader import load_graph_decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Load the CSV data
23
  known_labels = pd.read_csv('data/known_labels.csv')
 
55
  def load_model(model_choice):
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  model = load_graph_decoder(path=model_choice)
 
 
58
  return (model, device)
59
 
60
  # Create a flagged folder if it doesn't exist
 
83
 
84
  @spaces.GPU(duration=60)
85
  def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
 
86
  model, device = model_state
87
 
88
  properties = [CH4, CO2, H2, N2, O2]
 
99
  num_nodes = None if num_nodes == 0 else num_nodes
100
 
101
  for _ in range(repeating_time):
102
+ try:
103
+ # def generate_func():
104
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
+ # model.to(device)
106
+ # print('Before generation, move model to', device)
107
+ # return generated_molecule, img_list
108
+ # generated_molecule, img_list = generate_func()
109
+
110
+ model.to(device)
111
+ generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
112
+
113
+ # Create GIF if img_list is available
114
+ gif_path = None
115
+ if img_list and len(img_list) > 0:
116
+ imgs = [np.array(pil_img) for pil_img in img_list]
117
+ imgs.extend([imgs[-1]] * 10)
118
+ gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
119
+ imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
120
+
121
+ if generated_molecule is not None:
122
+ mol = Chem.MolFromSmiles(generated_molecule)
123
+ if mol is not None:
124
+ standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
125
+ is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
126
+ novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
127
+ img = Draw.MolToImage(mol)
128
+
129
+ # Evaluate the generated molecule
130
+ suggested_properties = {}
131
+ for prop, evaluator in evaluators.items():
132
+ suggested_properties[prop] = evaluator([standardized_smiles])[0]
133
+
134
+ suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
135
+
136
+ return (
137
+ f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
138
+ f"**{nan_message}**\n\n"
139
+ f"**{novelty_status}**\n\n"
140
+ f"**Suggested Properties:**\n{suggested_properties_text}",
141
+ img,
142
+ gif_path,
143
+ properties, # Add this
144
+ suggested_properties # Add this
145
+ )
146
+ else:
147
  return (
148
+ f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
149
+ None,
 
 
 
150
  gif_path,
151
+ properties,
152
+ None,
153
  )
154
+ except Exception as e:
155
+ print(f"Error in generation: {e}")
156
+ continue
 
 
 
 
 
 
 
 
157
 
158
  return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
159