File size: 2,375 Bytes
9cd0749
 
 
 
 
64428bf
 
9cd0749
64428bf
fe92162
 
 
 
64428bf
 
 
 
fe92162
64428bf
 
 
 
 
 
fe92162
64428bf
 
fe92162
 
 
 
64428bf
 
 
fe92162
64428bf
 
fe92162
64428bf
fe92162
 
 
 
 
 
 
 
 
 
 
 
64428bf
fe92162
 
 
 
64428bf
fe92162
64428bf
 
 
fe92162
 
 
 
 
64428bf
 
fe92162
 
64428bf
 
 
fe92162
64428bf
9cd0749
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os, sys

BASE_DIR = os.path.dirname(__file__)
INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.append(INFERENCE_DIR) 

import gradio as gr
from smi_ted_light.load import load_smi_ted  

# 1) Ajusta o path para o inference do SMI-TED
BASE_DIR = os.path.dirname(__file__)
INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.append(INFERENCE_DIR)

# 2) Caminho onde estão pesos e vocabulário
MODEL_DIR = os.path.join("smi-ted", "inference", "smi_ted_light")

# 3) Carrega o modelo SMI-TED (Light)
model = load_smi_ted(
    folder=MODEL_DIR,
    ckpt_filename="smi-ted-Light_40.pt",
    vocab_filename="bert_vocab_curated.txt",
)

# 4) Função utilizada pela interface Gradio
def gerar_embedding(smiles: str):
    """
    Recebe uma string SMILES e devolve:
      - embedding (lista de 768 floats)
      - caminho para um CSV com esse embedding, pronto para download
    Em caso de erro, devolve um dicionário com a mensagem e nenhum arquivo.
    """
    smiles = smiles.strip()
    if not smiles:
        return {"erro": "digite uma sequência SMILES primeiro"}, None

    try:
        # model.encode devolve tensor shape (1, 768)
        vetor_torch = model.encode(smiles, return_torch=True)[0]
        embedding = vetor_torch.tolist()

        # Cria um CSV temporário com uma única linha (o embedding)
        df = pd.DataFrame([embedding])
        tmp = tempfile.NamedTemporaryFile(
            delete=False,
            suffix=".csv",
            prefix="embedding_",
        )
        csv_path = tmp.name
        df.to_csv(csv_path, index=False)
        tmp.close()

        return embedding, csv_path

    except Exception as e:
        return {"erro": str(e)}, None

# 5) Define a interface Gradio com dois outputs: JSON e arquivo para download
demo = gr.Interface(
    fn=gerar_embedding,
    inputs=gr.Textbox(label="SMILES", placeholder="Ex.: CCO"),
    outputs=[
        gr.JSON(label="Embedding (lista de floats)"),
        gr.File(label="Baixar embedding em CSV"),
    ],
    title="SMI-TED Embedding Generator",
    description=(
        "Cole uma sequência SMILES e receba o embedding gerado pelo modelo "
        "SMI-TED Light treinado pela IBM Research. "
        "Você também pode baixar o embedding em CSV."
    ),
)

# 6) Roda localmente ou no HF Space
if __name__ == "__main__":
    demo.launch()