File size: 7,903 Bytes
e295393
 
 
 
87e6612
e295393
 
 
 
 
 
 
 
83d667e
111a2ce
 
 
 
e295393
83d667e
e295393
 
111a2ce
 
 
 
83d667e
0db5108
e295393
111a2ce
 
 
 
 
 
83d667e
e295393
111a2ce
e295393
 
83d667e
e295393
 
83d667e
e295393
 
83d667e
111a2ce
 
 
83d667e
111a2ce
83d667e
0db5108
111a2ce
e295393
0db5108
111a2ce
 
 
231b5fc
a3a6043
 
111a2ce
085e6ac
83d667e
87e6612
111a2ce
87e6612
 
111a2ce
e295393
87e6612
 
 
 
 
 
 
e295393
87e6612
 
 
231b5fc
87e6612
 
e295393
 
111a2ce
 
e295393
83d667e
111a2ce
 
 
83d667e
111a2ce
 
83d667e
111a2ce
 
313eced
111a2ce
58b2e8d
 
4a19ce8
 
 
 
58b2e8d
4a19ce8
111a2ce
313eced
4a19ce8
111a2ce
83d667e
111a2ce
 
 
 
 
83d667e
111a2ce
 
 
 
 
83d667e
111a2ce
 
 
 
 
83d667e
87e6612
111a2ce
 
 
87e6612
111a2ce
 
 
87e6612
111a2ce
 
 
 
87e6612
111a2ce
 
 
87e6612
111a2ce
 
87e6612
111a2ce
 
 
87e6612
111a2ce
 
87e6612
111a2ce
 
87e6612
a3a6043
231b5fc
a3a6043
 
87e6612
a3a6043
87e6612
a3a6043
 
111a2ce
87e6612
111a2ce
a3a6043
111a2ce
 
a3a6043
111a2ce
a3a6043
87e6612
 
 
 
 
 
 
 
 
a3a6043
111a2ce
 
 
 
 
a3a6043
87e6612
 
 
a3a6043
87e6612
 
 
 
 
 
 
 
a3a6043
87e6612
111a2ce
a3a6043
 
111a2ce
 
0db5108
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import streamlit as st
import torch
import os
from rdkit import Chem
from rdkit.Chem import Draw
from transformers import BartForConditionalGeneration, BartTokenizer
from admet_ai import ADMETModel
import safe
import io
from PIL import Image
import cairosvg
import pandas as pd

# Page Configuration
st.set_page_config(
    page_title='Beta-Lactam Molecule Generator',
    layout='wide'
)

