fm4m-eval-demo / app.py
ipd's picture
init
5306c2a
raw
history blame
18.2 kB
import gradio as gr
import numpy as np
import pandas as pd
from tempfile import NamedTemporaryFile
from PIL import Image
from rdkit import RDLogger
from sklearn.model_selection import train_test_split
from molecule_generation_helpers import *
from property_prediction_helpers import *
DEBUG_VISIBLE = False
RDLogger.logger().setLevel(RDLogger.ERROR)
# Predefined dataset paths (these should be adjusted to your file paths)
predefined_datasets = {
" ": " ",
"BACE": "./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
"ESOL": "./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
}
# Models
models_enabled = [
"MorganFingerprint",
"SMI-TED",
"SELFIES-TED",
"MHG-GED",
]
blank_df = pd.DataFrame({"id": [], "Model": [], "Score": []})
# Function to load a predefined dataset from the local path
def load_predefined_dataset(dataset_name):
val = predefined_datasets.get(dataset_name)
if val:
try:
df = pd.read_csv(val.split(",")[0])
return (
df.head(),
gr.update(choices=list(df.columns), value=None),
gr.update(choices=list(df.columns), value=None),
dataset_name.lower(),
)
except:
pass
else:
dataset_name = "Custom"
return (
pd.DataFrame(),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
dataset_name.lower(),
)
# Function to handle dataset selection (predefined or custom)
def handle_dataset_selection(selected_dataset, state):
state["dataset_name"] = (
selected_dataset if selected_dataset in predefined_datasets else "CUSTOM"
)
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
task_type = (
"Classification"
if selected_dataset == "BACE"
else "Regression" if selected_dataset == "ESOL" else None
)
return (
gr.update(visible=selected_dataset not in predefined_datasets or DEBUG_VISIBLE),
task_type,
)
# Function to select input and output columns and display a message
def select_columns(input_column, output_column, train_data, test_data, state):
if train_data and test_data and input_column and output_column:
return f"{train_data.name},{test_data.name},{input_column},{output_column},{state['dataset_name']}"
return gr.update()
# Function to display the head of the uploaded CSV file
def display_csv_head(file):
if file is not None:
# Load the CSV file into a DataFrame
df = pd.read_csv(file.name)
return (
df.head(),
gr.update(choices=list(df.columns)),
gr.update(choices=list(df.columns)),
)
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
def process_custom_file(file, selected_dataset):
if file and os.path.getsize(file.name) < 50 * 1024:
df = pd.read_csv(file.name)
if "input" in df.columns and "output" in df.columns:
train, test = train_test_split(df, test_size=0.2)
with NamedTemporaryFile(
prefix="fm4m-train-", suffix=".csv", delete=False
) as train_file:
train.to_csv(train_file.name, index=False)
with NamedTemporaryFile(
prefix="fm4m-test-", suffix=".csv", delete=False
) as test_file:
test.to_csv(test_file.name, index=False)
task_type = (
"Classification" if df["output"].dtype == np.int64 else "Regression"
)
return train_file.name, test_file.name, "input", "output", task_type
return (
None,
None,
None,
None,
gr.update() if selected_dataset in predefined_datasets else None,
)
def update_plot_choices(current, state):
choices = []
if state.get("roc_auc") is not None:
choices.append("ROC-AUC")
if state.get("RMSE") is not None:
choices.append("Parity Plot")
if state.get("x_batch") is not None:
choices.append("Latent Space")
if current in choices:
return gr.update(choices=choices)
return gr.update(choices=choices, value=None if len(choices) == 0 else choices[0])
def log_selected(df: pd.DataFrame, evt: gr.SelectData, state):
state.update(state["results"].get(df.at[evt.index[0], 'id'], {}))
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
smiles_image_mapping = {
# Example SMILES for ethanol
"Mol 1": {
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
"image": "img/img1.png",
},
# Example SMILES for butane
"Mol 2": {
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
"image": "img/img2.png",
},
# Example SMILES for ethylamine
"Mol 3": {
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
"image": "img/img3.png",
},
# Example SMILES for diethyl ether
"Mol 4": {
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
"image": "img/img4.png",
},
# Example SMILES for chloroethane
"Mol 5": {
"smiles": "C=CCS[C@@H](C)CC(=O)OCC",
"image": "img/img5.png",
},
}
# Load images for selection
def load_image(path):
try:
return Image.open(smiles_image_mapping[path]["image"])
except:
pass
# Function to handle image selection
def handle_image_selection(image_key):
if not image_key:
return None, None
smiles = smiles_image_mapping[image_key]["smiles"]
mol_image = smiles_to_image(smiles)
return smiles, mol_image
# Introduction
with gr.Blocks() as introduction:
with open("INTRODUCTION.md") as f:
gr.Markdown(f.read(), sanitize_html=False)
# Property Prediction
with gr.Blocks() as property_prediction:
state = gr.State({"model_name": "Default - Auto", "results": {}})
gr.HTML(
'''
<p style="text-align: center">
Task : Property Prediction
<br>
Models are finetuned with different combination of modalities on the uploaded or selected built data set.
</p>
'''
)
with gr.Row():
with gr.Column():
# Dropdown menu for predefined datasets including "Custom Dataset" option
dataset_selector = gr.Dropdown(
label="Select Dataset",
choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
)
# Display the message for selected columns
selected_columns_message = gr.Textbox(
label="Selected Columns Info", visible=DEBUG_VISIBLE
)
with gr.Accordion(
"Custom Dataset Settings", open=True, visible=DEBUG_VISIBLE
) as settings:
# File upload options for custom dataset (train and test)
custom_file = gr.File(
label="Upload Custom Dataset",
file_types=[".csv"],
)
train_file = gr.File(
label="Upload Custom Train Dataset",
file_types=[".csv"],
visible=False,
)
train_display = gr.Dataframe(
label="Train Dataset Preview (First 5 Rows)",
interactive=False,
visible=DEBUG_VISIBLE,
)
test_file = gr.File(
label="Upload Custom Test Dataset",
file_types=[".csv"],
visible=False,
)
test_display = gr.Dataframe(
label="Test Dataset Preview (First 5 Rows)",
interactive=False,
visible=DEBUG_VISIBLE,
)
# Predefined dataset displays
predefined_display = gr.Dataframe(
label="Predefined Dataset Preview (First 5 Rows)",
interactive=False,
visible=DEBUG_VISIBLE,
)
# Dropdowns for selecting input and output columns for the custom dataset
input_column_selector = gr.Dropdown(
label="Select Input Column",
choices=[],
allow_custom_value=True,
visible=DEBUG_VISIBLE,
)
output_column_selector = gr.Dropdown(
label="Select Output Column",
choices=[],
allow_custom_value=True,
visible=DEBUG_VISIBLE,
)
# When a custom train file is uploaded, display its head and update column selectors
train_file.change(
display_csv_head,
inputs=train_file,
outputs=[
train_display,
input_column_selector,
output_column_selector,
],
)
# When a custom test file is uploaded, display its head
test_file.change(
display_csv_head,
inputs=test_file,
outputs=[
test_display,
input_column_selector,
output_column_selector,
],
)
model_checkbox = gr.CheckboxGroup(
choices=models_enabled, label="Select Model", visible=DEBUG_VISIBLE
)
task_radiobutton = gr.Radio(
choices=["Classification", "Regression"],
label="Task Type",
visible=DEBUG_VISIBLE,
)
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
# When a predefined dataset is selected, load its head and update column selectors
dataset_selector.change(lambda: None, outputs=custom_file).then(
handle_dataset_selection,
inputs=[dataset_selector, state],
outputs=[settings, task_radiobutton],
).then(
load_predefined_dataset,
inputs=dataset_selector,
outputs=[
predefined_display,
input_column_selector,
output_column_selector,
selected_columns_message,
],
)
custom_file.change(
process_custom_file,
inputs=[custom_file, dataset_selector],
outputs=[
train_file,
test_file,
input_column_selector,
output_column_selector,
task_radiobutton,
],
)
eval_clear_button = gr.Button("Clear")
eval_button = gr.Button("Submit", variant="primary")
step_slider = gr.Slider(
minimum=0,
maximum=8,
value=0,
label="Progress",
show_label=True,
interactive=False,
visible=False,
)
# Right Column
with gr.Column():
log_table = gr.Dataframe(value=blank_df, interactive=False)
plot_radio = gr.Radio(choices=[], label="Select Plot Type")
plot_output = gr.Plot(label="Visualization")
log_table.select(log_selected, [log_table, state]).success(
update_plot_choices, inputs=[plot_radio, state], outputs=plot_radio
).then(display_plot, inputs=[plot_radio, state], outputs=plot_output)
def clear_eval(state):
state["results"] = {}
return None, gr.update(choices=[], value=None), blank_df
def eval_part(part, step, selector, show_progress=False):
return (
part.then(
lambda: [models_enabled[x] for x in selector],
outputs=model_checkbox,
)
.then(
evaluate_and_log,
inputs=[
model_checkbox,
selected_columns_message,
task_radiobutton,
log_table,
state,
],
outputs=log_table,
show_progress=show_progress,
)
.then(lambda: step, outputs=step_slider, show_progress=False)
)
part = (
eval_button.click(
lambda: (
gr.update(interactive=False),
gr.update(interactive=False),
),
outputs=[eval_clear_button, eval_button],
)
.then(
select_columns,
inputs=[
input_column_selector,
output_column_selector,
train_file,
test_file,
state,
],
outputs=selected_columns_message,
)
.then(
clear_eval,
inputs=state,
outputs=[
plot_output,
plot_radio,
log_table,
],
)
)
part = part.then(
lambda: gr.update(value=0, visible=True),
outputs=step_slider,
show_progress=False,
)
part = eval_part(part, 1, [0], True)
part = eval_part(part, 2, [1])
part = eval_part(part, 3, [2])
part = eval_part(part, 4, [3])
part = eval_part(part, 5, [1, 2])
part = eval_part(part, 6, [2, 3])
part = eval_part(part, 7, [1, 3])
part = eval_part(part, 8, [1, 2, 3])
part = part.then(
lambda: gr.update(visible=False),
outputs=step_slider,
show_progress=False,
)
part.then(
lambda: (
gr.update(interactive=True),
gr.update(interactive=True),
),
outputs=[eval_clear_button, eval_button],
)
plot_radio.change(
display_plot, inputs=[plot_radio, state], outputs=plot_output
)
eval_clear_button.click(
clear_eval,
inputs=state,
outputs=[
plot_output,
plot_radio,
log_table,
],
).then(lambda: " ", outputs=dataset_selector)
# Molecule Generation
with gr.Blocks() as molecule_generation:
gr.HTML(
'''
<p style="text-align: center">
Task : Molecule Generation
<br>
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
</p>
'''
)
with gr.Row():
with gr.Column():
smiles_input = gr.Textbox(label="Input SMILES String")
image_display = gr.Image(label="Molecule Image", height=250, width=250)
# Show images for selection
with gr.Accordion("Select from sample molecules", open=False):
image_selector = gr.Radio(
choices=list(smiles_image_mapping.keys()),
label="Select from sample molecules",
value=None,
)
image_selector.change(load_image, image_selector, image_display)
clear_button = gr.Button("Clear")
generate_button = gr.Button("Submit", variant="primary")
# Right Column
with gr.Column():
gen_image_display = gr.Image(
label="Generated Molecule Image", height=250, width=250
)
generated_output = gr.Textbox(label="Generated Output")
property_table = gr.Dataframe(label="Molecular Properties Comparison")
# Handle image selection
image_selector.change(
handle_image_selection,
inputs=image_selector,
outputs=[smiles_input, image_display],
)
smiles_input.change(
smiles_to_image, inputs=smiles_input, outputs=image_display
)
# Generate button to display canonical SMILES and molecule image
generate_button.click(
lambda: (
gr.update(interactive=False),
gr.update(interactive=False),
),
outputs=[clear_button, generate_button],
).then(
generate_canonical,
inputs=smiles_input,
outputs=[property_table, generated_output, gen_image_display],
).then(
lambda: (
gr.update(interactive=True),
gr.update(interactive=True),
),
outputs=[clear_button, generate_button],
)
clear_button.click(
lambda: (None, None, None, None, None, None),
outputs=[
smiles_input,
image_display,
image_selector,
gen_image_display,
generated_output,
property_table,
],
)
# Render with tabs
gr.TabbedInterface(
[introduction, property_prediction, molecule_generation],
["Introduction", "Property Prediction", "Molecule Generation"],
).launch(server_name="0.0.0.0", allowed_paths=["./"])