SMI-TED-demo1 / app.py
Enzo Reis de Oliveira
Fixing bug
ddae879
raw
history blame
2.38 kB
import os
import sys
import json
import tempfile
import pandas as pd
import gradio as gr
from PIL import Image
# 1) Ajusta o path antes de importar o loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)
# 2) Importa o loader do SMI-TED Light
from smi_ted_light.load import load_smi_ted
# 3) Carrega o modelo
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
# 4) Função que gera o embedding e cria o CSV temporário
def gerar_embedding_e_csv(smiles: str):
smiles = smiles.strip()
if not smiles:
erro = {"erro": "digite uma sequência SMILES primeiro"}
return json.dumps(erro), gr.update(visible=False)
try:
# Gera o vetor
vetor = model.encode(smiles, return_torch=True)[0].tolist()
# Grava CSV
df = pd.DataFrame([vetor])
tmp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
df.to_csv(tmp.name, index=False)
tmp.close()
# Retorna JSON em string e ativa o link de download
return json.dumps(vetor), gr.update(value=tmp.name, visible=True)
except Exception as e:
erro = {"erro": str(e)}
return json.dumps(erro), gr.update(visible=False)
# 5) Monta a interface com Blocks
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMI-TED Embedding Generator
Cole uma sequência SMILES e:
- Veja o vetor embedding (JSON)
- Baixe-o em CSV
"""
)
with gr.Row():
smiles_in = gr.Textbox(label="SMILES", placeholder="Ex.: CCO")
gerar_btn = gr.Button("Gerar Embedding")
with gr.Row():
embedding_out = gr.Textbox(
label="Embedding (JSON)",
interactive=False,
lines=4,
placeholder="O vetor aparecerá aqui…"
)
download_csv = gr.File(
label="Baixar CSV",
visible=False
)
# Conecta botão à função que tem dois outputs
gerar_btn.click(
fn=gerar_embedding_e_csv,
inputs=smiles_in,
outputs=[embedding_out, download_csv]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")