# Load Models
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
def load_models():
    """
    Load the molecule generation model and the ADMET-AI model.
    Caches the models to avoid reloading on every run.
    """
    # Load your molecule generation model
    model_name = "bcadkins01/beta_lactam_generator"  
    access_token = os.getenv("HUGGING_FACE_TOKEN")
    if access_token is None:
        st.error("Access token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
        st.stop()
    model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
    tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
    
    # Load ADMET-AI model
    admet_model = ADMETModel()
    
    return model, tokenizer, admet_model

# Load models once and reuse
model, tokenizer, admet_model = load_models()

# Set Generation Parameters in Sidebar
st.sidebar.header('Generation Parameters')

# Creativity Slider (Temperature)
creativity = st.sidebar.slider(
    'Creativity (Temperature):',
    min_value=0.0,
    max_value=2.4,
    value=1.0,
    step=0.2,
    help="Higher values lead to more diverse (or wild) outputs."
)

# Number of Molecules to Generate
num_molecules = st.sidebar.number_input(
    'Number of Molecules to Generate:',
    min_value=1,
    max_value=3,  # Adjust as needed
    value=3,
    help="Select the number of molecules you want to generate (up to 3)."
)

# Function to Generate Molecule Images
def generate_molecule_image(input_string, use_safe=False):
    """
    Generates an image of the molecule from the input string.
    If use_safe is True, input_string is treated as a SAFE string.
    """
    try:
        if use_safe and input_string is not None:
            # Generate image from SAFE encoding
            svg_str = safe.to_image(input_string)
            # Convert SVG to PNG bytes
            png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
            # Create an image object
            img = Image.open(io.BytesIO(png_bytes))
        else:
            # Generate standard molecule image
            mol = Chem.MolFromSmiles(input_string)
            if mol:
                img = Draw.MolToImage(mol, size=(250, 250))
            else:
                img = None
        return img
    except Exception as e:
        st.error(f"Error generating molecule image: {e}")
        return None

# Generate Molecules Button
if st.button('Generate Molecules'):
    st.info("Generating molecules... Please wait.")
    
    # Beta-lactam core structure
    core_smiles = "C1C(=O)N(C)C(=O)C1"
    
    # Tokenize the core SMILES
    input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
    
    # Generate molecules using the model with diverse beam search
    output_ids = model.generate(
        input_ids=input_ids,
        max_length=128,
        do_sample=True,
        temperature=1.2,  # Increase for more diversity
        top_k=0,          # Disable top-k sampling
        top_p=0.9,        # Enable nucleus (top-p) sampling
        num_return_sequences=num_molecules,
        num_beams=1
    )


    
    # Decode generated molecule SMILES
    generated_smiles = [
        tokenizer.decode(ids, skip_special_tokens=True)
        for ids in output_ids
    ]
    
    # Create generic molecule names for demo
    molecule_names = [
        f"Mol{str(i).zfill(2)}"
        for i in range(1, len(generated_smiles) + 1)
    ]
    
    # Create df for generated molecules
    df_molecules = pd.DataFrame({
        'Molecule Name': molecule_names,
        'SMILES': generated_smiles
    })
    
    # Invalid SMILES Check
    # Function to validate SMILES
    def is_valid_smile(smile):
        return Chem.MolFromSmiles(smile) is not None
    
    # Apply validation function
    df_molecules['Valid'] = df_molecules['SMILES'].apply(is_valid_smile)
    df_valid = df_molecules[df_molecules['Valid']].copy()
    
    # Inform user if any molecules were invalid
    invalid_molecules = df_molecules[~df_molecules['Valid']]
    if not invalid_molecules.empty:
        st.warning(f"{len(invalid_molecules)} generated molecules were invalid and excluded from predictions.")
    
    # Check if there are valid molecules to proceed
    if df_valid.empty:
        st.error("No valid molecules were generated. Please try adjusting the generation parameters.")
    else:
        # ADMET Predictions
        preds = admet_model.predict(smiles=df_valid['SMILES'].tolist())
        
        # Ensure 'SMILES' is a column in preds
        if 'SMILES' not in preds.columns:
            preds['SMILES'] = df_valid['SMILES'].values
        
        # Merge predictions with valid molecules
        df_results = pd.merge(df_valid, preds, on='SMILES', how='inner')
        
        # Set 'Molecule Name' as index
        df_results.set_index('Molecule Name', inplace=True)
        
        # Select only desired ADMET properties
        admet_properties = [
            'molecular_weight', 'logP', 'hydrogen_bond_acceptors',
            'hydrogen_bond_donors', 'QED', 'ClinTox', 'hERG', 'BBB_Martins'
        ]
        df_results_filtered = df_results[['SMILES', 'Valid'] + admet_properties]
        
        # Check if df_results_filtered is empty after filtering
        if df_results_filtered.empty:
            st.error("No valid ADMET predictions were obtained. Please try adjusting the generation parameters.")
        else:
            # Display Molecules
            st.subheader('Generated Molecules')
            cols_per_row = min(3, len(df_results_filtered))  # Max 3 columns
            cols = st.columns(cols_per_row)
            
            for idx, (mol_name, row) in enumerate(df_results_filtered.iterrows()):
                smiles = row['SMILES']
                
                # Attempt to encode to SAFE
                try:
                    safe_string = safe.encode(smiles)
                except Exception as e:
                    safe_string = None
                    st.error(f"Could not convert to SAFE encoding for {mol_name}: {e}")
                
                # Generate molecule image (SMILES or SAFE)
                img = generate_molecule_image(smiles)
                
                with cols[idx % cols_per_row]:
                    if img is not None and isinstance(img, Image.Image):
                        st.image(img, caption=mol_name)
                    else:
                        st.error(f"Could not generate image for {mol_name}")
                    
                    # Display SMILES string
                    st.write("**SMILES:**")
                    st.text(smiles)
                    
                    # Display SAFE encoding if available
                    if safe_string:
                        st.write("**SAFE Encoding:**")
                        st.text(safe_string)
                        # Optionally display SAFE visualization
                        safe_img = generate_molecule_image(safe_string, use_safe=True)
                        if safe_img is not None:
                            st.image(safe_img, caption=f"{mol_name} (SAFE Visualization)")
                    
                    # Display selected ADMET properties
                    st.write("**ADMET Properties:**")
                    admet_data = row.drop(['SMILES', 'Valid'])
                    st.write(admet_data)
else:
    st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")