Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import streamlit as st | |
| import torch | |
| import os | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from transformers import BartForConditionalGeneration, BartTokenizer | |
| from admet_ai import ADMETModel | |
| import safe | |
| import io | |
| from PIL import Image | |
| import cairosvg | |
| import pandas as pd | |
| # Page Configuration | |
| st.set_page_config( | |
| page_title='Beta-Lactam Molecule Generator', | |
| layout='wide' | |
| ) | |
| # Load Models | |
| def load_models(): | |
| """ | |
| Load the molecule generation model and the ADMET-AI model. | |
| Caches the models to avoid reloading on every run. | |
| """ | |
| # Load your molecule generation model | |
| model_name = "bcadkins01/beta_lactam_generator" | |
| access_token = os.getenv("HUGGING_FACE_TOKEN") | |
| if access_token is None: | |
| st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
| st.stop() | |
| model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token) | |
| tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token) | |
| # Load ADMET-AI model | |
| admet_model = ADMETModel() | |
| return model, tokenizer, admet_model | |
| # Load models once and reuse | |
| model, tokenizer, admet_model = load_models() | |
| # Set Generation Parameters in Sidebar | |
| st.sidebar.header('Generation Parameters') | |
| # Creativity Slider (Temperature) | |
| creativity = st.sidebar.slider( | |
| 'Creativity (Temperature):', | |
| min_value=0.0, | |
| max_value=2.4, | |
| value=1.0, | |
| step=0.2, | |
| help="Higher values lead to more diverse (or wild) outputs." | |
| ) | |
| # Number of Molecules to Generate | |
| num_molecules = st.sidebar.number_input( | |
| 'Number of Molecules to Generate:', | |
| min_value=1, | |
| max_value=3, # Adjust as needed | |
| value=3, | |
| help="Select the number of molecules you want to generate (up to 3)." | |
| ) | |
| # Function to Generate Molecule Images | |
| def generate_molecule_image(input_string, use_safe=False): | |
| """ | |
| Generates an image of the molecule from the input string. | |
| If use_safe is True, input_string is treated as a SAFE string. | |
| """ | |
| try: | |
| if use_safe and input_string is not None: | |
| # Generate image from SAFE encoding | |
| svg_str = safe.to_image(input_string) | |
| # Convert SVG to PNG bytes | |
| png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8')) | |
| # Create an image object | |
| img = Image.open(io.BytesIO(png_bytes)) | |
| else: | |
| # Generate standard molecule image | |
| mol = Chem.MolFromSmiles(input_string) | |
| if mol: | |
| img = Draw.MolToImage(mol, size=(250, 250)) | |
| else: | |
| img = None | |
| return img | |
| except Exception as e: | |
| st.error(f"Error generating molecule image: {e}") | |
| return None | |
| # Generate Molecules Button | |
| if st.button('Generate Molecules'): | |
| st.info("Generating molecules... Please wait.") | |
| # Beta-lactam core structure | |
| core_smiles = "C1C(=O)N(C)C(=O)C1" | |
| # Tokenize the core SMILES | |
| input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids | |
| # Generate molecules using the model with diverse beam search | |
| output_ids = model.generate( | |
| input_ids=input_ids, | |
| max_length=128, | |
| do_sample=True, | |
| temperature=1.2, # Increase for more diversity | |
| top_k=0, # Disable top-k sampling | |
| top_p=0.9, # Enable nucleus (top-p) sampling | |
| num_return_sequences=num_molecules, | |
| num_beams=1 | |
| ) | |
| # Decode generated molecule SMILES | |
| generated_smiles = [ | |
| tokenizer.decode(ids, skip_special_tokens=True) | |
| for ids in output_ids | |
| ] | |
| # Create generic molecule names for demo | |
| molecule_names = [ | |
| f"Mol{str(i).zfill(2)}" | |
| for i in range(1, len(generated_smiles) + 1) | |
| ] | |
| # Create df for generated molecules | |
| df_molecules = pd.DataFrame({ | |
| 'Molecule Name': molecule_names, | |
| 'SMILES': generated_smiles | |
| }) | |
| # Invalid SMILES Check | |
| # Function to validate SMILES | |
| def is_valid_smile(smile): | |
| return Chem.MolFromSmiles(smile) is not None | |
| # Apply validation function | |
| df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile) | |
| df_valid = df_molecules[df_molecules['Valid']].copy() | |
| # Inform user if any molecules were invalid | |
| invalid_molecules = df_molecules[~df_molecules['Valid']] | |
| if not invalid_molecules.empty: | |
| st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.") | |
| # Check if there are valid molecules to proceed | |
| if df_valid.empty: | |
| st.error("No valid molecules were generated. Please try adjusting the generation parameters.") | |
| else: | |
| # ADMET Predictions | |
| preds = admet_model.predict(smiles=df_valid['SMILES'].tolist()) | |
| # Ensure 'SMILES' is a column in preds | |
| if 'SMILES' not in preds.columns: | |
| preds['SMILES'] = df_valid['SMILES'].values | |
| # Merge predictions with valid molecules | |
| df_results = pd.merge(df_valid, preds, on='SMILES', how='inner') | |
| # Set 'Molecule Name' as index | |
| df_results.set_index('Molecule Name', inplace=True) | |
| # Select only desired ADMET properties | |
| admet_properties = [ | |
| 'molecular_weight', 'logP', 'hydrogen_bond_acceptors', | |
| 'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins' | |
| ] | |
| df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties] | |
| # Check if df_results_filtered is empty after filtering | |
| if df_results_filtered.empty: | |
| st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.") | |
| else: | |
| # Display Molecules | |
| st.subheader('Generated Molecules') | |
| cols_per_row = min(3, len(df_results_filtered)) # Max 3 columns | |
| cols = st.columns(cols_per_row) | |
| for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()): | |
| smiles = row['SMILES'] | |
| # Attempt to encode to SAFE | |
| try: | |
| safe_string = safe.encode(smiles) | |
| except Exception as e: | |
| safe_string = None | |
| st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}") | |
| # Generate molecule image (SMILES or SAFE) | |
| img = generate_molecule_image(smiles) | |
| with cols[idx % cols_per_row]: | |
| if img is not None and isinstance(img, Image.Image): | |
| st.image(img, caption=mol_name) | |
| else: | |
| st.error(f"Could not generate image for {mol_name}") | |
| # Display SMILES string | |
| st.write("**SMILES:**") | |
| st.text(smiles) | |
| # Display SAFE encoding if available | |
| if safe_string: | |
| st.write("**SAFE Encoding:**") | |
| st.text(safe_string) | |
| # Optionally display SAFE visualization | |
| safe_img = generate_molecule_image(safe_string, use_safe=True) | |
| if safe_img is not None: | |
| st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)") | |
| # Display selected ADMET properties | |
| st.write("**ADMET Properties:**") | |
| admet_data = row.drop(['SMILES', 'Valid']) | |
| st.write(admet_data) | |
| else: | |
| st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.") | |
