bcadkins01 commited on
Commit
111a2ce
·
verified ·
1 Parent(s): 085e6ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -90
app.py CHANGED
@@ -10,117 +10,92 @@ import io
10
  from PIL import Image
11
  import cairosvg
12
  import pandas as pd
 
13
 
14
- # Page Configuration
15
- st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide')
 
 
 
16
 
17
- # Load Models
18
  @st.cache_resource(show_spinner="Loading Models...", ttl=600)
19
  def load_models():
20
- # Load your molecule generation model
 
 
 
 
21
  model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
22
  access_token = os.getenv("HUGGING_FACE_TOKEN")
23
- model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token)
24
- tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token)
25
- # Load ADMET-AI model
 
 
 
 
26
  admet_model = ADMETModel()
 
27
  return model, tokenizer, admet_model
28
 
 
29
  model, tokenizer, admet_model = load_models()
30
 
31
- # Set Generation Parameters
32
  st.sidebar.header('Generation Parameters')
33
- creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1)
34
- num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5)
35
 
36
- # String Format Option
37
- string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE'))
 
 
 
 
 
 
 
38
 
39
- # Generate Molecules Button
40
- if st.button('Generate Molecules'):
41
- st.info("Generating molecules... Please wait.")
42
- # Generate molecules
43
- core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure
44
- input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
45
- output_ids = model.generate(
46
- input_ids=input_ids,
47
- max_length=128,
48
- temperature=creativity,
49
- do_sample=True,
50
- top_k=50,
51
- num_return_sequences=num_molecules,
52
- num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
53
- )
54
- generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
55
- molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, len(generated_smiles) + 1)]
56
-
57
- # Create DataFrame for generated molecules
58
- df_molecules = pd.DataFrame({
59
- 'Molecule Name': molecule_names,
60
- 'SMILES': generated_smiles
61
- })
62
-
63
- # Display generated SMILES for debugging
64
- st.write("Generated SMILES:")
65
- st.write(df_molecules)
66
-
67
- # ADMET Predictions
68
- preds = admet_model.predict(smiles=df_molecules['SMILES'].tolist())
69
-
70
- # Ensure 'SMILES' is a column in preds
71
- if 'SMILES' not in preds.columns:
72
- preds['SMILES'] = df_molecules['SMILES']
73
-
74
- # Merge predictions with generated molecules
75
- df_results = pd.merge(df_molecules, preds, on='SMILES', how='inner')
76
-
77
- # Set 'Molecule Name' as index
78
- df_results.set_index('Molecule Name', inplace=True)
79
-
80
- # Display Molecules
81
- st.subheader('Generated Molecules')
82
- cols_per_row = min(5, len(df_results))
83
- cols = st.columns(cols_per_row)
84
- for idx, (mol_name, row) in enumerate(df_results.iterrows()):
85
- smiles = row['SMILES']
86
- img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE'))
87
- with cols[idx % cols_per_row]:
88
- if isinstance(img, Image.Image):
89
- st.image(img, caption=mol_name)
90
- else:
91
- st.error(f"Could not generate image for {mol_name}")
92
- # Display molecule string
93
- string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
94
- st.code(string_to_display)
95
- # Copy-to-clipboard functionality
96
- st_copy_button(string_to_display, key=f'copy_{mol_name}')
97
- # Display ADMET properties
98
- st.write("**ADMET Properties:**")
99
- st.write(row.drop(['SMILES']))
100
- else:
101
- st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
102
 
 
 
 
 
 
 
103
 
104
- # Function Definitions
105
  def generate_molecule_image(input_string, use_safe_visualization=True):
 
 
 
 
106
  try:
107
  if use_safe_visualization:
108
  try:
109
- # Attempt to decode as SAFE string
110
  smiles = safe.decode(input_string)
111
- # Encode back to SAFE string
112
  safe_string = safe.encode(smiles)
113
- except Exception:
114
- # If decoding fails, assume input is SMILES and encode to SAFE
115
- safe_string = safe.encode(input_string)
116
- # Generate SVG image with fragment highlights
 
117
  svg_str = safe.to_image(safe_string)
118
- # Convert SVG to PNG bytes
119
  png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
120
- # Create an image object
121
  img = Image.open(io.BytesIO(png_bytes))
122
  else:
123
- # Generate standard molecule image
124
  mol = Chem.MolFromSmiles(input_string)
125
  if mol:
126
  img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
@@ -128,15 +103,125 @@ def generate_molecule_image(input_string, use_safe_visualization=True):
128
  img = None
129
  return img
130
  except Exception as e:
131
- # Collect exceptions for later reporting
132
- return e
133
-
134
- import streamlit.components.v1 as components
135
 
 
136
  def st_copy_button(text, key):
137
- """Creates a copy-to-clipboard button."""
 
 
138
  components.html(f"""
139
  <button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
140
  """, height=45)
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
 
 
10
  from PIL import Image
11
  import cairosvg
12
  import pandas as pd
13
+ import streamlit.components.v1 as components
14
 
15
+ # **Page Configuration**
16
+ st.set_page_config(
17
+ page_title='Beta-Lactam Molecule Generator',
18
+ layout='wide'
19
+ )
20
 
21
+ # **Load Models**
22
  @st.cache_resource(show_spinner="Loading Models...", ttl=600)
23
  def load_models():
24
+ """
25
+ Load the molecule generation model and the ADMET-AI model.
26
+ Caches the models to avoid reloading on every run.
27
+ """
28
+ # **Load your molecule generation model**
29
  model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
30
  access_token = os.getenv("HUGGING_FACE_TOKEN")
31
+ if access_token is None:
32
+ st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
33
+ st.stop()
34
+ model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
35
+ tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
36
+
37
+ # **Load ADMET-AI model**
38
  admet_model = ADMETModel()
39
+
40
  return model, tokenizer, admet_model
41
 
42
+ # **Load models once and reuse**
43
  model, tokenizer, admet_model = load_models()
44
 
45
+ # **Set Generation Parameters in Sidebar**
46
  st.sidebar.header('Generation Parameters')
 
 
47
 
48
+ # **Creativity Slider (Temperature)**
49
+ creativity = st.sidebar.slider(
50
+ 'Creativity (Temperature):',
51
+ min_value=0.0,
52
+ max_value=2.0,
53
+ value=1.0,
54
+ step=0.1,
55
+ help="Higher values lead to more diverse outputs."
56
+ )
57
 
