import streamlit as st import torch import os from rdkit import Chem from rdkit.Chem import Draw from transformers import BartForConditionalGeneration, BartTokenizer from admet_ai import ADMETModel import safe import io from PIL import Image import cairosvg import pandas as pd # Page Configuration st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide') # Load Models @st.cache_resource(show_spinner="Loading Models...", ttl=600) def load_models(): # Load your molecule generation model model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path access_token = os.getenv("HUGGING_FACE_TOKEN") model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token) tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token) # Load ADMET-AI model admet_model = ADMETModel() return model, tokenizer, admet_model model, tokenizer, admet_model = load_models() # Set Generation Parameters st.sidebar.header('Generation Parameters') creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1) num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5) # String Format Option string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE')) # Generate Molecules Button if st.button('Generate Molecules'): st.info("Generating molecules... Please wait.") # Generate molecules core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure output_ids = model.generate( tokenizer(core_smiles, return_tensors='pt').input_ids, max_length=128, temperature=creativity, do_sample=True, top_k=50, num_return_sequences=num_molecules ) generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, num_molecules + 1)] generated_molecules = dict(zip(molecule_names, generated_smiles)) # ADMET Predictions preds = admet_model.predict(smiles=list(generated_molecules.values())) preds['Molecule Name'] = molecule_names preds.set_index('Molecule Name', inplace=True) # Display Molecules st.subheader('Generated Molecules') cols_per_row = min(5, num_molecules) cols = st.columns(cols_per_row) for idx, mol_name in enumerate(molecule_names): smiles = generated_molecules[mol_name] img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE')) with cols[idx % cols_per_row]: if isinstance(img, Image.Image): st.image(img, caption=mol_name) else: st.error(f"Could not generate image for {mol_name}") # Display molecule string string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles st.code(string_to_display) # Copy-to-clipboard functionality st_copy_button(string_to_display, key=f'copy_{mol_name}') # Display ADMET properties st.write("**ADMET Properties:**") st.write(preds.loc[mol_name]) else: st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") # Function Definitions def generate_molecule_image(input_string, use_safe_visualization=True): try: if use_safe_visualization: try: # Attempt to decode as SAFE string smiles = safe.decode(input_string) # Encode back to SAFE string safe_string = safe.encode(smiles) except Exception: # If decoding fails, assume input is SMILES and encode to SAFE safe_string = safe.encode(input_string) # Generate SVG image with fragment highlights svg_str = safe.to_image(safe_string) # Convert SVG to PNG bytes png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) # Create an image object img = Image.open(io.BytesIO(png_bytes)) else: # Generate standard molecule image mol = Chem.MolFromSmiles(input_string) if mol: img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size else: img = None return img except Exception as e: # Collect exceptions for later reporting return e import streamlit.components.v1 as components def st_copy_button(text, key): """Creates a copy-to-clipboard button.""" components.html(f""" """, height=45)