Spaces:
Sleeping
Sleeping
File size: 7,964 Bytes
e295393 87e6612 e295393 83d667e 111a2ce e295393 83d667e e295393 111a2ce 83d667e 0db5108 e295393 111a2ce 83d667e e295393 111a2ce e295393 83d667e e295393 83d667e e295393 83d667e 111a2ce 83d667e 111a2ce 83d667e 0db5108 111a2ce e295393 0db5108 111a2ce 231b5fc a3a6043 111a2ce 085e6ac 83d667e 87e6612 111a2ce 87e6612 111a2ce e295393 87e6612 e295393 87e6612 231b5fc 87e6612 e295393 111a2ce e295393 83d667e 111a2ce 83d667e 111a2ce 83d667e 111a2ce 313eced 111a2ce 58b2e8d 111a2ce 313eced 111a2ce 83d667e 111a2ce 83d667e 111a2ce 83d667e 111a2ce 83d667e 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 111a2ce 87e6612 a3a6043 231b5fc a3a6043 87e6612 a3a6043 87e6612 a3a6043 111a2ce 87e6612 111a2ce a3a6043 111a2ce a3a6043 111a2ce a3a6043 87e6612 a3a6043 111a2ce a3a6043 87e6612 a3a6043 87e6612 a3a6043 87e6612 111a2ce a3a6043 111a2ce 0db5108 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
# Page Configuration
st.set_page_config(
page_title='Beta-Lactam Molecule Generator',
layout='wide'
)
# Load Models
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
"""
Load the molecule generation model and the ADMET-AI model.
Caches the models to avoid reloading on every run.
"""
# Load your molecule generation model
model_name = "bcadkins01/beta_lactam_generator"
access_token = os.getenv("HUGGING_FACE_TOKEN")
if access_token is None:
st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
st.stop()
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
# Load ADMET-AI model
admet_model = ADMETModel()
return model, tokenizer, admet_model
# Load models once and reuse
model, tokenizer, admet_model = load_models()
# Set Generation Parameters in Sidebar
st.sidebar.header('Generation Parameters')
# Creativity Slider (Temperature)
creativity = st.sidebar.slider(
'Creativity (Temperature):',
min_value=0.0,
max_value=2.4,
value=1.0,
step=0.2,
help="Higher values lead to more diverse (or wild) outputs."
)
# Number of Molecules to Generate
num_molecules = st.sidebar.number_input(
'Number of Molecules to Generate:',
min_value=1,
max_value=3, # Adjust as needed
value=3,
help="Select the number of molecules you want to generate (up to 3)."
)
# Function to Generate Molecule Images
def generate_molecule_image(input_string, use_safe=False):
"""
Generates an image of the molecule from the input string.
If use_safe is True, input_string is treated as a SAFE string.
"""
try:
if use_safe and input_string is not None:
# Generate image from SAFE encoding
svg_str = safe.to_image(input_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(250, 250))
else:
img = None
return img
except Exception as e:
st.error(f"Error generating molecule image: {e}")
return None
# Generate Molecules Button
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# Beta-lactam core structure
core_smiles = "C1C(=O)N(C)C(=O)C1"
# Tokenize the core SMILES
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
# Generate molecules using the model with diverse beam search
output_ids = model.generate(
input_ids=input_ids,
max_length=128,
do_sample=False, # Use deterministic beam search
num_return_sequences=num_molecules,
num_beams=max(num_molecules * 2, 5), # Ensure enough beams for diversity
num_beam_groups=num_molecules, # Equal to num_return_sequences
diversity_penalty=0.5 # Adjust this value as needed
)
# Decode generated molecule SMILES
generated_smiles = [
tokenizer.decode(ids, skip_special_tokens=True)
for ids in output_ids
]
# Create generic molecule names for demo
molecule_names = [
f"Mol{str(i).zfill(2)}"
for i in range(1, len(generated_smiles) + 1)
]
# Create df for generated molecules
df_molecules = pd.DataFrame({
'Molecule Name': molecule_names,
'SMILES': generated_smiles
})
# Invalid SMILES Check
# Function to validate SMILES
def is_valid_smile(smile):
return Chem.MolFromSmiles(smile) is not None
# Apply validation function
df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
df_valid = df_molecules[df_molecules['Valid']].copy()
# Inform user if any molecules were invalid
invalid_molecules = df_molecules[~df_molecules['Valid']]
if not invalid_molecules.empty:
st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
# Check if there are valid molecules to proceed
if df_valid.empty:
st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
else:
# ADMET Predictions
preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
# Ensure 'SMILES' is a column in preds
if 'SMILES' not in preds.columns:
preds['SMILES'] = df_valid['SMILES'].values
# Merge predictions with valid molecules
df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
# Set 'Molecule Name' as index
df_results.set_index('Molecule Name', inplace=True)
# Select only desired ADMET properties
admet_properties = [
'molecular_weight', 'logP', 'hydrogen_bond_acceptors',
'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
]
df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties]
# Check if df_results_filtered is empty after filtering
if df_results_filtered.empty:
st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
else:
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns
cols = st.columns(cols_per_row)
for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
smiles = row['SMILES']
# Attempt to encode to SAFE
try:
safe_string = safe.encode(smiles)
except Exception as e:
safe_string = None
st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}")
# Generate molecule image (SMILES or SAFE)
img = generate_molecule_image(smiles)
with cols[idx % cols_per_row]:
if img is not None and isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display SMILES string
st.write("**SMILES:**")
st.text(smiles)
# Display SAFE encoding if available
if safe_string:
st.write("**SAFE Encoding:**")
st.text(safe_string)
# Optionally display SAFE visualization
safe_img = generate_molecule_image(safe_string, use_safe=True)
if safe_img is not None:
st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)")
# Display selected ADMET properties
st.write("**ADMET Properties:**")
admet_data = row.drop(['SMILES', 'Valid'])
st.write(admet_data)
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
|