SMI-TED-demo1 / app.py
Enzo Reis de Oliveira
Searching for smiles regardless of the position column
cb4cd4f
raw
history blame
3.71 kB
import os
import sys
import json
import pandas as pd
import gradio as gr
# 1) Adjust path before importing the loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)
from smi_ted_light.load import load_smi_ted
# 2) Load the SMI-TED Light model
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
# 3) Single function to process either a single SMILES or a CSV of SMILES
def process_inputs(smiles: str, file_obj):
# Se um arquivo CSV for fornecido, processa em batch
if file_obj is not None:
try:
df_in = pd.read_csv(file_obj.name)
# Procura coluna "smiles" (case‐insensitive), mas sem aceitar prefixes/sufixos
smiles_cols = [col for col in df_in.columns if col.lower() == "smiles"]
if not smiles_cols:
return (
"Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
gr.update(visible=False),
)
smiles_col = smiles_cols[0]
smiles_list = df_in[smiles_col].astype(str).tolist()
embeddings = []
for sm in smiles_list:
vec = model.encode(sm, return_torch=True)[0].tolist()
embeddings.append(vec)
# Constroi DataFrame de saída
out_df = pd.DataFrame(embeddings)
out_df.insert(0, "smiles", smiles_list)
out_df.to_csv("embeddings.csv", index=False)
msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
return msg, gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error processing batch: {e}", gr.update(visible=False)
# Modo single
smiles = smiles.strip()
if not smiles:
return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
try:
vec = model.encode(smiles, return_torch=True)[0].tolist()
# Salva CSV com header
cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
df_out = pd.DataFrame([[smiles] + vec], columns=cols)
df_out.to_csv("embeddings.csv", index=False)
return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error extracting embedding: {e}", gr.update(visible=False)
# 4) Build the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMI-TED-Embeddings-Extraction
**Single mode:** paste a SMILES string in the left box.
**Batch mode:** upload a CSV file where each row has a SMILES in the first column.
In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
"""
)
with gr.Row():
smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
file_in = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])
generate_btn = gr.Button("Extract Embeddings")
with gr.Row():
output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
download_csv = gr.File(label="Download embeddings.csv", visible=False)
generate_btn.click(
fn=process_inputs,
inputs=[smiles_in, file_in],
outputs=[output_msg, download_csv]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")