bcadkins01 commited on
Commit
e295393
·
verified ·
1 Parent(s): 276c5ce

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +122 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ from rdkit import Chem
5
+ from rdkit.Chem import Draw
6
+ from transformers import BartForConditionalGeneration, BartTokenizer
7
+ from admet_ai import ADMETModel
8
+ import safe
9
+ import io
10
+ from PIL import Image
11
+ import cairosvg
12
+ import pandas as pd
13
+
14
+ # Page Configuration
15
+ st.set_page_config(page_title='Beta-Lactam Molecule Generator', layout='wide')
16
+
17
+ # Load Models
18
+ @st.cache_resource(show_spinner="Loading Models...", ttl=600)
19
+ def load_models():
20
+ # Load your molecule generation model
21
+ model_name = "your-new-beta-lactam-model-path" # Replace with your actual model path
22
+ access_token = os.getenv("HUGGING_FACE_TOKEN")
23
+ model = BartForConditionalGeneration.from_pretrained(model_name, use_auth_token=access_token)
24
+ tokenizer = BartTokenizer.from_pretrained(model_name, use_auth_token=access_token)
25
+ # Load ADMET-AI model
26
+ admet_model = ADMETModel()
27
+ return model, tokenizer, admet_model
28
+
29
+ model, tokenizer, admet_model = load_models()
30
+
31
+ # Set Generation Parameters
32
+ st.sidebar.header('Generation Parameters')
33
+ creativity = st.sidebar.slider('Creativity (Temperature):', 0.0, 2.0, 1.0, step=0.1)
34
+ num_molecules = st.sidebar.number_input('Number of Molecules to Generate:', min_value=1, max_value=5, value=5)
35
+
36
+ # String Format Option
37
+ string_format = st.sidebar.radio('String Format:', ('SMILES', 'SAFE'))
38
+
39
+ # Generate Molecules Button
40
+ if st.button('Generate Molecules'):
41
+ st.info("Generating molecules... Please wait.")
42
+ # Generate molecules
43
+ core_smiles = "C1C(=O)N(C)C(=O)C1" # Beta-lactam core structure
44
+ output_ids = model.generate(
45
+ tokenizer(core_smiles, return_tensors='pt').input_ids,
46
+ max_length=128,
47
+ temperature=creativity,
48
+ do_sample=True,
49
+ top_k=50,
50
+ num_return_sequences=num_molecules
51
+ )
52
+ generated_smiles = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
53
+ molecule_names = [f"Mol{str(i).zfill(2)}" for i in range(1, num_molecules + 1)]
54
+ generated_molecules = dict(zip(molecule_names, generated_smiles))
55
+
56
+ # ADMET Predictions
57
+ preds = admet_model.predict(smiles=list(generated_molecules.values()))
58
+ preds['Molecule Name'] = molecule_names
59
+ preds.set_index('Molecule Name', inplace=True)
60
+
61
+ # Display Molecules
62
+ st.subheader('Generated Molecules')
63
+ cols_per_row = min(5, num_molecules)
64
+ cols = st.columns(cols_per_row)
65
+ for idx, mol_name in enumerate(molecule_names):
66
+ smiles = generated_molecules[mol_name]
67
+ img = generate_molecule_image(smiles, use_safe_visualization=(string_format == 'SAFE'))
68
+ with cols[idx % cols_per_row]:
69
+ if isinstance(img, Image.Image):
70
+ st.image(img, caption=mol_name)
71
+ else:
72
+ st.error(f"Could not generate image for {mol_name}")
73
+ # Display molecule string
74
+ string_to_display = safe.encode(smiles) if string_format == 'SAFE' else smiles
75
+ st.code(string_to_display)
76
+ # Copy-to-clipboard functionality
77
+ st_copy_button(string_to_display, key=f'copy_{mol_name}')
78
+ # Display ADMET properties
79
+ st.write("**ADMET Properties:**")
80
+ st.write(preds.loc[mol_name])
81
+ else:
82
+ st.write("Click the 'Generate Molecules' button to generate beta-lactam molecules.")
83
+
84
+ # Function Definitions
85
+ def generate_molecule_image(input_string, use_safe_visualization=True):
86
+ try:
87
+ if use_safe_visualization:
88
+ try:
89
+ # Attempt to decode as SAFE string
90
+ smiles = safe.decode(input_string)
91
+ # Encode back to SAFE string
92
+ safe_string = safe.encode(smiles)
93
+ except Exception:
94
+ # If decoding fails, assume input is SMILES and encode to SAFE
95
+ safe_string = safe.encode(input_string)
96
+ # Generate SVG image with fragment highlights
97
+ svg_str = safe.to_image(safe_string)
98
+ # Convert SVG to PNG bytes
99
+ png_bytes = cairosvg.svg2png(bytestring=svg_str.encode('utf-8'))
100
+ # Create an image object
101
+ img = Image.open(io.BytesIO(png_bytes))
102
+ else:
103
+ # Generate standard molecule image
104
+ mol = Chem.MolFromSmiles(input_string)
105
+ if mol:
106
+ img = Draw.MolToImage(mol, size=(200, 200)) # Adjusted size
107
+ else:
108
+ img = None
109
+ return img
110
+ except Exception as e:
111
+ # Collect exceptions for later reporting
112
+ return e
113
+
114
+ import streamlit.components.v1 as components
115
+
116
+ def st_copy_button(text, key):
117
+ """Creates a copy-to-clipboard button."""
118
+ components.html(f"""
119
+ <button onclick="navigator.clipboard.writeText('{text}')" style="padding:5px;">Copy</button>
120
+ """, height=45)
121
+
122
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ rdkit-pypi
3
+ numpy
4
+ torch
5
+ transformers
6
+ admet-ai
7
+ safe-encoding
8
+ pandas
9
+ cairosvg
10
+ Pillow