Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import os | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
from admet_ai import ADMETModel | |
import safe | |
import io | |
from PIL import Image | |
import cairosvg | |
import pandas as pd | |
# Page Configuration | |
st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide') | |
# Load Models | |
def load_models(): | |
# Load your molecule generation model | |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path | |
access_token = os.getenv("HUGGING_FACE_TOKEN") | |
model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token) | |
tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token) | |
# Load ADMET-AI model | |
admet_model = ADMETModel() | |
return model, tokenizer, admet_model | |
model, tokenizer, admet_model = load_models() | |
# Set Generation Parameters | |
st.sidebar.header('Generation Parameters') | |
creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1) | |
num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5) | |
# String Format Option | |
string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE')) | |
# Generate Molecules Button | |
if st.button('Generate Molecules'): | |
st.info("Generating molecules... Please wait.") | |
# Generate molecules | |
core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure | |
output_ids = model.generate( | |
tokenizer(core_smiles, return_tensors='pt').input_ids, | |
max_length=128, | |
temperature=creativity, | |
do_sample=True, | |
top_k=50, | |
num_return_sequences=num_molecules | |
) | |
generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] | |
molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, num_molecules + 1)] | |
generated_molecules = dict(zip(molecule_names, generated_smiles)) | |
# ADMET Predictions | |
preds = admet_model.predict(smiles=list(generated_molecules.values())) | |
preds['Molecule Name'] = molecule_names | |
preds.set_index('Molecule Name', inplace=True) | |
# Display Molecules | |
st.subheader('Generated Molecules') | |
cols_per_row = min(5, num_molecules) | |
cols = st.columns(cols_per_row) | |
for idx, mol_name in enumerate(molecule_names): | |
smiles = generated_molecules[mol_name] | |
img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE')) | |
with cols[idx % cols_per_row]: | |
if isinstance(img, Image.Image): | |
st.image(img, caption=mol_name) | |
else: | |
st.error(f"Could not generate image for {mol_name}") | |
# Display molecule string | |
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles | |
st.code(string_to_display) | |
# Copy-to-clipboard functionality | |
st_copy_button(string_to_display, key=f'copy_{mol_name}') | |
# Display ADMET properties | |
st.write("**ADMET Properties:**") | |
st.write(preds.loc[mol_name]) | |
else: | |
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |
# Function Definitions | |
def generate_molecule_image(input_string, use_safe_visualization=True): | |
try: | |
if use_safe_visualization: | |
try: | |
# Attempt to decode as SAFE string | |
smiles = safe.decode(input_string) | |
# Encode back to SAFE string | |
safe_string = safe.encode(smiles) | |
except Exception: | |
# If decoding fails, assume input is SMILES and encode to SAFE | |
safe_string = safe.encode(input_string) | |
# Generate SVG image with fragment highlights | |
svg_str = safe.to_image(safe_string) | |
# Convert SVG to PNG bytes | |
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
# Create an image object | |
img = Image.open(io.BytesIO(png_bytes)) | |
else: | |
# Generate standard molecule image | |
mol = Chem.MolFromSmiles(input_string) | |
if mol: | |
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size | |
else: | |
img = None | |
return img | |
except Exception as e: | |
# Collect exceptions for later reporting | |
return e | |
import streamlit.components.v1 as components | |
def st_copy_button(text, key): | |
"""Creates a copy-to-clipboard button.""" | |
components.html(f""" | |
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button> | |
""", height=45) | |