SMI-TED-demo1 / app.py
Enzo Reis de Oliveira
Fixing bug on pandas
8650a46
raw
history blame
3.58 kB
import os
import sys
import json
import pandas as pd
import gradio as gr
# 1) Adjust path before importing the loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)
from smi_ted_light.load import load_smi_ted
# 2) Load the SMI-TED Light model
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
# 3) Single function to process either a single SMILES or a CSV of SMILES
def process_inputs(smiles: str, file_obj):
# Se um arquivo CSV for fornecido, processa em batch
if file_obj is not None:
try:
df_in = pd.read_csv(file_obj.name, sep=';')
smiles_cols = [col for col in df_in.columns if col.lower() == "smiles"]
if not smiles_cols:
return (
"Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
gr.update(visible=False),
)
smiles_col = smiles_cols[0]
smiles_list = df_in[smiles_col].astype(str).tolist()
embeddings = []
for sm in smiles_list:
vec = model.encode(sm, return_torch=True)[0].tolist()
embeddings.append(vec)
out_df = pd.DataFrame(embeddings)
out_df.insert(0, "smiles", smiles_list)
out_df.to_csv("embeddings.csv", index=False)
msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
return msg, gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error processing batch: {e}", gr.update(visible=False)
# Modo single
smiles = smiles.strip()
if not smiles:
return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
try:
vec = model.encode(smiles, return_torch=True)[0].tolist()
# Salva CSV com header
cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
df_out = pd.DataFrame([[smiles] + vec], columns=cols)
df_out.to_csv("embeddings.csv", index=False)
return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error extracting embedding: {e}", gr.update(visible=False)
# 4) Build the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMI-TED-Embeddings-Extraction
**Single mode:** paste a SMILES string in the left box.
**Batch mode:** upload a CSV file where each row has a SMILES in the first column.
In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
"""
)
with gr.Row():
smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
file_in = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])
generate_btn = gr.Button("Extract Embeddings")
with gr.Row():
output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
download_csv = gr.File(label="Download embeddings.csv", visible=False)
generate_btn.click(
fn=process_inputs,
inputs=[smiles_in, file_in],
outputs=[output_msg, download_csv]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")