Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,117 +10,92 @@ import io
|
|
| 10 |
from PIL import Image
|
| 11 |
import cairosvg
|
| 12 |
import pandas as pd
|
|
|
|
| 13 |
|
| 14 |
-
# Page Configuration
|
| 15 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
# Load Models
|
| 18 |
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
|
| 19 |
def load_models():
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
|
| 22 |
access_token = os.getenv("HUGGING_FACE_TOKEN")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
admet_model = ADMETModel()
|
|
|
|
| 27 |
return model, tokenizer, admet_model
|
| 28 |
|
|
|
|
| 29 |
model, tokenizer, admet_model = load_models()
|
| 30 |
|
| 31 |
-
# Set Generation Parameters
|
| 32 |
st.sidebar.header('Generation Parameters')
|
| 33 |
-
creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1)
|
| 34 |
-
num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5)
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
max_length=128,
|
| 48 |
-
temperature=creativity,
|
| 49 |
-
do_sample=True,
|
| 50 |
-
top_k=50,
|
| 51 |
-
num_return_sequences=num_molecules,
|
| 52 |
-
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
|
| 53 |
-
)
|
| 54 |
-
generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
|
| 55 |
-
molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, len(generated_smiles) + 1)]
|
| 56 |
-
|
| 57 |
-
# Create DataFrame for generated molecules
|
| 58 |
-
df_molecules = pd.DataFrame({
|
| 59 |
-
'Molecule Name': molecule_names,
|
| 60 |
-
'SMILES': generated_smiles
|
| 61 |
-
})
|
| 62 |
-
|
| 63 |
-
# Display generated SMILES for debugging
|
| 64 |
-
st.write("Generated SMILES:")
|
| 65 |
-
st.write(df_molecules)
|
| 66 |
-
|
| 67 |
-
# ADMET Predictions
|
| 68 |
-
preds = admet_model.predict(smiles=df_molecules['SMILES'].tolist())
|
| 69 |
-
|
| 70 |
-
# Ensure 'SMILES' is a column in preds
|
| 71 |
-
if 'SMILES' not in preds.columns:
|
| 72 |
-
preds['SMILES'] = df_molecules['SMILES']
|
| 73 |
-
|
| 74 |
-
# Merge predictions with generated molecules
|
| 75 |
-
df_results = pd.merge(df_molecules, preds, on='SMILES', how='inner')
|
| 76 |
-
|
| 77 |
-
# Set 'Molecule Name' as index
|
| 78 |
-
df_results.set_index('Molecule Name', inplace=True)
|
| 79 |
-
|
| 80 |
-
# Display Molecules
|
| 81 |
-
st.subheader('Generated Molecules')
|
| 82 |
-
cols_per_row = min(5, len(df_results))
|
| 83 |
-
cols = st.columns(cols_per_row)
|
| 84 |
-
for idx, (mol_name, row) in enumerate(df_results.iterrows()):
|
| 85 |
-
smiles = row['SMILES']
|
| 86 |
-
img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE'))
|
| 87 |
-
with cols[idx % cols_per_row]:
|
| 88 |
-
if isinstance(img, Image.Image):
|
| 89 |
-
st.image(img, caption=mol_name)
|
| 90 |
-
else:
|
| 91 |
-
st.error(f"Could not generate image for {mol_name}")
|
| 92 |
-
# Display molecule string
|
| 93 |
-
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
|
| 94 |
-
st.code(string_to_display)
|
| 95 |
-
# Copy-to-clipboard functionality
|
| 96 |
-
st_copy_button(string_to_display, key=f'copy_{mol_name}')
|
| 97 |
-
# Display ADMET properties
|
| 98 |
-
st.write("**ADMET Properties:**")
|
| 99 |
-
st.write(row.drop(['SMILES']))
|
| 100 |
-
else:
|
| 101 |
-
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
# Function
|
| 105 |
def generate_molecule_image(input_string, use_safe_visualization=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
try:
|
| 107 |
if use_safe_visualization:
|
| 108 |
try:
|
| 109 |
-
# Attempt to decode as SAFE string
|
| 110 |
smiles = safe.decode(input_string)
|
| 111 |
-
# Encode back to SAFE string
|
| 112 |
safe_string = safe.encode(smiles)
|
| 113 |
-
except Exception:
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
svg_str = safe.to_image(safe_string)
|
| 118 |
-
# Convert SVG to PNG bytes
|
| 119 |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
|
| 120 |
-
# Create an image object
|
| 121 |
img = Image.open(io.BytesIO(png_bytes))
|
| 122 |
else:
|
| 123 |
-
# Generate standard molecule image
|
| 124 |
mol = Chem.MolFromSmiles(input_string)
|
| 125 |
if mol:
|
| 126 |
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
|
|
@@ -128,15 +103,125 @@ def generate_molecule_image(input_string, use_safe_visualization=True):
|
|
| 128 |
img = None
|
| 129 |
return img
|
| 130 |
except Exception as e:
|
| 131 |
-
# Collect exceptions for later reporting
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
import streamlit.components.v1 as components
|
| 135 |
|
|
|
|
| 136 |
def st_copy_button(text, key):
|
| 137 |
-
"""
|
|
|
|
|
|
|
| 138 |
components.html(f"""
|
| 139 |
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
|
| 140 |
""", height=45)
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
|
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
import cairosvg
|
| 12 |
import pandas as pd
|
| 13 |
+
import streamlit.components.v1 as components
|
| 14 |
|
| 15 |
+
# **Page Configuration**
|
| 16 |
+
st.set_page_config(
|
| 17 |
+
page_title='Beta-Lactam Molecule Generator',
|
| 18 |
+
layout='wide'
|
| 19 |
+
)
|
| 20 |
|
| 21 |
+
# **Load Models**
|
| 22 |
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
|
| 23 |
def load_models():
|
| 24 |
+
"""
|
| 25 |
+
Load the molecule generation model and the ADMET-AI model.
|
| 26 |
+
Caches the models to avoid reloading on every run.
|
| 27 |
+
"""
|
| 28 |
+
# **Load your molecule generation model**
|
| 29 |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
|
| 30 |
access_token = os.getenv("HUGGING_FACE_TOKEN")
|
| 31 |
+
if access_token is None:
|
| 32 |
+
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
|
| 33 |
+
st.stop()
|
| 34 |
+
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
|
| 35 |
+
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
|
| 36 |
+
|
| 37 |
+
# **Load ADMET-AI model**
|
| 38 |
admet_model = ADMETModel()
|
| 39 |
+
|
| 40 |
return model, tokenizer, admet_model
|
| 41 |
|
| 42 |
+
# **Load models once and reuse**
|
| 43 |
model, tokenizer, admet_model = load_models()
|
| 44 |
|
| 45 |
+
# **Set Generation Parameters in Sidebar**
|
| 46 |
st.sidebar.header('Generation Parameters')
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
# **Creativity Slider (Temperature)**
|
| 49 |
+
creativity = st.sidebar.slider(
|
| 50 |
+
'Creativity (Temperature):',
|
| 51 |
+
min_value=0.0,
|
| 52 |
+
max_value=2.0,
|
| 53 |
+
value=1.0,
|
| 54 |
+
step=0.1,
|
| 55 |
+
help="Higher values lead to more diverse outputs."
|
| 56 |
+
)
|
| 57 |
|
| 58 |
+
# **Number of Molecules to Generate**
|
| 59 |
+
num_molecules = st.sidebar.number_input(
|
| 60 |
+
'Number of Molecules to Generate:',
|
| 61 |
+
min_value=1,
|
| 62 |
+
max_value=5,
|
| 63 |
+
value=5,
|
| 64 |
+
help="Select the number of molecules you want to generate."
|
| 65 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
# **String Format Option (SMILES or SAFE)**
|
| 68 |
+
string_format = st.sidebar.radio(
|
| 69 |
+
'String Format:',
|
| 70 |
+
('SMILES', 'SAFE'),
|
| 71 |
+
help="Choose the format for displaying molecule strings."
|
| 72 |
+
)
|
| 73 |
|
| 74 |
+
# **Function to Generate Molecule Images**
|
| 75 |
def generate_molecule_image(input_string, use_safe_visualization=True):
|
| 76 |
+
"""
|
| 77 |
+
Generates an image of the molecule from the input string.
|
| 78 |
+
Supports SAFE visualization if enabled.
|
| 79 |
+
"""
|
| 80 |
try:
|
| 81 |
if use_safe_visualization:
|
| 82 |
try:
|
| 83 |
+
# **Attempt to decode as SAFE string**
|
| 84 |
smiles = safe.decode(input_string)
|
| 85 |
+
# **Encode back to SAFE string**
|
| 86 |
safe_string = safe.encode(smiles)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
# **Handle decoding errors**
|
| 89 |
+
st.error(f"Error decoding SAFE string: {e}")
|
| 90 |
+
return None
|
| 91 |
+
# **Generate SVG image with fragment highlights**
|
| 92 |
svg_str = safe.to_image(safe_string)
|
| 93 |
+
# **Convert SVG to PNG bytes**
|
| 94 |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
|
| 95 |
+
# **Create an image object**
|
| 96 |
img = Image.open(io.BytesIO(png_bytes))
|
| 97 |
else:
|
| 98 |
+
# **Generate standard molecule image**
|
| 99 |
mol = Chem.MolFromSmiles(input_string)
|
| 100 |
if mol:
|
| 101 |
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
|
|
|
|
| 103 |
img = None
|
| 104 |
return img
|
| 105 |
except Exception as e:
|
| 106 |
+
# **Collect exceptions for later reporting**
|
| 107 |
+
st.error(f"Error generating molecule image: {e}")
|
| 108 |
+
return None
|
|
|
|
| 109 |
|
| 110 |
+
# **Function to Create Copy-to-Clipboard Button**
|
| 111 |
def st_copy_button(text, key):
|
| 112 |
+
"""
|
| 113 |
+
Creates a copy-to-clipboard button for the given text.
|
| 114 |
+
"""
|
| 115 |
components.html(f"""
|
| 116 |
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
|
| 117 |
""", height=45)
|
| 118 |
|
| 119 |
+
# **Generate Molecules Button**
|
| 120 |
+
if st.button('Generate Molecules'):
|
| 121 |
+
st.info("Generating molecules... Please wait.")
|
| 122 |
+
|
| 123 |
+
# **Beta-lactam core structure**
|
| 124 |
+
core_smiles = "C1C(=O)N(C)C(=O)C1"
|
| 125 |
+
|
| 126 |
+
# **Tokenize the core SMILES**
|
| 127 |
+
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
|
| 128 |
+
|
| 129 |
+
# **Generate molecules using the model**
|
| 130 |
+
output_ids = model.generate(
|
| 131 |
+
input_ids=input_ids,
|
| 132 |
+
max_length=128,
|
| 133 |
+
temperature=creativity,
|
| 134 |
+
do_sample=True,
|
| 135 |
+
top_k=50,
|
| 136 |
+
num_return_sequences=num_molecules,
|
| 137 |
+
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# **Decode generated molecule SMILES**
|
| 141 |
+
generated_smiles = [
|
| 142 |
+
tokenizer.decode(ids, skip_special_tokens=True)
|
| 143 |
+
for ids in output_ids
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
# **Create molecule names**
|
| 147 |
+
molecule_names = [
|
| 148 |
+
f"Mol{str(i).zfill(2)}"
|
| 149 |
+
for i in range(1, len(generated_smiles) + 1)
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# **Create DataFrame for generated molecules**
|
| 153 |
+
df_molecules = pd.DataFrame({
|
| 154 |
+
'Molecule Name': molecule_names,
|
| 155 |
+
'SMILES': generated_smiles
|
| 156 |
+
})
|
| 157 |
+
|
| 158 |
+
# **Invalid SMILES Check**
|
| 159 |
+
from rdkit import Chem
|
| 160 |
+
|
| 161 |
+
# **Function to validate SMILES**
|
| 162 |
+
def is_valid_smile(smile):
|
| 163 |
+
return Chem.MolFromSmiles(smile) is not None
|
| 164 |
+
|
| 165 |
+
# **Apply validation function**
|
| 166 |
+
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
|
| 167 |
+
df_valid = df_molecules[df_molecules['Valid']].copy()
|
| 168 |
+
|
| 169 |
+
# **Inform user if any molecules were invalid**
|
| 170 |
+
invalid_molecules = df_molecules[~df_molecules['Valid']]
|
| 171 |
+
if not invalid_molecules.empty:
|
| 172 |
+
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
|
| 173 |
+
|
| 174 |
+
# **Check if there are valid molecules to proceed**
|
| 175 |
+
if df_valid.empty:
|
| 176 |
+
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
|
| 177 |
+
else:
|
| 178 |
+
# **ADMET Predictions**
|
| 179 |
+
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
|
| 180 |
+
|
| 181 |
+
# **Ensure 'SMILES' is a column in preds**
|
| 182 |
+
if 'SMILES' not in preds.columns:
|
| 183 |
+
preds['SMILES'] = df_valid['SMILES'].values
|
| 184 |
+
|
| 185 |
+
# **Merge predictions with valid molecules**
|
| 186 |
+
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
|
| 187 |
+
|
| 188 |
+
# **Set 'Molecule Name' as index**
|
| 189 |
+
df_results.set_index('Molecule Name', inplace=True)
|
| 190 |
+
|
| 191 |
+
# **Check if df_results is empty after merging**
|
| 192 |
+
if df_results.empty:
|
| 193 |
+
st.error("No valid molecules were generated after predictions. Please try adjusting the generation parameters.")
|
| 194 |
+
else:
|
| 195 |
+
# **Display Molecules**
|
| 196 |
+
st.subheader('Generated Molecules')
|
| 197 |
+
|
| 198 |
+
# **Determine number of columns per row**
|
| 199 |
+
cols_per_row = min(5, len(df_results))
|
| 200 |
+
|
| 201 |
+
# **Create columns in Streamlit**
|
| 202 |
+
cols = st.columns(cols_per_row)
|
| 203 |
+
|
| 204 |
+
# **Iterate over each molecule to display**
|
| 205 |
+
for idx, (mol_name, row) in enumerate(df_results.iterrows()):
|
| 206 |
+
smiles = row['SMILES']
|
| 207 |
+
img = generate_molecule_image(
|
| 208 |
+
smiles,
|
| 209 |
+
use_safe_visualization=(string_format == 'SAFE')
|
| 210 |
+
)
|
| 211 |
+
with cols[idx % cols_per_row]:
|
| 212 |
+
if img is not None and isinstance(img, Image.Image):
|
| 213 |
+
st.image(img, caption=mol_name)
|
| 214 |
+
else:
|
| 215 |
+
st.error(f"Could not generate image for {mol_name}")
|
| 216 |
+
# **Display molecule string in chosen format**
|
| 217 |
+
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
|
| 218 |
+
st.code(string_to_display)
|
| 219 |
+
# **Copy-to-clipboard functionality**
|
| 220 |
+
st_copy_button(string_to_display, key=f'copy_{mol_name}')
|
| 221 |
+
# **Display ADMET properties**
|
| 222 |
+
st.write("**ADMET Properties:**")
|
| 223 |
+
st.write(row.drop(['SMILES', 'Valid']))
|
| 224 |
+
else:
|
| 225 |
+
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
|
| 226 |
|
| 227 |
+
|