Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import os | |
from rdkit import Chem | |
from rdkit.Chem import Draw, Descriptors | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
from admet_ai import ADMETModel | |
import safe | |
import io | |
from PIL import Image | |
import cairosvg | |
import pandas as pd | |
import streamlit.components.v1 as components | |
# **Page Configuration** | |
st.set_page_config( | |
page_title='Beta-Lactam Molecule Generator', | |
layout='wide' | |
) | |
# **Load Models** | |
def load_models(): | |
""" | |
Load the molecule generation model and the ADMET-AI model. | |
Caches the models to avoid reloading on every run. | |
""" | |
# **Load your molecule generation model** | |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path | |
access_token = os.getenv("HUGGING_FACE_TOKEN") | |
if access_token is None: | |
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
st.stop() | |
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token) | |
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token) | |
# **Load ADMET-AI model** | |
admet_model = ADMETModel() | |
return model, tokenizer, admet_model | |
# **Load models once and reuse** | |
model, tokenizer, admet_model = load_models() | |
# **Set Generation Parameters in Sidebar** | |
st.sidebar.header('Generation Parameters') | |
# **Creativity Slider (Temperature)** | |
creativity = st.sidebar.slider( | |
'Creativity (Temperature):', | |
min_value=0.0, | |
max_value=2.0, | |
value=1.0, | |
step=0.1, | |
help="Higher values lead to more diverse outputs." | |
) | |
# **Number of Molecules to Generate** | |
num_molecules = st.sidebar.number_input( | |
'Number of Molecules to Generate:', | |
min_value=1, | |
max_value=3, # Reduced from 5 to 3 | |
value=3, | |
help="Select the number of molecules you want to generate (up to 3)." | |
) | |
# **Function to Generate Molecule Images** | |
def generate_molecule_image(smiles): | |
""" | |
Generates an image of the molecule from the SMILES string. | |
""" | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
if mol: | |
img = Draw.MolToImage(mol, size=(200, 200)) | |
else: | |
img = None | |
return img | |
except Exception as e: | |
st.error(f"Error generating molecule image: {e}") | |
return None | |
# **Function to Create Copy-to-Clipboard Button** | |
def st_copy_button(text, key): | |
""" | |
Creates a copy-to-clipboard button for the given text. | |
Adjusted to position the button without overlapping the text. | |
""" | |
# Adjusted styling to position the button | |
button_html = f""" | |
<div style="display: flex; justify-content: flex-end;"> | |
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px; margin-top: -40px; position: relative; z-index: 1;">Copy</button> | |
</div> | |
""" | |
components.html(button_html, height=45) | |
# **Generate Molecules Button** | |
if st.button('Generate Molecules'): | |
st.info("Generating molecules... Please wait.") | |
# **Beta-lactam core structure** | |
core_smiles = "C1C(=O)N(C)C(=O)C1" | |
# **Tokenize the core SMILES** | |
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids | |
# **Generate molecules using the model** | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_length=128, | |
temperature=creativity, | |
do_sample=True, | |
top_k=50, | |
num_return_sequences=num_molecules, | |
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences | |
) | |
# **Decode generated molecule SMILES** | |
generated_smiles = [ | |
tokenizer.decode(ids, skip_special_tokens=True) | |
for ids in output_ids | |
] | |
# **Create molecule names** | |
molecule_names = [ | |
f"Mol{str(i).zfill(2)}" | |
for i in range(1, len(generated_smiles) + 1) | |
] | |
# **Create DataFrame for generated molecules** | |
df_molecules = pd.DataFrame({ | |
'Molecule Name': molecule_names, | |
'SMILES': generated_smiles | |
}) | |
# **Invalid SMILES Check** | |
from rdkit import Chem | |
# **Function to validate SMILES** | |
def is_valid_smile(smile): | |
return Chem.MolFromSmiles(smile) is not None | |
# **Apply validation function** | |
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile) | |
df_valid = df_molecules[df_molecules['Valid']].copy() | |
# **Inform user if any molecules were invalid** | |
invalid_molecules = df_molecules[~df_molecules['Valid']] | |
if not invalid_molecules.empty: | |
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.") | |
# **Check if there are valid molecules to proceed** | |
if df_valid.empty: | |
st.error("No valid molecules were generated. Please try adjusting the generation parameters.") | |
else: | |
# **ADMET Predictions** | |
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist()) | |
# **Ensure 'SMILES' is a column in preds** | |
if 'SMILES' not in preds.columns: | |
preds['SMILES'] = df_valid['SMILES'].values | |
# **Merge predictions with valid molecules** | |
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner') | |
# **Set 'Molecule Name' as index** | |
df_results.set_index('Molecule Name', inplace=True) | |
# **Select only desired ADMET properties** | |
admet_properties = [ | |
'molecular weight', 'logP', 'hydrogen_bond_acceptors', | |
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins' | |
] | |
df_results_filtered = df_results[[ | |
'SMILES', 'Valid'] + admet_properties] | |
# **Check if df_results_filtered is empty after filtering** | |
if df_results_filtered.empty: | |
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.") | |
else: | |
# **Display Molecules** | |
st.subheader('Generated Molecules') | |
# **Determine number of columns per row** | |
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns | |
# **Create columns in Streamlit** | |
cols = st.columns(cols_per_row) | |
# **Iterate over each molecule to display** | |
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()): | |
smiles = row['SMILES'] | |
img = generate_molecule_image(smiles) | |
# **Initialize session state for toggle buttons** | |
toggle_key = f'toggle_{mol_name}' | |
if toggle_key not in st.session_state: | |
st.session_state[toggle_key] = False # False means SMILES is displayed | |
with cols[idx % cols_per_row]: | |
if img is not None and isinstance(img, Image.Image): | |
st.image(img, caption=mol_name) | |
else: | |
st.error(f"Could not generate image for {mol_name}") | |
# **Toggle Button to Switch Between SMILES and SAFE** | |
if st.button( | |
'Toggle to SAFE' if not st.session_state[toggle_key] else 'Toggle to SMILES', | |
key=toggle_key + '_button' | |
): | |
st.session_state[toggle_key] = not st.session_state[toggle_key] | |
# **Display molecule string in chosen format** | |
if st.session_state[toggle_key]: | |
# **Attempt to encode to SAFE** | |
try: | |
safe_string = safe.encode(smiles) | |
st.code(safe_string) | |
except Exception as e: | |
st.error(f"Could not convert to SAFE encoding: {e}") | |
st.code(smiles) | |
else: | |
st.code(smiles) | |
# **Copy-to-clipboard functionality** | |
st_copy_button( | |
safe_string if st.session_state[toggle_key] else smiles, | |
key=f'copy_{mol_name}' | |
) | |
# **Display selected ADMET properties** | |
st.write("**ADMET Properties:**") | |
# Drop 'SMILES' and 'Valid' columns for display | |
admet_data = row.drop(['SMILES', 'Valid']) | |
st.write(admet_data) | |
else: | |
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |