Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -13,20 +13,20 @@ import pandas as pd
|
|
13 |
import streamlit.components.v1 as components
|
14 |
import json # For safely encoding text in JavaScript
|
15 |
|
16 |
-
#
|
17 |
st.set_page_config(
|
18 |
page_title='Beta-Lactam Molecule Generator',
|
19 |
layout='wide'
|
20 |
)
|
21 |
|
22 |
-
#
|
23 |
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
|
24 |
def load_models():
|
25 |
"""
|
26 |
Load the molecule generation model and the ADMET-AI model.
|
27 |
Caches the models to avoid reloading on every run.
|
28 |
"""
|
29 |
-
#
|
30 |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
|
31 |
access_token = os.getenv("HUGGING_FACE_TOKEN")
|
32 |
if access_token is None:
|
@@ -35,24 +35,24 @@ def load_models():
|
|
35 |
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
|
36 |
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
|
37 |
|
38 |
-
#
|
39 |
admet_model = ADMETModel()
|
40 |
|
41 |
return model, tokenizer, admet_model
|
42 |
|
43 |
-
#
|
44 |
model, tokenizer, admet_model = load_models()
|
45 |
|
46 |
-
#
|
47 |
st.sidebar.header('Generation Parameters')
|
48 |
|
49 |
-
#
|
50 |
creativity = st.sidebar.slider(
|
51 |
'Creativity (Temperature):',
|
52 |
min_value=0.0,
|
53 |
-
max_value=2.
|
54 |
value=1.0,
|
55 |
-
step=0.
|
56 |
help="Higher values lead to more diverse outputs."
|
57 |
)
|
58 |
|
@@ -65,7 +65,7 @@ num_molecules = st.sidebar.number_input(
|
|
65 |
help="Select the number of molecules you want to generate (up to 3)."
|
66 |
)
|
67 |
|
68 |
-
#
|
69 |
def generate_molecule_image(input_string, use_safe=False):
|
70 |
"""
|
71 |
Generates an image of the molecule from the input string.
|
@@ -91,10 +91,10 @@ def generate_molecule_image(input_string, use_safe=False):
|
|
91 |
st.error(f"Error generating molecule image: {e}")
|
92 |
return None
|
93 |
|
94 |
-
#
|
95 |
def st_copy_button(text, key):
|
96 |
-
"""Creates a copy-to-clipboard button
|
97 |
-
#
|
98 |
escaped_text = json.dumps(text)
|
99 |
button_html = f"""
|
100 |
<div style="text-align: right; margin-top: -10px; margin-bottom: 10px;">
|
@@ -105,17 +105,17 @@ def st_copy_button(text, key):
|
|
105 |
"""
|
106 |
components.html(button_html, height=35)
|
107 |
|
108 |
-
#
|
109 |
if st.button('Generate Molecules'):
|
110 |
st.info("Generating molecules... Please wait.")
|
111 |
|
112 |
-
#
|
113 |
core_smiles = "C1C(=O)N(C)C(=O)C1"
|
114 |
|
115 |
-
#
|
116 |
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
|
117 |
|
118 |
-
#
|
119 |
output_ids = model.generate(
|
120 |
input_ids=input_ids,
|
121 |
max_length=128,
|
@@ -126,25 +126,25 @@ if st.button('Generate Molecules'):
|
|
126 |
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
|
127 |
)
|
128 |
|
129 |
-
#
|
130 |
generated_smiles = [
|
131 |
tokenizer.decode(ids, skip_special_tokens=True)
|
132 |
for ids in output_ids
|
133 |
]
|
134 |
|
135 |
-
#
|
136 |
molecule_names = [
|
137 |
f"Mol{str(i).zfill(2)}"
|
138 |
for i in range(1, len(generated_smiles) + 1)
|
139 |
]
|
140 |
|
141 |
-
#
|
142 |
df_molecules = pd.DataFrame({
|
143 |
'Molecule Name': molecule_names,
|
144 |
'SMILES': generated_smiles
|
145 |
})
|
146 |
|
147 |
-
#
|
148 |
# Function to validate SMILES
|
149 |
def is_valid_smile(smile):
|
150 |
return Chem.MolFromSmiles(smile) is not None
|
|
|
13 |
import streamlit.components.v1 as components
|
14 |
import json # For safely encoding text in JavaScript
|
15 |
|
16 |
+
# Page Configuration
|
17 |
st.set_page_config(
|
18 |
page_title='Beta-Lactam Molecule Generator',
|
19 |
layout='wide'
|
20 |
)
|
21 |
|
22 |
+
# Load Models
|
23 |
@st.cache_resource(show_spinner="Loading Models...", ttl=600)
|
24 |
def load_models():
|
25 |
"""
|
26 |
Load the molecule generation model and the ADMET-AI model.
|
27 |
Caches the models to avoid reloading on every run.
|
28 |
"""
|
29 |
+
# Load your molecule generation model
|
30 |
model_name = "bcadkins01/beta_lactam_generator" # Replace with your actual model path
|
31 |
access_token = os.getenv("HUGGING_FACE_TOKEN")
|
32 |
if access_token is None:
|
|
|
35 |
model = BartForConditionalGeneration.from_pretrained(model_name, token=access_token)
|
36 |
tokenizer = BartTokenizer.from_pretrained(model_name, token=access_token)
|
37 |
|
38 |
+
# Load ADMET-AI model
|
39 |
admet_model = ADMETModel()
|
40 |
|
41 |
return model, tokenizer, admet_model
|
42 |
|
43 |
+
# Load models once and reuse
|
44 |
model, tokenizer, admet_model = load_models()
|
45 |
|
46 |
+
# Set Generation Parameters in Sidebar
|
47 |
st.sidebar.header('Generation Parameters')
|
48 |
|
49 |
+
# Creativity Slider (Temperature)
|
50 |
creativity = st.sidebar.slider(
|
51 |
'Creativity (Temperature):',
|
52 |
min_value=0.0,
|
53 |
+
max_value=2.4,
|
54 |
value=1.0,
|
55 |
+
step=0.2,
|
56 |
help="Higher values lead to more diverse outputs."
|
57 |
)
|
58 |
|
|
|
65 |
help="Select the number of molecules you want to generate (up to 3)."
|
66 |
)
|
67 |
|
68 |
+
# Function to Generate Molecule Images
|
69 |
def generate_molecule_image(input_string, use_safe=False):
|
70 |
"""
|
71 |
Generates an image of the molecule from the input string.
|
|
|
91 |
st.error(f"Error generating molecule image: {e}")
|
92 |
return None
|
93 |
|
94 |
+
# Function to Create Copy-to-Clipboard Button
|
95 |
def st_copy_button(text, key):
|
96 |
+
"""Creates a copy-to-clipboard button."""
|
97 |
+
# Encode the text for JavaScript
|
98 |
escaped_text = json.dumps(text)
|
99 |
button_html = f"""
|
100 |
<div style="text-align: right; margin-top: -10px; margin-bottom: 10px;">
|
|
|
105 |
"""
|
106 |
components.html(button_html, height=35)
|
107 |
|
108 |
+
# Generate Molecules Button
|
109 |
if st.button('Generate Molecules'):
|
110 |
st.info("Generating molecules... Please wait.")
|
111 |
|
112 |
+
# Beta-lactam core structure
|
113 |
core_smiles = "C1C(=O)N(C)C(=O)C1"
|
114 |
|
115 |
+
# Tokenize the core SMILES
|
116 |
input_ids = tokenizer(core_smiles, return_tensors='pt').input_ids
|
117 |
|
118 |
+
# Generate molecules using the model
|
119 |
output_ids = model.generate(
|
120 |
input_ids=input_ids,
|
121 |
max_length=128,
|
|
|
126 |
num_beams=max(num_molecules, 5) # Ensure num_beams >= num_return_sequences
|
127 |
)
|
128 |
|
129 |
+
# Decode generated molecule SMILES
|
130 |
generated_smiles = [
|
131 |
tokenizer.decode(ids, skip_special_tokens=True)
|
132 |
for ids in output_ids
|
133 |
]
|
134 |
|
135 |
+
# Create generic molecule names for demo
|
136 |
molecule_names = [
|
137 |
f"Mol{str(i).zfill(2)}"
|
138 |
for i in range(1, len(generated_smiles) + 1)
|
139 |
]
|
140 |
|
141 |
+
# Create df for generated molecules
|
142 |
df_molecules = pd.DataFrame({
|
143 |
'Molecule Name': molecule_names,
|
144 |
'SMILES': generated_smiles
|
145 |
})
|
146 |
|
147 |
+
# Invalid SMILES Check
|
148 |
# Function to validate SMILES
|
149 |
def is_valid_smile(smile):
|
150 |
return Chem.MolFromSmiles(smile) is not None
|