bcadkins01's picture
Update app.py
af1913b verified
raw
history blame
4.79 kB
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
# Page Configuration
st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide')
# Load Models
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
# Load your molecule generation model
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
access_token = os.getenv("HUGGING_FACE_TOKEN")
model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token)
# Load ADMET-AI model
admet_model = ADMETModel()
return model, tokenizer, admet_model
model, tokenizer, admet_model = load_models()
# Set Generation Parameters
st.sidebar.header('Generation Parameters')
creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1)
num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5)
# String Format Option
string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE'))
# Generate Molecules Button
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# Generate molecules
core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure
output_ids = model.generate(
tokenizer(core_smiles, return_tensors='pt').input_ids,
max_length=128,
temperature=creativity,
do_sample=True,
top_k=50,
num_return_sequences=num_molecules
)
generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, num_molecules + 1)]
generated_molecules = dict(zip(molecule_names, generated_smiles))
# ADMET Predictions
preds = admet_model.predict(smiles=list(generated_molecules.values()))
preds['Molecule Name'] = molecule_names
preds.set_index('Molecule Name', inplace=True)
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(5, num_molecules)
cols = st.columns(cols_per_row)
for idx, mol_name in enumerate(molecule_names):
smiles = generated_molecules[mol_name]
img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE'))
with cols[idx % cols_per_row]:
if isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display molecule string
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
st.code(string_to_display)
# Copy-to-clipboard functionality
st_copy_button(string_to_display, key=f'copy_{mol_name}')
# Display ADMET properties
st.write("**ADMET Properties:**")
st.write(preds.loc[mol_name])
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
# Function Definitions
def generate_molecule_image(input_string, use_safe_visualization=True):
try:
if use_safe_visualization:
try:
# Attempt to decode as SAFE string
smiles = safe.decode(input_string)
# Encode back to SAFE string
safe_string = safe.encode(smiles)
except Exception:
# If decoding fails, assume input is SMILES and encode to SAFE
safe_string = safe.encode(input_string)
# Generate SVG image with fragment highlights
svg_str = safe.to_image(safe_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
else:
img = None
return img
except Exception as e:
# Collect exceptions for later reporting
return e
import streamlit.components.v1 as components
def st_copy_button(text, key):
"""Creates a copy-to-clipboard button."""
components.html(f"""
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
""", height=45)