bcadkins01's picture
Update app.py
87e6612 verified
raw
history blame
8.67 kB
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
import streamlit.components.v1 as components
import json # For safely encoding text in JavaScript
# **Page Configuration**
st.set_page_config(
page_title='Beta-Lactam Molecule Generator',
layout='wide'
)
# **Load Models**
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
"""
Load the molecule generation model and the ADMET-AI model.
Caches the models to avoid reloading on every run.
"""
# **Load your molecule generation model**
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
access_token = os.getenv("HUGGING_FACE_TOKEN")
if access_token is None:
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
st.stop()
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
# **Load ADMET-AI model**
admet_model = ADMETModel()
return model, tokenizer, admet_model
# **Load models once and reuse**
model, tokenizer, admet_model = load_models()
# **Set Generation Parameters in Sidebar**
st.sidebar.header('Generation Parameters')
# **Creativity Slider (Temperature)**
creativity = st.sidebar.slider(
'Creativity (Temperature):',
min_value=0.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Higher values lead to more diverse outputs."
)
# **Number of Molecules to Generate**
num_molecules = st.sidebar.number_input(
'Number of Molecules to Generate:',
min_value=1,
max_value=3, # Reduced from 5 to 3
value=3,
help="Select the number of molecules you want to generate (up to 3)."
)
# **Function to Generate Molecule Images**
def generate_molecule_image(input_string, use_safe=False):
"""
Generates an image of the molecule from the input string.
If use_safe is True, input_string is treated as a SAFE string.
"""
try:
if use_safe and input_string is not None:
# Generate image from SAFE encoding
svg_str = safe.to_image(input_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(200, 200))
else:
img = None
return img
except Exception as e:
st.error(f"Error generating molecule image: {e}")
return None
# **Function to Create Copy-to-Clipboard Button**
def st_copy_button(text, key):
"""Creates a copy-to-clipboard button placed appropriately."""
# Safely encode the text for JavaScript
escaped_text = json.dumps(text)
button_html = f"""
<div style="text-align: right; margin-top: -10px; margin-bottom: 10px;">
<button onclick="navigator.clipboard.writeText({escaped_text})" style="
padding:5px;
">Copy</button>
</div>
"""
components.html(button_html, height=35)
# **Generate Molecules Button**
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# **Beta-lactam core structure**
core_smiles = "C1C(=O)N(C)C(=O)C1"
# **Tokenize the core SMILES**
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
# **Generate molecules using the model**
output_ids = model.generate(
input_ids=input_ids,
max_length=128,
temperature=creativity,
do_sample=True,
top_k=50,
num_return_sequences=num_molecules,
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
)
# **Decode generated molecule SMILES**
generated_smiles = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
# **Create molecule names**
molecule_names = [
f"Mol{str(i).zfill(2)}"
for i in range(1, len(generated_smiles) + 1)
]
# **Create DataFrame for generated molecules**
df_molecules = pd.DataFrame({
'Molecule Name': molecule_names,
'SMILES': generated_smiles
})
# **Invalid SMILES Check**
# Function to validate SMILES
def is_valid_smile(smile):
return Chem.MolFromSmiles(smile) is not None
# Apply validation function
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
df_valid = df_molecules[df_molecules['Valid']].copy()
# Inform user if any molecules were invalid
invalid_molecules = df_molecules[~df_molecules['Valid']]
if not invalid_molecules.empty:
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
# Check if there are valid molecules to proceed
if df_valid.empty:
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
else:
# ADMET Predictions
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
# Ensure 'SMILES' is a column in preds
if 'SMILES' not in preds.columns:
preds['SMILES'] = df_valid['SMILES'].values
# Merge predictions with valid molecules
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
# Set 'Molecule Name' as index
df_results.set_index('Molecule Name', inplace=True)
# Select only desired ADMET properties
admet_properties = [
'molecular weight', 'logP', 'hydrogen_bond_acceptors',
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
]
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties]
# Check if df_results_filtered is empty after filtering
if df_results_filtered.empty:
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
else:
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns
cols = st.columns(cols_per_row)
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
smiles = row['SMILES']
# Attempt to encode to SAFE
try:
safe_string = safe.encode(smiles)
except Exception as e:
safe_string = None
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}")
# Generate molecule image (SMILES or SAFE)
img = generate_molecule_image(smiles)
with cols[idx % cols_per_row]:
if img is not None and isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display SMILES string
st.write("**SMILES:**")
st.text(smiles)
st_copy_button(smiles, key=f'copy_smiles_{mol_name}')
# Display SAFE encoding if available
if safe_string:
st.write("**SAFE Encoding:**")
st.text(safe_string)
st_copy_button(safe_string, key=f'copy_safe_{mol_name}')
# Optionally display SAFE visualization
safe_img = generate_molecule_image(safe_string, use_safe=True)
if safe_img is not None:
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)")
# Display selected ADMET properties
st.write("**ADMET Properties:**")
admet_data = row.drop(['SMILES', 'Valid'])
st.write(admet_data)
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")