Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import os | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from transformers import BartForConditionalGeneration, BartTokenizer | |
| from admet_ai import ADMETModel | |
| import safe | |
| import io | |
| from PIL import Image | |
| import cairosvg | |
| import pandas as pd | |
| # Page Configuration | |
| st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide') | |
| # Load Models | |
| def load_models(): | |
| # Load your molecule generation model | |
| model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path | |
| access_token = os.getenv("HUGGING_FACE_TOKEN") | |
| model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token) | |
| tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token) | |
| # Load ADMET-AI model | |
| admet_model = ADMETModel() | |
| return model, tokenizer, admet_model | |
| model, tokenizer, admet_model = load_models() | |
| # Set Generation Parameters | |
| st.sidebar.header('Generation Parameters') | |
| creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1) | |
| num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5) | |
| # String Format Option | |
| string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE')) | |
| # Generate Molecules Button | |
| if st.button('Generate Molecules'): | |
| st.info("Generating molecules... Please wait.") | |
| # Generate molecules | |
| core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure | |
| output_ids = model.generate( | |
| tokenizer(core_smiles, return_tensors='pt').input_ids, | |
| max_length=128, | |
| temperature=creativity, | |
| do_sample=True, | |
| top_k=50, | |
| num_return_sequences=num_molecules | |
| ) | |
| generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] | |
| molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, num_molecules + 1)] | |
| generated_molecules = dict(zip(molecule_names, generated_smiles)) | |
| # ADMET Predictions | |
| preds = admet_model.predict(smiles=list(generated_molecules.values())) | |
| preds['Molecule Name'] = molecule_names | |
| preds.set_index('Molecule Name', inplace=True) | |
| # Display Molecules | |
| st.subheader('Generated Molecules') | |
| cols_per_row = min(5, num_molecules) | |
| cols = st.columns(cols_per_row) | |
| for idx, mol_name in enumerate(molecule_names): | |
| smiles = generated_molecules[mol_name] | |
| img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE')) | |
| with cols[idx % cols_per_row]: | |
| if isinstance(img, Image.Image): | |
| st.image(img, caption=mol_name) | |
| else: | |
| st.error(f"Could not generate image for {mol_name}") | |
| # Display molecule string | |
| string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles | |
| st.code(string_to_display) | |
| # Copy-to-clipboard functionality | |
| st_copy_button(string_to_display, key=f'copy_{mol_name}') | |
| # Display ADMET properties | |
| st.write("**ADMET Properties:**") | |
| st.write(preds.loc[mol_name]) | |
| else: | |
| st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |
| # Function Definitions | |
| def generate_molecule_image(input_string, use_safe_visualization=True): | |
| try: | |
| if use_safe_visualization: | |
| try: | |
| # Attempt to decode as SAFE string | |
| smiles = safe.decode(input_string) | |
| # Encode back to SAFE string | |
| safe_string = safe.encode(smiles) | |
| except Exception: | |
| # If decoding fails, assume input is SMILES and encode to SAFE | |
| safe_string = safe.encode(input_string) | |
| # Generate SVG image with fragment highlights | |
| svg_str = safe.to_image(safe_string) | |
| # Convert SVG to PNG bytes | |
| png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
| # Create an image object | |
| img = Image.open(io.BytesIO(png_bytes)) | |
| else: | |
| # Generate standard molecule image | |
| mol = Chem.MolFromSmiles(input_string) | |
| if mol: | |
| img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size | |
| else: | |
| img = None | |
| return img | |
| except Exception as e: | |
| # Collect exceptions for later reporting | |
| return e | |
| import streamlit.components.v1 as components | |
| def st_copy_button(text, key): | |
| """Creates a copy-to-clipboard button.""" | |
| components.html(f""" | |
| <button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button> | |
| """, height=45) | |