58
+ # **Number of Molecules to Generate**
59
+ num_molecules = st.sidebar.number_input(
60
+ 'Number of Molecules to Generate:',
61
+ min_value=1,
62
+ max_value=5,
63
+ value=5,
64
+ help="Select the number of molecules you want to generate."
65
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # **String Format Option (SMILES or SAFE)**
68
+ string_format = st.sidebar.radio(
69
+ 'String Format:',
70
+ ('SMILES', 'SAFE'),
71
+ help="Choose the format for displaying molecule strings."
72
+ )
73
 
74
+ # **Function to Generate Molecule Images**
75
  def generate_molecule_image(input_string, use_safe_visualization=True):
76
+ """
77
+ Generates an image of the molecule from the input string.
78
+ Supports SAFE visualization if enabled.
79
+ """
80
  try:
81
  if use_safe_visualization:
82
  try:
83
+ # **Attempt to decode as SAFE string**
84
  smiles = safe.decode(input_string)
85
+ # **Encode back to SAFE string**
86
  safe_string = safe.encode(smiles)
87
+ except Exception as e:
88
+ # **Handle decoding errors**
89
+ st.error(f"Error decoding SAFE string: {e}")
90
+ return None
91
+ # **Generate SVG image with fragment highlights**
92
  svg_str = safe.to_image(safe_string)
93
+ # **Convert SVG to PNG bytes**
94
  png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
95
+ # **Create an image object**
96
  img = Image.open(io.BytesIO(png_bytes))
97
  else:
98
+ # **Generate standard molecule image**
99
  mol = Chem.MolFromSmiles(input_string)
100
  if mol:
101
  img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
 
103
  img = None
104
  return img
105
  except Exception as e:
106
+ # **Collect exceptions for later reporting**
107
+ st.error(f"Error generating molecule image: {e}")
108
+ return None
 
109
 
110
+ # **Function to Create Copy-to-Clipboard Button**
111
  def st_copy_button(text, key):
112
+ """
113
+ Creates a copy-to-clipboard button for the given text.
114
+ """
115
  components.html(f"""
116
  <button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
117
  """, height=45)
118
 
119
+ # **Generate Molecules Button**
120
+ if st.button('Generate Molecules'):
121
+ st.info("Generating molecules... Please wait.")
122
+
123
+ # **Beta-lactam core structure**
124
+ core_smiles = "C1C(=O)N(C)C(=O)C1"
125
+
126
+ # **Tokenize the core SMILES**
127
+ input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
128
+
129
+ # **Generate molecules using the model**
130
+ output_ids = model.generate(
131
+ input_ids=input_ids,
132
+ max_length=128,
133
+ temperature=creativity,
134
+ do_sample=True,
135
+ top_k=50,
136
+ num_return_sequences=num_molecules,
137
+ num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
138
+ )
139
+
140
+ # **Decode generated molecule SMILES**
141
+ generated_smiles = [
142
+ tokenizer.decode(ids, skip_special_tokens=True)
143
+ for ids in output_ids
144
+ ]
145
+
146
+ # **Create molecule names**
147
+ molecule_names = [
148
+ f"Mol{str(i).zfill(2)}"
149
+ for i in range(1, len(generated_smiles) + 1)
150
+ ]
151
+
152
+ # **Create DataFrame for generated molecules**
153
+ df_molecules = pd.DataFrame({
154
+ 'Molecule Name': molecule_names,
155
+ 'SMILES': generated_smiles
156
+ })
157
+
158
+ # **Invalid SMILES Check**
159
+ from rdkit import Chem
160
+
161
+ # **Function to validate SMILES**
162
+ def is_valid_smile(smile):
163
+ return Chem.MolFromSmiles(smile) is not None
164
+
165
+ # **Apply validation function**
166
+ df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
167
+ df_valid = df_molecules[df_molecules['Valid']].copy()
168
+
169
+ # **Inform user if any molecules were invalid**
170
+ invalid_molecules = df_molecules[~df_molecules['Valid']]
171
+ if not invalid_molecules.empty:
172
+ st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
173
+
174
+ # **Check if there are valid molecules to proceed**
175
+ if df_valid.empty:
176
+ st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
177
+ else:
178
+ # **ADMET Predictions**
179
+ preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
180
+
181
+ # **Ensure 'SMILES' is a column in preds**
182
+ if 'SMILES' not in preds.columns:
183
+ preds['SMILES'] = df_valid['SMILES'].values
184
+
185
+ # **Merge predictions with valid molecules**
186
+ df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
187
+
188
+ # **Set 'Molecule Name' as index**
189
+ df_results.set_index('Molecule Name', inplace=True)
190
+
191
+ # **Check if df_results is empty after merging**
192
+ if df_results.empty:
193
+ st.error("No valid molecules were generated after predictions. Please try adjusting the generation parameters.")
194
+ else:
195
+ # **Display Molecules**
196
+ st.subheader('Generated Molecules')
197
+
198
+ # **Determine number of columns per row**
199
+ cols_per_row = min(5, len(df_results))
200
+
201
+ # **Create columns in Streamlit**
202
+ cols = st.columns(cols_per_row)
203
+
204
+ # **Iterate over each molecule to display**
205
+ for idx, (mol_name, row) in enumerate(df_results.iterrows()):
206
+ smiles = row['SMILES']
207
+ img = generate_molecule_image(
208
+ smiles,
209
+ use_safe_visualization=(string_format == 'SAFE')
210
+ )
211
+ with cols[idx % cols_per_row]:
212
+ if img is not None and isinstance(img, Image.Image):
213
+ st.image(img, caption=mol_name)
214
+ else:
215
+ st.error(f"Could not generate image for {mol_name}")
216
+ # **Display molecule string in chosen format**
217
+ string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
218
+ st.code(string_to_display)
219
+ # **Copy-to-clipboard functionality**
220
+ st_copy_button(string_to_display, key=f'copy_{mol_name}')
221
+ # **Display ADMET properties**
222
+ st.write("**ADMET Properties:**")
223
+ st.write(row.drop(['SMILES', 'Valid']))
224
+ else:
225
+ st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
226
 
227
+