Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import os | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
from admet_ai import ADMETModel | |
import safe | |
import io | |
from PIL import Image | |
import cairosvg | |
import pandas as pd | |
import streamlit.components.v1 as components | |
import json # For safely encoding text in JavaScript | |
# **Page Configuration** | |
st.set_page_config( | |
page_title='Beta-Lactam Molecule Generator', | |
layout='wide' | |
) | |
# **Load Models** | |
def load_models(): | |
""" | |
Load the molecule generation model and the ADMET-AI model. | |
Caches the models to avoid reloading on every run. | |
""" | |
# **Load your molecule generation model** | |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path | |
access_token = os.getenv("HUGGING_FACE_TOKEN") | |
if access_token is None: | |
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
st.stop() | |
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token) | |
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token) | |
# **Load ADMET-AI model** | |
admet_model = ADMETModel() | |
return model, tokenizer, admet_model | |
# **Load models once and reuse** | |
model, tokenizer, admet_model = load_models() | |
# **Set Generation Parameters in Sidebar** | |
st.sidebar.header('Generation Parameters') | |
# **Creativity Slider (Temperature)** | |
creativity = st.sidebar.slider( | |
'Creativity (Temperature):', | |
min_value=0.0, | |
max_value=2.0, | |
value=1.0, | |
step=0.1, | |
help="Higher values lead to more diverse outputs." | |
) | |
# **Number of Molecules to Generate** | |
num_molecules = st.sidebar.number_input( | |
'Number of Molecules to Generate:', | |
min_value=1, | |
max_value=3, # Reduced from 5 to 3 | |
value=3, | |
help="Select the number of molecules you want to generate (up to 3)." | |
) | |
# **Function to Generate Molecule Images** | |
def generate_molecule_image(input_string, use_safe=False): | |
""" | |
Generates an image of the molecule from the input string. | |
If use_safe is True, input_string is treated as a SAFE string. | |
""" | |
try: | |
if use_safe and input_string is not None: | |
# Generate image from SAFE encoding | |
svg_str = safe.to_image(input_string) | |
# Convert SVG to PNG bytes | |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
# Create an image object | |
img = Image.open(io.BytesIO(png_bytes)) | |
else: | |
# Generate standard molecule image | |
mol = Chem.MolFromSmiles(input_string) | |
if mol: | |
img = Draw.MolToImage(mol, size=(200, 200)) | |
else: | |
img = None | |
return img | |
except Exception as e: | |
st.error(f"Error generating molecule image: {e}") | |
return None | |
# **Function to Create Copy-to-Clipboard Button** | |
def st_copy_button(text, key): | |
"""Creates a copy-to-clipboard button placed appropriately.""" | |
# Safely encode the text for JavaScript | |
escaped_text = json.dumps(text) | |
button_html = f""" | |
<div style="text-align: right; margin-top: -10px; margin-bottom: 10px;"> | |
<button onclick="navigator.clipboard.writeText({escaped_text})" style=" | |
padding:5px; | |
">Copy</button> | |
</div> | |
""" | |
components.html(button_html, height=35) | |
# **Generate Molecules Button** | |
if st.button('Generate Molecules'): | |
st.info("Generating molecules... Please wait.") | |
# **Beta-lactam core structure** | |
core_smiles = "C1C(=O)N(C)C(=O)C1" | |
# **Tokenize the core SMILES** | |
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids | |
# **Generate molecules using the model** | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_length=128, | |
temperature=creativity, | |
do_sample=True, | |
top_k=50, | |
num_return_sequences=num_molecules, | |
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences | |
) | |
# **Decode generated molecule SMILES** | |
generated_smiles = [ | |
tokenizer.decode(ids, skip_special_tokens=True) | |
for ids in output_ids | |
] | |
# **Create molecule names** | |
molecule_names = [ | |
f"Mol{str(i).zfill(2)}" | |
for i in range(1, len(generated_smiles) + 1) | |
] | |
# **Create DataFrame for generated molecules** | |
df_molecules = pd.DataFrame({ | |
'Molecule Name': molecule_names, | |
'SMILES': generated_smiles | |
}) | |
# **Invalid SMILES Check** | |
# Function to validate SMILES | |
def is_valid_smile(smile): | |
return Chem.MolFromSmiles(smile) is not None | |
# Apply validation function | |
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile) | |
df_valid = df_molecules[df_molecules['Valid']].copy() | |
# Inform user if any molecules were invalid | |
invalid_molecules = df_molecules[~df_molecules['Valid']] | |
if not invalid_molecules.empty: | |
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.") | |
# Check if there are valid molecules to proceed | |
if df_valid.empty: | |
st.error("No valid molecules were generated. Please try adjusting the generation parameters.") | |
else: | |
# ADMET Predictions | |
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist()) | |
# Ensure 'SMILES' is a column in preds | |
if 'SMILES' not in preds.columns: | |
preds['SMILES'] = df_valid['SMILES'].values | |
# Merge predictions with valid molecules | |
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner') | |
# Set 'Molecule Name' as index | |
df_results.set_index('Molecule Name', inplace=True) | |
# Select only desired ADMET properties | |
admet_properties = [ | |
'molecular weight', 'logP', 'hydrogen_bond_acceptors', | |
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins' | |
] | |
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties] | |
# Check if df_results_filtered is empty after filtering | |
if df_results_filtered.empty: | |
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.") | |
else: | |
# Display Molecules | |
st.subheader('Generated Molecules') | |
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns | |
cols = st.columns(cols_per_row) | |
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()): | |
smiles = row['SMILES'] | |
# Attempt to encode to SAFE | |
try: | |
safe_string = safe.encode(smiles) | |
except Exception as e: | |
safe_string = None | |
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}") | |
# Generate molecule image (SMILES or SAFE) | |
img = generate_molecule_image(smiles) | |
with cols[idx % cols_per_row]: | |
if img is not None and isinstance(img, Image.Image): | |
st.image(img, caption=mol_name) | |
else: | |
st.error(f"Could not generate image for {mol_name}") | |
# Display SMILES string | |
st.write("**SMILES:**") | |
st.text(smiles) | |
st_copy_button(smiles, key=f'copy_smiles_{mol_name}') | |
# Display SAFE encoding if available | |
if safe_string: | |
st.write("**SAFE Encoding:**") | |
st.text(safe_string) | |
st_copy_button(safe_string, key=f'copy_safe_{mol_name}') | |
# Optionally display SAFE visualization | |
safe_img = generate_molecule_image(safe_string, use_safe=True) | |
if safe_img is not None: | |
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)") | |
# Display selected ADMET properties | |
st.write("**ADMET Properties:**") | |
admet_data = row.drop(['SMILES', 'Valid']) | |
st.write(admet_data) | |
else: | |
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |