bcadkins01's picture
Update app.py
a3a6043 verified
raw
history blame
8.83 kB
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
import streamlit.components.v1 as components
# **Page Configuration**
st.set_page_config(
page_title='Beta-Lactam Molecule Generator',
layout='wide'
)
# **Load Models**
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
"""
Load the molecule generation model and the ADMET-AI model.
Caches the models to avoid reloading on every run.
"""
# **Load your molecule generation model**
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
access_token = os.getenv("HUGGING_FACE_TOKEN")
if access_token is None:
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
st.stop()
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
# **Load ADMET-AI model**
admet_model = ADMETModel()
return model, tokenizer, admet_model
# **Load models once and reuse**
model, tokenizer, admet_model = load_models()
# **Set Generation Parameters in Sidebar**
st.sidebar.header('Generation Parameters')
# **Creativity Slider (Temperature)**
creativity = st.sidebar.slider(
'Creativity (Temperature):',
min_value=0.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Higher values lead to more diverse outputs."
)
# **Number of Molecules to Generate**
num_molecules = st.sidebar.number_input(
'Number of Molecules to Generate:',
min_value=1,
max_value=3, # Reduced from 5 to 3
value=3,
help="Select the number of molecules you want to generate (up to 3)."
)
# **Function to Generate Molecule Images**
def generate_molecule_image(smiles):
"""
Generates an image of the molecule from the SMILES string.
"""
try:
mol = Chem.MolFromSmiles(smiles)
if mol:
img = Draw.MolToImage(mol, size=(200, 200))
else:
img = None
return img
except Exception as e:
st.error(f"Error generating molecule image: {e}")
return None
# **Function to Create Copy-to-Clipboard Button**
def st_copy_button(text, key):
"""
Creates a copy-to-clipboard button for the given text.
Adjusted to position the button without overlapping the text.
"""
# Adjusted styling to position the button
button_html = f"""
<div style="display: flex; justify-content: flex-end;">
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px; margin-top: -40px; position: relative; z-index: 1;">Copy</button>
</div>
"""
components.html(button_html, height=45)
# **Generate Molecules Button**
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# **Beta-lactam core structure**
core_smiles = "C1C(=O)N(C)C(=O)C1"
# **Tokenize the core SMILES**
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
# **Generate molecules using the model**
output_ids = model.generate(
input_ids=input_ids,
max_length=128,
temperature=creativity,
do_sample=True,
top_k=50,
num_return_sequences=num_molecules,
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
)
# **Decode generated molecule SMILES**
generated_smiles = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
# **Create molecule names**
molecule_names = [
f"Mol{str(i).zfill(2)}"
for i in range(1, len(generated_smiles) + 1)
]
# **Create DataFrame for generated molecules**
df_molecules = pd.DataFrame({
'Molecule Name': molecule_names,
'SMILES': generated_smiles
})
# **Invalid SMILES Check**
from rdkit import Chem
# **Function to validate SMILES**
def is_valid_smile(smile):
return Chem.MolFromSmiles(smile) is not None
# **Apply validation function**
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
df_valid = df_molecules[df_molecules['Valid']].copy()
# **Inform user if any molecules were invalid**
invalid_molecules = df_molecules[~df_molecules['Valid']]
if not invalid_molecules.empty:
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
# **Check if there are valid molecules to proceed**
if df_valid.empty:
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
else:
# **ADMET Predictions**
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
# **Ensure 'SMILES' is a column in preds**
if 'SMILES' not in preds.columns:
preds['SMILES'] = df_valid['SMILES'].values
# **Merge predictions with valid molecules**
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
# **Set 'Molecule Name' as index**
df_results.set_index('Molecule Name', inplace=True)
# **Select only desired ADMET properties**
admet_properties = [
'molecular weight', 'logP', 'hydrogen_bond_acceptors',
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
]
df_results_filtered = df_results[[
'SMILES', 'Valid'] + admet_properties]
# **Check if df_results_filtered is empty after filtering**
if df_results_filtered.empty:
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
else:
# **Display Molecules**
st.subheader('Generated Molecules')
# **Determine number of columns per row**
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns
# **Create columns in Streamlit**
cols = st.columns(cols_per_row)
# **Iterate over each molecule to display**
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
smiles = row['SMILES']
img = generate_molecule_image(smiles)
# **Initialize session state for toggle buttons**
toggle_key = f'toggle_{mol_name}'
if toggle_key not in st.session_state:
st.session_state[toggle_key] = False # False means SMILES is displayed
with cols[idx % cols_per_row]:
if img is not None and isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# **Toggle Button to Switch Between SMILES and SAFE**
if st.button(
'Toggle to SAFE' if not st.session_state[toggle_key] else 'Toggle to SMILES',
key=toggle_key + '_button'
):
st.session_state[toggle_key] = not st.session_state[toggle_key]
# **Display molecule string in chosen format**
if st.session_state[toggle_key]:
# **Attempt to encode to SAFE**
try:
safe_string = safe.encode(smiles)
st.code(safe_string)
except Exception as e:
st.error(f"Could not convert to SAFE encoding: {e}")
st.code(smiles)
else:
st.code(smiles)
# **Copy-to-clipboard functionality**
st_copy_button(
safe_string if st.session_state[toggle_key] else smiles,
key=f'copy_{mol_name}'
)
# **Display selected ADMET properties**
st.write("**ADMET Properties:**")
# Drop 'SMILES' and 'Valid' columns for display
admet_data = row.drop(['SMILES', 'Valid'])
st.write(admet_data)
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")