Spaces:
Running
Running
File size: 2,375 Bytes
9cd0749 64428bf 9cd0749 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf fe92162 64428bf 9cd0749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os, sys
BASE_DIR = os.path.dirname(__file__)
INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.append(INFERENCE_DIR)
import gradio as gr
from smi_ted_light.load import load_smi_ted
# 1) Ajusta o path para o inference do SMI-TED
BASE_DIR = os.path.dirname(__file__)
INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.append(INFERENCE_DIR)
# 2) Caminho onde estão pesos e vocabulário
MODEL_DIR = os.path.join("smi-ted", "inference", "smi_ted_light")
# 3) Carrega o modelo SMI-TED (Light)
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
# 4) Função utilizada pela interface Gradio
def gerar_embedding(smiles: str):
"""
Recebe uma string SMILES e devolve:
- embedding (lista de 768 floats)
- caminho para um CSV com esse embedding, pronto para download
Em caso de erro, devolve um dicionário com a mensagem e nenhum arquivo.
"""
smiles = smiles.strip()
if not smiles:
return {"erro": "digite uma sequência SMILES primeiro"}, None
try:
# model.encode devolve tensor shape (1, 768)
vetor_torch = model.encode(smiles, return_torch=True)[0]
embedding = vetor_torch.tolist()
# Cria um CSV temporário com uma única linha (o embedding)
df = pd.DataFrame([embedding])
tmp = tempfile.NamedTemporaryFile(
delete=False,
suffix=".csv",
prefix="embedding_",
)
csv_path = tmp.name
df.to_csv(csv_path, index=False)
tmp.close()
return embedding, csv_path
except Exception as e:
return {"erro": str(e)}, None
# 5) Define a interface Gradio com dois outputs: JSON e arquivo para download
demo = gr.Interface(
fn=gerar_embedding,
inputs=gr.Textbox(label="SMILES", placeholder="Ex.: CCO"),
outputs=[
gr.JSON(label="Embedding (lista de floats)"),
gr.File(label="Baixar embedding em CSV"),
],
title="SMI-TED Embedding Generator",
description=(
"Cole uma sequência SMILES e receba o embedding gerado pelo modelo "
"SMI-TED Light treinado pela IBM Research. "
"Você também pode baixar o embedding em CSV."
),
)
# 6) Roda localmente ou no HF Space
if __name__ == "__main__":
demo.launch() |