bcadkins01's picture
Update app.py
4a19ce8 verified
raw
history blame
7.9 kB
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
# Page Configuration
st.set_page_config(
page_title='Beta-Lactam Molecule Generator',
layout='wide'
)
# Load Models
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
"""
Load the molecule generation model and the ADMET-AI model.
Caches the models to avoid reloading on every run.
"""
# Load your molecule generation model
model_name = "bcadkins01/beta_lactam_generator"
access_token = os.getenv("HUGGING_FACE_TOKEN")
if access_token is None:
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
st.stop()
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
# Load ADMET-AI model
admet_model = ADMETModel()
return model, tokenizer, admet_model
# Load models once and reuse
model, tokenizer, admet_model = load_models()
# Set Generation Parameters in Sidebar
st.sidebar.header('Generation Parameters')
# Creativity Slider (Temperature)
creativity = st.sidebar.slider(
'Creativity (Temperature):',
min_value=0.0,
max_value=2.4,
value=1.0,
step=0.2,
help="Higher values lead to more diverse (or wild) outputs."
)
# Number of Molecules to Generate
num_molecules = st.sidebar.number_input(
'Number of Molecules to Generate:',
min_value=1,
max_value=3, # Adjust as needed
value=3,
help="Select the number of molecules you want to generate (up to 3)."
)
# Function to Generate Molecule Images
def generate_molecule_image(input_string, use_safe=False):
"""
Generates an image of the molecule from the input string.
If use_safe is True, input_string is treated as a SAFE string.
"""
try:
if use_safe and input_string is not None:
# Generate image from SAFE encoding
svg_str = safe.to_image(input_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(250, 250))
else:
img = None
return img
except Exception as e:
st.error(f"Error generating molecule image: {e}")
return None
# Generate Molecules Button
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# Beta-lactam core structure
core_smiles = "C1C(=O)N(C)C(=O)C1"
# Tokenize the core SMILES
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
# Generate molecules using the model with diverse beam search
output_ids = model.generate(
input_ids=input_ids,
max_length=128,
do_sample=True,
temperature=1.2, # Increase for more diversity
top_k=0, # Disable top-k sampling
top_p=0.9, # Enable nucleus (top-p) sampling
num_return_sequences=num_molecules,
num_beams=1
)
# Decode generated molecule SMILES
generated_smiles = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
# Create generic molecule names for demo
molecule_names = [
f"Mol{str(i).zfill(2)}"
for i in range(1, len(generated_smiles) + 1)
]
# Create df for generated molecules
df_molecules = pd.DataFrame({
'Molecule Name': molecule_names,
'SMILES': generated_smiles
})
# Invalid SMILES Check
# Function to validate SMILES
def is_valid_smile(smile):
return Chem.MolFromSmiles(smile) is not None
# Apply validation function
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
df_valid = df_molecules[df_molecules['Valid']].copy()
# Inform user if any molecules were invalid
invalid_molecules = df_molecules[~df_molecules['Valid']]
if not invalid_molecules.empty:
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
# Check if there are valid molecules to proceed
if df_valid.empty:
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
else:
# ADMET Predictions
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
# Ensure 'SMILES' is a column in preds
if 'SMILES' not in preds.columns:
preds['SMILES'] = df_valid['SMILES'].values
# Merge predictions with valid molecules
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
# Set 'Molecule Name' as index
df_results.set_index('Molecule Name', inplace=True)
# Select only desired ADMET properties
admet_properties = [
'molecular_weight', 'logP', 'hydrogen_bond_acceptors',
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
]
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties]
# Check if df_results_filtered is empty after filtering
if df_results_filtered.empty:
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
else:
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns
cols = st.columns(cols_per_row)
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
smiles = row['SMILES']
# Attempt to encode to SAFE
try:
safe_string = safe.encode(smiles)
except Exception as e:
safe_string = None
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}")
# Generate molecule image (SMILES or SAFE)
img = generate_molecule_image(smiles)
with cols[idx % cols_per_row]:
if img is not None and isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display SMILES string
st.write("**SMILES:**")
st.text(smiles)
# Display SAFE encoding if available
if safe_string:
st.write("**SAFE Encoding:**")
st.text(safe_string)
# Optionally display SAFE visualization
safe_img = generate_molecule_image(safe_string, use_safe=True)
if safe_img is not None:
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)")
# Display selected ADMET properties
st.write("**ADMET Properties:**")
admet_data = row.drop(['SMILES', 'Valid'])
st.write(admet_data)
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")