Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import os | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from transformers import BartForConditionalGeneration, BartTokenizer | |
| from admet_ai import ADMETModel | |
| import safe | |
| import io | |
| from PIL import Image | |
| import cairosvg | |
| import pandas as pd | |
| import streamlit.components.v1 as components | |
| import json # For safely encoding text in JavaScript | |
| # **Page Configuration** | |
| st.set_page_config( | |
| page_title='Beta-Lactam Molecule Generator', | |
| layout='wide' | |
| ) | |
| # **Load Models** | |
| def load_models(): | |
| """ | |
| Load the molecule generation model and the ADMET-AI model. | |
| Caches the models to avoid reloading on every run. | |
| """ | |
| # **Load your molecule generation model** | |
| model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path | |
| access_token = os.getenv("HUGGING_FACE_TOKEN") | |
| if access_token is None: | |
| st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
| st.stop() | |
| model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token) | |
| tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token) | |
| # **Load ADMET-AI model** | |
| admet_model = ADMETModel() | |
| return model, tokenizer, admet_model | |
| # **Load models once and reuse** | |
| model, tokenizer, admet_model = load_models() | |
| # **Set Generation Parameters in Sidebar** | |
| st.sidebar.header('Generation Parameters') | |
| # **Creativity Slider (Temperature)** | |
| creativity = st.sidebar.slider( | |
| 'Creativity (Temperature):', | |
| min_value=0.0, | |
| max_value=2.0, | |
| value=1.0, | |
| step=0.1, | |
| help="Higher values lead to more diverse outputs." | |
| ) | |
| # **Number of Molecules to Generate** | |
| num_molecules = st.sidebar.number_input( | |
| 'Number of Molecules to Generate:', | |
| min_value=1, | |
| max_value=3, # Reduced from 5 to 3 | |
| value=3, | |
| help="Select the number of molecules you want to generate (up to 3)." | |
| ) | |
| # **Function to Generate Molecule Images** | |
| def generate_molecule_image(input_string, use_safe=False): | |
| """ | |
| Generates an image of the molecule from the input string. | |
| If use_safe is True, input_string is treated as a SAFE string. | |
| """ | |
| try: | |
| if use_safe and input_string is not None: | |
| # Generate image from SAFE encoding | |
| svg_str = safe.to_image(input_string) | |
| # Convert SVG to PNG bytes | |
| png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
| # Create an image object | |
| img = Image.open(io.BytesIO(png_bytes)) | |
| else: | |
| # Generate standard molecule image | |
| mol = Chem.MolFromSmiles(input_string) | |
| if mol: | |
| img = Draw.MolToImage(mol, size=(200, 200)) | |
| else: | |
| img = None | |
| return img | |
| except Exception as e: | |
| st.error(f"Error generating molecule image: {e}") | |
| return None | |
| # **Function to Create Copy-to-Clipboard Button** | |
| def st_copy_button(text, key): | |
| """Creates a copy-to-clipboard button placed appropriately.""" | |
| # Safely encode the text for JavaScript | |
| escaped_text = json.dumps(text) | |
| button_html = f""" | |
| <div style="text-align: right; margin-top: -10px; margin-bottom: 10px;"> | |
| <button onclick="navigator.clipboard.writeText({escaped_text})" style=" | |
| padding:5px; | |
| ">Copy</button> | |
| </div> | |
| """ | |
| components.html(button_html, height=35) | |
| # **Generate Molecules Button** | |
| if st.button('Generate Molecules'): | |
| st.info("Generating molecules... Please wait.") | |
| # **Beta-lactam core structure** | |
| core_smiles = "C1C(=O)N(C)C(=O)C1" | |
| # **Tokenize the core SMILES** | |
| input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids | |
| # **Generate molecules using the model** | |
| output_ids = model.generate( | |
| input_ids=input_ids, | |
| max_length=128, | |
| temperature=creativity, | |
| do_sample=True, | |
| top_k=50, | |
| num_return_sequences=num_molecules, | |
| num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences | |
| ) | |
| # **Decode generated molecule SMILES** | |
| generated_smiles = [ | |
| tokenizer.decode(ids, skip_special_tokens=True) | |
| for ids in output_ids | |
| ] | |
| # **Create molecule names** | |
| molecule_names = [ | |
| f"Mol{str(i).zfill(2)}" | |
| for i in range(1, len(generated_smiles) + 1) | |
| ] | |
| # **Create DataFrame for generated molecules** | |
| df_molecules = pd.DataFrame({ | |
| 'Molecule Name': molecule_names, | |
| 'SMILES': generated_smiles | |
| }) | |
| # **Invalid SMILES Check** | |
| # Function to validate SMILES | |
| def is_valid_smile(smile): | |
| return Chem.MolFromSmiles(smile) is not None | |
| # Apply validation function | |
| df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile) | |
| df_valid = df_molecules[df_molecules['Valid']].copy() | |
| # Inform user if any molecules were invalid | |
| invalid_molecules = df_molecules[~df_molecules['Valid']] | |
| if not invalid_molecules.empty: | |
| st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.") | |
| # Check if there are valid molecules to proceed | |
| if df_valid.empty: | |
| st.error("No valid molecules were generated. Please try adjusting the generation parameters.") | |
| else: | |
| # ADMET Predictions | |
| preds = admet_model.predict(smiles=df_valid['SMILES'].tolist()) | |
| # Ensure 'SMILES' is a column in preds | |
| if 'SMILES' not in preds.columns: | |
| preds['SMILES'] = df_valid['SMILES'].values | |
| # Merge predictions with valid molecules | |
| df_results = pd.merge(df_valid, preds, on='SMILES', how='inner') | |
| # Set 'Molecule Name' as index | |
| df_results.set_index('Molecule Name', inplace=True) | |
| # Select only desired ADMET properties | |
| admet_properties = [ | |
| 'molecular weight', 'logP', 'hydrogen_bond_acceptors', | |
| 'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins' | |
| ] | |
| df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties] | |
| # Check if df_results_filtered is empty after filtering | |
| if df_results_filtered.empty: | |
| st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.") | |
| else: | |
| # Display Molecules | |
| st.subheader('Generated Molecules') | |
| cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns | |
| cols = st.columns(cols_per_row) | |
| for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()): | |
| smiles = row['SMILES'] | |
| # Attempt to encode to SAFE | |
| try: | |
| safe_string = safe.encode(smiles) | |
| except Exception as e: | |
| safe_string = None | |
| st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}") | |
| # Generate molecule image (SMILES or SAFE) | |
| img = generate_molecule_image(smiles) | |
| with cols[idx % cols_per_row]: | |
| if img is not None and isinstance(img, Image.Image): | |
| st.image(img, caption=mol_name) | |
| else: | |
| st.error(f"Could not generate image for {mol_name}") | |
| # Display SMILES string | |
| st.write("**SMILES:**") | |
| st.text(smiles) | |
| st_copy_button(smiles, key=f'copy_smiles_{mol_name}') | |
| # Display SAFE encoding if available | |
| if safe_string: | |
| st.write("**SAFE Encoding:**") | |
| st.text(safe_string) | |
| st_copy_button(safe_string, key=f'copy_safe_{mol_name}') | |
| # Optionally display SAFE visualization | |
| safe_img = generate_molecule_image(safe_string, use_safe=True) | |
| if safe_img is not None: | |
| st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)") | |
| # Display selected ADMET properties | |
| st.write("**ADMET Properties:**") | |
| admet_data = row.drop(['SMILES', 'Valid']) | |
| st.write(admet_data) | |
| else: | |
| st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |