Spaces:
Sleeping
Sleeping
File size: 4,789 Bytes
e295393 af1913b e295393 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd
# Page Configuration
st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide')
# Load Models
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
# Load your molecule generation model
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
access_token = os.getenv("HUGGING_FACE_TOKEN")
model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token)
tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token)
# Load ADMET-AI model
admet_model = ADMETModel()
return model, tokenizer, admet_model
model, tokenizer, admet_model = load_models()
# Set Generation Parameters
st.sidebar.header('Generation Parameters')
creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1)
num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5)
# String Format Option
string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE'))
# Generate Molecules Button
if st.button('Generate Molecules'):
st.info("Generating molecules... Please wait.")
# Generate molecules
core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure
output_ids = model.generate(
tokenizer(core_smiles, return_tensors='pt').input_ids,
max_length=128,
temperature=creativity,
do_sample=True,
top_k=50,
num_return_sequences=num_molecules
)
generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, num_molecules + 1)]
generated_molecules = dict(zip(molecule_names, generated_smiles))
# ADMET Predictions
preds = admet_model.predict(smiles=list(generated_molecules.values()))
preds['Molecule Name'] = molecule_names
preds.set_index('Molecule Name', inplace=True)
# Display Molecules
st.subheader('Generated Molecules')
cols_per_row = min(5, num_molecules)
cols = st.columns(cols_per_row)
for idx, mol_name in enumerate(molecule_names):
smiles = generated_molecules[mol_name]
img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE'))
with cols[idx % cols_per_row]:
if isinstance(img, Image.Image):
st.image(img, caption=mol_name)
else:
st.error(f"Could not generate image for {mol_name}")
# Display molecule string
string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
st.code(string_to_display)
# Copy-to-clipboard functionality
st_copy_button(string_to_display, key=f'copy_{mol_name}')
# Display ADMET properties
st.write("**ADMET Properties:**")
st.write(preds.loc[mol_name])
else:
st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
# Function Definitions
def generate_molecule_image(input_string, use_safe_visualization=True):
try:
if use_safe_visualization:
try:
# Attempt to decode as SAFE string
smiles = safe.decode(input_string)
# Encode back to SAFE string
safe_string = safe.encode(smiles)
except Exception:
# If decoding fails, assume input is SMILES and encode to SAFE
safe_string = safe.encode(input_string)
# Generate SVG image with fragment highlights
svg_str = safe.to_image(safe_string)
# Convert SVG to PNG bytes
png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
# Create an image object
img = Image.open(io.BytesIO(png_bytes))
else:
# Generate standard molecule image
mol = Chem.MolFromSmiles(input_string)
if mol:
img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
else:
img = None
return img
except Exception as e:
# Collect exceptions for later reporting
return e
import streamlit.components.v1 as components
def st_copy_button(text, key):
"""Creates a copy-to-clipboard button."""
components.html(f"""
<button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
""", height=45)
|