Spaces:
Sleeping
Sleeping
File size: 8,666 Bytes
e295393 87e6612 e295393 111a2ce 87e6612 e295393 111a2ce e295393 111a2ce e295393 111a2ce af1913b e295393 111a2ce e295393 111a2ce e295393 111a2ce e295393 111a2ce e295393 111a2ce e295393 111a2ce a3a6043 111a2ce 085e6ac 111a2ce 87e6612 111a2ce 87e6612 111a2ce e295393 87e6612 e295393 87e6612 e295393 111a2ce e295393 111a2ce e295393 87e6612 a3a6043 87e6612 111a2ce 87e6612 e295393 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 a3a6043 87e6612 a3a6043 87e6612 a3a6043 87e6612 a3a6043 111a2ce 87e6612 111a2ce a3a6043 111a2ce a3a6043 111a2ce a3a6043 87e6612 a3a6043 111a2ce a3a6043 87e6612 a3a6043 87e6612 a3a6043 87e6612 111a2ce a3a6043 111a2ce 87e6612 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
import streamlit.components.v1 as components
import json # For safely encoding text in JavaScript
# **Page Configuration**
st.set_page_config(
page_title='Beta-Lactam Molecule Generator',
layout='wide'
)
# **Load Models**
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
"""
Load the molecule generation model and the ADMET-AI model.
Caches the models to avoid reloading on every run.
"""
# **Load your molecule generation model**
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
access_token = os.getenv("HUGGING_FACE_TOKEN")
if access_token is None:
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
st.stop()
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
# **Load ADMET-AI model**
admet_model = ADMETModel()
return model, tokenizer, admet_model
# **Load models once and reuse**
model, tokenizer, admet_model = load_models()
# **Set Generation Parameters in Sidebar**
st.sidebar.header('Generation Parameters')
# **Creativity Slider (Temperature)**
creativity = st.sidebar.slider(
'Creativity (Temperature):',
min_value=0.0,
max_value=2.0,
value=1.0,
step=0.1,
help="Higher values lead to more diverse outputs."
)
# **Number of Molecules to Generate**
num_molecules = st.sidebar.number_input(
'Number of Molecules to Generate:',
min_value=1,
max_value=3, # Reduced from 5 to 3
value=3,
help="Select the number of molecules you want to generate (up to 3)."
)
# **Function to Generate Molecule Images**
def generate_molecule_image(input_string, use_safe=False):
"""
Generates an image of the molecule from the input string.
If use_safe is True, input_string is treated as a SAFE string.
"""
try:
if use_safe and input_string is not None:
# Generate image from SAFE encoding
svg_str = safe.to_image(input_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(200, 200))
else:
img = None
return img
except Exception as e:
st.error(f"Error generating molecule image: {e}")
return None
# **Function to Create Copy-to-Clipboard Button**
def st_copy_button(text, key):
"""Creates a copy-to-clipboard button placed appropriately."""
# Safely encode the text for JavaScript
escaped_text = json.dumps(text)
button_html = f"""
<div style="text-align: right; margin-top: -10px; margin-bottom: 10px;">
<button onclick="navigator.clipboard.writeText({escaped_text})" style="
padding:5px;
">Copy</button>
</div>
"""
components.html(button_html, height=35)
# **Generate Molecules Button**
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# **Beta-lactam core structure**
core_smiles = "C1C(=O)N(C)C(=O)C1"
# **Tokenize the core SMILES**
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
# **Generate molecules using the model**
output_ids = model.generate(
input_ids=input_ids,
max_length=128,
temperature=creativity,
do_sample=True,
top_k=50,
num_return_sequences=num_molecules,
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
)
# **Decode generated molecule SMILES**
generated_smiles = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
# **Create molecule names**
molecule_names = [
f"Mol{str(i).zfill(2)}"
for i in range(1, len(generated_smiles) + 1)
]
# **Create DataFrame for generated molecules**
df_molecules = pd.DataFrame({
'Molecule Name': molecule_names,
'SMILES': generated_smiles
})
# **Invalid SMILES Check**
# Function to validate SMILES
def is_valid_smile(smile):
return Chem.MolFromSmiles(smile) is not None
# Apply validation function
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
df_valid = df_molecules[df_molecules['Valid']].copy()
# Inform user if any molecules were invalid
invalid_molecules = df_molecules[~df_molecules['Valid']]
if not invalid_molecules.empty:
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
# Check if there are valid molecules to proceed
if df_valid.empty:
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
else:
# ADMET Predictions
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
# Ensure 'SMILES' is a column in preds
if 'SMILES' not in preds.columns:
preds['SMILES'] = df_valid['SMILES'].values
# Merge predictions with valid molecules
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
# Set 'Molecule Name' as index
df_results.set_index('Molecule Name', inplace=True)
# Select only desired ADMET properties
admet_properties = [
'molecular weight', 'logP', 'hydrogen_bond_acceptors',
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
]
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties]
# Check if df_results_filtered is empty after filtering
if df_results_filtered.empty:
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
else:
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns
cols = st.columns(cols_per_row)
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
smiles = row['SMILES']
# Attempt to encode to SAFE
try:
safe_string = safe.encode(smiles)
except Exception as e:
safe_string = None
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}")
# Generate molecule image (SMILES or SAFE)
img = generate_molecule_image(smiles)
with cols[idx % cols_per_row]:
if img is not None and isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display SMILES string
st.write("**SMILES:**")
st.text(smiles)
st_copy_button(smiles, key=f'copy_smiles_{mol_name}')
# Display SAFE encoding if available
if safe_string:
st.write("**SAFE Encoding:**")
st.text(safe_string)
st_copy_button(safe_string, key=f'copy_safe_{mol_name}')
# Optionally display SAFE visualization
safe_img = generate_molecule_image(safe_string, use_safe=True)
if safe_img is not None:
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)")
# Display selected ADMET properties
st.write("**ADMET Properties:**")
admet_data = row.drop(['SMILES', 'Valid'])
st.write(admet_data)
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
|