Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import os | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
from admet_ai import ADMETModel | |
import safe | |
import io | |
from PIL import Image | |
import cairosvg | |
import pandas as pd | |
# Page Configuration | |
st.set_page_config( | |
page_title='Beta-Lactam Molecule Generator', | |
layout='wide' | |
) | |
# Load Models | |
def load_models(): | |
""" | |
Load the molecule generation model and the ADMET-AI model. | |
Caches the models to avoid reloading on every run. | |
""" | |
# Load your molecule generation model | |
model_name = "bcadkins01/beta_lactam_generator" | |
access_token = os.getenv("HUGGING_FACE_TOKEN") | |
if access_token is None: | |
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
st.stop() | |
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token) | |
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token) | |
# Load ADMET-AI model | |
admet_model = ADMETModel() | |
return model, tokenizer, admet_model | |
# Load models once and reuse | |
model, tokenizer, admet_model = load_models() | |
# Set Generation Parameters in Sidebar | |
st.sidebar.header('Generation Parameters') | |
# Creativity Slider (Temperature) | |
creativity = st.sidebar.slider( | |
'Creativity (Temperature):', | |
min_value=0.0, | |
max_value=2.4, | |
value=1.0, | |
step=0.2, | |
help="Higher values lead to more diverse (or wild) outputs." | |
) | |
# Number of Molecules to Generate | |
num_molecules = st.sidebar.number_input( | |
'Number of Molecules to Generate:', | |
min_value=1, | |
max_value=3, # Adjust as needed | |
value=3, | |
help="Select the number of molecules you want to generate (up to 3)." | |
) | |
# Function to Generate Molecule Images | |
def generate_molecule_image(input_string, use_safe=False): | |
""" | |
Generates an image of the molecule from the input string. | |
If use_safe is True, input_string is treated as a SAFE string. | |
""" | |
try: | |
if use_safe and input_string is not None: | |
# Generate image from SAFE encoding | |
svg_str = safe.to_image(input_string) | |
# Convert SVG to PNG bytes | |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
# Create an image object | |
img = Image.open(io.BytesIO(png_bytes)) | |
else: | |
# Generate standard molecule image | |
mol = Chem.MolFromSmiles(input_string) | |
if mol: | |
img = Draw.MolToImage(mol, size=(250, 250)) | |
else: | |
img = None | |
return img | |
except Exception as e: | |
st.error(f"Error generating molecule image: {e}") | |
return None | |
# Generate Molecules Button | |
if st.button('Generate Molecules'): | |
st.info("Generating molecules... Please wait.") | |
# Beta-lactam core structure | |
core_smiles = "C1C(=O)N(C)C(=O)C1" | |
# Tokenize the core SMILES | |
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids | |
# Generate molecules using the model with diverse beam search | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_length=128, | |
do_sample=True, | |
temperature=1.2, # Increase for more diversity | |
top_k=0, # Disable top-k sampling | |
top_p=0.9, # Enable nucleus (top-p) sampling | |
num_return_sequences=num_molecules, | |
num_beams=1 | |
) | |
# Decode generated molecule SMILES | |
generated_smiles = [ | |
tokenizer.decode(ids, skip_special_tokens=True) | |
for ids in output_ids | |
] | |
# Create generic molecule names for demo | |
molecule_names = [ | |
f"Mol{str(i).zfill(2)}" | |
for i in range(1, len(generated_smiles) + 1) | |
] | |
# Create df for generated molecules | |
df_molecules = pd.DataFrame({ | |
'Molecule Name': molecule_names, | |
'SMILES': generated_smiles | |
}) | |
# Invalid SMILES Check | |
# Function to validate SMILES | |
def is_valid_smile(smile): | |
return Chem.MolFromSmiles(smile) is not None | |
# Apply validation function | |
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile) | |
df_valid = df_molecules[df_molecules['Valid']].copy() | |
# Inform user if any molecules were invalid | |
invalid_molecules = df_molecules[~df_molecules['Valid']] | |
if not invalid_molecules.empty: | |
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.") | |
# Check if there are valid molecules to proceed | |
if df_valid.empty: | |
st.error("No valid molecules were generated. Please try adjusting the generation parameters.") | |
else: | |
# ADMET Predictions | |
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist()) | |
# Ensure 'SMILES' is a column in preds | |
if 'SMILES' not in preds.columns: | |
preds['SMILES'] = df_valid['SMILES'].values | |
# Merge predictions with valid molecules | |
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner') | |
# Set 'Molecule Name' as index | |
df_results.set_index('Molecule Name', inplace=True) | |
# Select only desired ADMET properties | |
admet_properties = [ | |
'molecular_weight', 'logP', 'hydrogen_bond_acceptors', | |
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins' | |
] | |
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties] | |
# Check if df_results_filtered is empty after filtering | |
if df_results_filtered.empty: | |
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.") | |
else: | |
# Display Molecules | |
st.subheader('Generated Molecules') | |
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns | |
cols = st.columns(cols_per_row) | |
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()): | |
smiles = row['SMILES'] | |
# Attempt to encode to SAFE | |
try: | |
safe_string = safe.encode(smiles) | |
except Exception as e: | |
safe_string = None | |
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}") | |
# Generate molecule image (SMILES or SAFE) | |
img = generate_molecule_image(smiles) | |
with cols[idx % cols_per_row]: | |
if img is not None and isinstance(img, Image.Image): | |
st.image(img, caption=mol_name) | |
else: | |
st.error(f"Could not generate image for {mol_name}") | |
# Display SMILES string | |
st.write("**SMILES:**") | |
st.text(smiles) | |
# Display SAFE encoding if available | |
if safe_string: | |
st.write("**SAFE Encoding:**") | |
st.text(safe_string) | |
# Optionally display SAFE visualization | |
safe_img = generate_molecule_image(safe_string, use_safe=True) | |
if safe_img is not None: | |
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)") | |
# Display selected ADMET properties | |
st.write("**ADMET Properties:**") | |
admet_data = row.drop(['SMILES', 'Valid']) | |
st.write(admet_data) | |
else: | |
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |