Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,117 +10,92 @@ import io
|
|
10 |
from PIL import Image
|
11 |
import cairosvg
|
12 |
import pandas as pd
|
|
|
13 |
|
14 |
-
# Page Configuration
|
15 |
-
st.set_page_config(
|
|
|
|
|
|
|
16 |
|
17 |
-
# Load Models
|
18 |
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
|
19 |
def load_models():
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
|
22 |
access_token = os.getenv("HUGGING_FACE_TOKEN")
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
admet_model = ADMETModel()
|
|
|
27 |
return model, tokenizer, admet_model
|
28 |
|
|
|
29 |
model, tokenizer, admet_model = load_models()
|
30 |
|
31 |
-
# Set Generation Parameters
|
32 |
st.sidebar.header('Generation Parameters')
|
33 |
-
creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1)
|
34 |
-
num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5)
|
35 |
|
36 |
-
#
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
max_length=128,
|
48 |
-
temperature=creativity,
|
49 |
-
do_sample=True,
|
50 |
-
top_k=50,
|
51 |
-
num_return_sequences=num_molecules,
|
52 |
-
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
|
53 |
-
)
|
54 |
-
generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
|
55 |
-
molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, len(generated_smiles) + 1)]
|
56 |
-
|
57 |
-
# Create DataFrame for generated molecules
|
58 |
-
df_molecules = pd.DataFrame({
|
59 |
-
'Molecule Name': molecule_names,
|
60 |
-
'SMILES': generated_smiles
|
61 |
-
})
|
62 |
-
|
63 |
-
# Display generated SMILES for debugging
|
64 |
-
st.write("Generated SMILES:")
|
65 |
-
st.write(df_molecules)
|
66 |
-
|
67 |
-
# ADMET Predictions
|
68 |
-
preds = admet_model.predict(smiles=df_molecules['SMILES'].tolist())
|
69 |
-
|
70 |
-
# Ensure 'SMILES' is a column in preds
|
71 |
-
if 'SMILES' not in preds.columns:
|
72 |
-
preds['SMILES'] = df_molecules['SMILES']
|
73 |
-
|
74 |
-
# Merge predictions with generated molecules
|
75 |
-
df_results = pd.merge(df_molecules, preds, on='SMILES', how='inner')
|
76 |
-
|
77 |
-
# Set 'Molecule Name' as index
|
78 |
-
df_results.set_index('Molecule Name', inplace=True)
|
79 |
-
|
80 |
-
# Display Molecules
|
81 |
-
st.subheader('Generated Molecules')
|
82 |
-
cols_per_row = min(5, len(df_results))
|
83 |
-
cols = st.columns(cols_per_row)
|
84 |
-
for idx, (mol_name, row) in enumerate(df_results.iterrows()):
|
85 |
-
smiles = row['SMILES']
|
86 |
-
img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE'))
|
87 |
-
with cols[idx % cols_per_row]:
|
88 |
-
if isinstance(img, Image.Image):
|
89 |
-
st.image(img, caption=mol_name)
|
90 |
-
else:
|
91 |
-
st.error(f"Could not generate image for {mol_name}")
|
92 |
-
# Display molecule string
|
93 |
-
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
|
94 |
-
st.code(string_to_display)
|
95 |
-
# Copy-to-clipboard functionality
|
96 |
-
st_copy_button(string_to_display, key=f'copy_{mol_name}')
|
97 |
-
# Display ADMET properties
|
98 |
-
st.write("**ADMET Properties:**")
|
99 |
-
st.write(row.drop(['SMILES']))
|
100 |
-
else:
|
101 |
-
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
# Function
|
105 |
def generate_molecule_image(input_string, use_safe_visualization=True):
|
|
|
|
|
|
|
|
|
106 |
try:
|
107 |
if use_safe_visualization:
|
108 |
try:
|
109 |
-
# Attempt to decode as SAFE string
|
110 |
smiles = safe.decode(input_string)
|
111 |
-
# Encode back to SAFE string
|
112 |
safe_string = safe.encode(smiles)
|
113 |
-
except Exception:
|
114 |
-
#
|
115 |
-
|
116 |
-
|
|
|
117 |
svg_str = safe.to_image(safe_string)
|
118 |
-
# Convert SVG to PNG bytes
|
119 |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
|
120 |
-
# Create an image object
|
121 |
img = Image.open(io.BytesIO(png_bytes))
|
122 |
else:
|
123 |
-
# Generate standard molecule image
|
124 |
mol = Chem.MolFromSmiles(input_string)
|
125 |
if mol:
|
126 |
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
|
@@ -128,15 +103,125 @@ def generate_molecule_image(input_string, use_safe_visualization=True):
|
|
128 |
img = None
|
129 |
return img
|
130 |
except Exception as e:
|
131 |
-
# Collect exceptions for later reporting
|
132 |
-
|
133 |
-
|
134 |
-
import streamlit.components.v1 as components
|
135 |
|
|
|
136 |
def st_copy_button(text, key):
|
137 |
-
"""
|
|
|
|
|
138 |
components.html(f"""
|
139 |
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
|
140 |
""", height=45)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
|
|
|
10 |
from PIL import Image
|
11 |
import cairosvg
|
12 |
import pandas as pd
|
13 |
+
import streamlit.components.v1 as components
|
14 |
|
15 |
+
# **Page Configuration**
|
16 |
+
st.set_page_config(
|
17 |
+
page_title='Beta-Lactam Molecule Generator',
|
18 |
+
layout='wide'
|
19 |
+
)
|
20 |
|
21 |
+
# **Load Models**
|
22 |
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
|
23 |
def load_models():
|
24 |
+
"""
|
25 |
+
Load the molecule generation model and the ADMET-AI model.
|
26 |
+
Caches the models to avoid reloading on every run.
|
27 |
+
"""
|
28 |
+
# **Load your molecule generation model**
|
29 |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
|
30 |
access_token = os.getenv("HUGGING_FACE_TOKEN")
|
31 |
+
if access_token is None:
|
32 |
+
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
|
33 |
+
st.stop()
|
34 |
+
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
|
35 |
+
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
|
36 |
+
|
37 |
+
# **Load ADMET-AI model**
|
38 |
admet_model = ADMETModel()
|
39 |
+
|
40 |
return model, tokenizer, admet_model
|
41 |
|
42 |
+
# **Load models once and reuse**
|
43 |
model, tokenizer, admet_model = load_models()
|
44 |
|
45 |
+
# **Set Generation Parameters in Sidebar**
|
46 |
st.sidebar.header('Generation Parameters')
|
|
|
|
|
47 |
|
48 |
+
# **Creativity Slider (Temperature)**
|
49 |
+
creativity = st.sidebar.slider(
|
50 |
+
'Creativity (Temperature):',
|
51 |
+
min_value=0.0,
|
52 |
+
max_value=2.0,
|
53 |
+
value=1.0,
|
54 |
+
step=0.1,
|
55 |
+
help="Higher values lead to more diverse outputs."
|
56 |
+
)
|
57 |
|
58 |
+
# **Number of Molecules to Generate**
|
59 |
+
num_molecules = st.sidebar.number_input(
|
60 |
+
'Number of Molecules to Generate:',
|
61 |
+
min_value=1,
|
62 |
+
max_value=5,
|
63 |
+
value=5,
|
64 |
+
help="Select the number of molecules you want to generate."
|
65 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
# **String Format Option (SMILES or SAFE)**
|
68 |
+
string_format = st.sidebar.radio(
|
69 |
+
'String Format:',
|
70 |
+
('SMILES', 'SAFE'),
|
71 |
+
help="Choose the format for displaying molecule strings."
|
72 |
+
)
|
73 |
|
74 |
+
# **Function to Generate Molecule Images**
|
75 |
def generate_molecule_image(input_string, use_safe_visualization=True):
|
76 |
+
"""
|
77 |
+
Generates an image of the molecule from the input string.
|
78 |
+
Supports SAFE visualization if enabled.
|
79 |
+
"""
|
80 |
try:
|
81 |
if use_safe_visualization:
|
82 |
try:
|
83 |
+
# **Attempt to decode as SAFE string**
|
84 |
smiles = safe.decode(input_string)
|
85 |
+
# **Encode back to SAFE string**
|
86 |
safe_string = safe.encode(smiles)
|
87 |
+
except Exception as e:
|
88 |
+
# **Handle decoding errors**
|
89 |
+
st.error(f"Error decoding SAFE string: {e}")
|
90 |
+
return None
|
91 |
+
# **Generate SVG image with fragment highlights**
|
92 |
svg_str = safe.to_image(safe_string)
|
93 |
+
# **Convert SVG to PNG bytes**
|
94 |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
|
95 |
+
# **Create an image object**
|
96 |
img = Image.open(io.BytesIO(png_bytes))
|
97 |
else:
|
98 |
+
# **Generate standard molecule image**
|
99 |
mol = Chem.MolFromSmiles(input_string)
|
100 |
if mol:
|
101 |
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
|
|
|
103 |
img = None
|
104 |
return img
|
105 |
except Exception as e:
|
106 |
+
# **Collect exceptions for later reporting**
|
107 |
+
st.error(f"Error generating molecule image: {e}")
|
108 |
+
return None
|
|
|
109 |
|
110 |
+
# **Function to Create Copy-to-Clipboard Button**
|
111 |
def st_copy_button(text, key):
|
112 |
+
"""
|
113 |
+
Creates a copy-to-clipboard button for the given text.
|
114 |
+
"""
|
115 |
components.html(f"""
|
116 |
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
|
117 |
""", height=45)
|
118 |
|
119 |
+
# **Generate Molecules Button**
|
120 |
+
if st.button('Generate Molecules'):
|
121 |
+
st.info("Generating molecules... Please wait.")
|
122 |
+
|
123 |
+
# **Beta-lactam core structure**
|
124 |
+
core_smiles = "C1C(=O)N(C)C(=O)C1"
|
125 |
+
|
126 |
+
# **Tokenize the core SMILES**
|
127 |
+
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
|
128 |
+
|
129 |
+
# **Generate molecules using the model**
|
130 |
+
output_ids = model.generate(
|
131 |
+
input_ids=input_ids,
|
132 |
+
max_length=128,
|
133 |
+
temperature=creativity,
|
134 |
+
do_sample=True,
|
135 |
+
top_k=50,
|
136 |
+
num_return_sequences=num_molecules,
|
137 |
+
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
|
138 |
+
)
|
139 |
+
|
140 |
+
# **Decode generated molecule SMILES**
|
141 |
+
generated_smiles = [
|
142 |
+
tokenizer.decode(ids, skip_special_tokens=True)
|
143 |
+
for ids in output_ids
|
144 |
+
]
|
145 |
+
|
146 |
+
# **Create molecule names**
|
147 |
+
molecule_names = [
|
148 |
+
f"Mol{str(i).zfill(2)}"
|
149 |
+
for i in range(1, len(generated_smiles) + 1)
|
150 |
+
]
|
151 |
+
|
152 |
+
# **Create DataFrame for generated molecules**
|
153 |
+
df_molecules = pd.DataFrame({
|
154 |
+
'Molecule Name': molecule_names,
|
155 |
+
'SMILES': generated_smiles
|
156 |
+
})
|
157 |
+
|
158 |
+
# **Invalid SMILES Check**
|
159 |
+
from rdkit import Chem
|
160 |
+
|
161 |
+
# **Function to validate SMILES**
|
162 |
+
def is_valid_smile(smile):
|
163 |
+
return Chem.MolFromSmiles(smile) is not None
|
164 |
+
|
165 |
+
# **Apply validation function**
|
166 |
+
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
|
167 |
+
df_valid = df_molecules[df_molecules['Valid']].copy()
|
168 |
+
|
169 |
+
# **Inform user if any molecules were invalid**
|
170 |
+
invalid_molecules = df_molecules[~df_molecules['Valid']]
|
171 |
+
if not invalid_molecules.empty:
|
172 |
+
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
|
173 |
+
|
174 |
+
# **Check if there are valid molecules to proceed**
|
175 |
+
if df_valid.empty:
|
176 |
+
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
|
177 |
+
else:
|
178 |
+
# **ADMET Predictions**
|
179 |
+
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
|
180 |
+
|
181 |
+
# **Ensure 'SMILES' is a column in preds**
|
182 |
+
if 'SMILES' not in preds.columns:
|
183 |
+
preds['SMILES'] = df_valid['SMILES'].values
|
184 |
+
|
185 |
+
# **Merge predictions with valid molecules**
|
186 |
+
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
|
187 |
+
|
188 |
+
# **Set 'Molecule Name' as index**
|
189 |
+
df_results.set_index('Molecule Name', inplace=True)
|
190 |
+
|
191 |
+
# **Check if df_results is empty after merging**
|
192 |
+
if df_results.empty:
|
193 |
+
st.error("No valid molecules were generated after predictions. Please try adjusting the generation parameters.")
|
194 |
+
else:
|
195 |
+
# **Display Molecules**
|
196 |
+
st.subheader('Generated Molecules')
|
197 |
+
|
198 |
+
# **Determine number of columns per row**
|
199 |
+
cols_per_row = min(5, len(df_results))
|
200 |
+
|
201 |
+
# **Create columns in Streamlit**
|
202 |
+
cols = st.columns(cols_per_row)
|
203 |
+
|
204 |
+
# **Iterate over each molecule to display**
|
205 |
+
for idx, (mol_name, row) in enumerate(df_results.iterrows()):
|
206 |
+
smiles = row['SMILES']
|
207 |
+
img = generate_molecule_image(
|
208 |
+
smiles,
|
209 |
+
use_safe_visualization=(string_format == 'SAFE')
|
210 |
+
)
|
211 |
+
with cols[idx % cols_per_row]:
|
212 |
+
if img is not None and isinstance(img, Image.Image):
|
213 |
+
st.image(img, caption=mol_name)
|
214 |
+
else:
|
215 |
+
st.error(f"Could not generate image for {mol_name}")
|
216 |
+
# **Display molecule string in chosen format**
|
217 |
+
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
|
218 |
+
st.code(string_to_display)
|
219 |
+
# **Copy-to-clipboard functionality**
|
220 |
+
st_copy_button(string_to_display, key=f'copy_{mol_name}')
|
221 |
+
# **Display ADMET properties**
|
222 |
+
st.write("**ADMET Properties:**")
|
223 |
+
st.write(row.drop(['SMILES', 'Valid']))
|
224 |
+
else:
|
225 |
+
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
|
226 |
|
227 |
+
|