File size: 5,169 Bytes
4d799f2
 
e1e6b13
 
 
843425c
073cdd9
843425c
 
 
 
 
 
073cdd9
843425c
64428bf
 
 
 
 
 
214fccd
073cdd9
214fccd
 
073cdd9
6b91e18
cb4cd4f
073cdd9
 
cb4cd4f
 
 
 
 
 
 
 
073cdd9
 
 
68704e5
073cdd9
214fccd
073cdd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214fccd
cb4cd4f
073cdd9
 
 
 
c9e9b6b
073cdd9
c9e9b6b
 
 
 
073cdd9
 
 
 
214fccd
cb4cd4f
214fccd
5465560
214fccd
073cdd9
64428bf
 
5465560
64428bf
214fccd
 
 
 
 
073cdd9
cbc085f
cb4cd4f
073cdd9
 
4d799f2
 
 
f3e37c7
5465560
 
f3e37c7
4d799f2
 
e1e6b13
4d799f2
5465560
 
214fccd
f3e37c7
ddae879
4d799f2
073cdd9
5465560
4d799f2
5465560
214fccd
 
 
4d799f2
f63af71
64428bf
ddae879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import sys
import json
import pandas as pd
import gradio as gr

# 1) Ajusta o path antes de importar o loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)

from smi_ted_light.load import load_smi_ted

# 2) Carrega o modelo SMI-TED Light
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
    folder=MODEL_DIR,
    ckpt_filename="smi-ted-Light_40.pt",
    vocab_filename="bert_vocab_curated.txt",
)

def process_inputs(smiles: str, file_obj):
    # Modo batch
    if file_obj is not None:
        try:
            # autodetecta delimitador (; ou , etc)
            df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
            
            # procura coluna "smiles" (case‐insensitive)
            smiles_cols = [c for c in df_in.columns if c.lower() == "smiles"]
            if not smiles_cols:
                return (
                    "Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
                    gr.update(visible=False),
                )
            smiles_col = smiles_cols[0]
            smiles_list = df_in[smiles_col].astype(str).tolist()

            out_records = []
            invalid_smiles = []
            embed_dim = None

            # para cada SMILES, tenta gerar embedding
            for sm in smiles_list:
                try:
                    vec = model.encode(sm, return_torch=True)[0].tolist()
                    # guarda dimensão do vetor na primeira vez
                    if embed_dim is None:
                        embed_dim = len(vec)
                    # monta registro válido
                    record = {"smiles": sm}
                    record.update({f"dim_{i}": v for i, v in enumerate(vec)})
                except Exception:
                    # marca como inválido
                    invalid_smiles.append(sm)
                    # se já souber quantos dims, preenche com None
                    if embed_dim is not None:
                        record = {"smiles": f"SMILES {sm} was invalid"}
                        record.update({f"dim_{i}": None for i in range(embed_dim)})
                    else:
                        # ainda não sabemos quantos dims: só guarda smiles
                        record = {"smiles": f"SMILES {sm} was invalid"}
                out_records.append(record)

            # converte para DataFrame (vai unificar todas as colunas)
            out_df = pd.DataFrame(out_records)
            out_df.to_csv("embeddings.csv", index=False)

            # monta mensagem de saída
            total = len(smiles_list)
            valid = total - len(invalid_smiles)
            if invalid_smiles:
                invalid_count = len(invalid_smiles)
                msg = (
                    f"{valid} SMILES processed successfully. "
                    f"{invalid_count} entr{'y' if invalid_count==1 else 'ies'} "
                    f"could not be parsed by RDKit:\n"
                    + "\n".join(f"- {sm}" for sm in invalid_smiles)
                )
            else:
                msg = f"Processed batch of {valid} SMILES. Download embeddings.csv."

            return msg, gr.update(value="embeddings.csv", visible=True)

        except Exception as e:
            return f"Error processing batch: {e}", gr.update(visible=False)

    # Modo single (sem mudança)
    smiles = smiles.strip()
    if not smiles:
        return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
    try:
        vec = model.encode(smiles, return_torch=True)[0].tolist()
        cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
        df_out = pd.DataFrame([[smiles] + vec], columns=cols)
        df_out.to_csv("embeddings.csv", index=False)
        return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
    except Exception:
        return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)


# 4) Interface Gradio (sem mudanças)
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # SMI-TED-Embeddings-Extraction
        **Single mode:** paste a SMILES string in the left box.  
        **Batch mode:** upload a CSV file where each row has a SMILES in the first column.  
        In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
        """
    )

    with gr.Row():
        smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
        file_in   = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])

    generate_btn = gr.Button("Extract Embeddings")

    with gr.Row():
        output_msg   = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=4)
        download_csv = gr.File(label="Download embeddings.csv", visible=False)

    generate_btn.click(
        fn=process_inputs,
        inputs=[smiles_in, file_in],
        outputs=[output_msg, download_csv]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")