File size: 5,355 Bytes
4d799f2
 
e1e6b13
 
 
843425c
073cdd9
843425c
 
 
 
 
 
073cdd9
843425c
64428bf
 
 
 
 
 
214fccd
073cdd9
214fccd
 
073cdd9
6b91e18
cb4cd4f
073cdd9
 
cb4cd4f
 
 
 
 
 
 
 
862c2e6
 
 
 
 
 
 
073cdd9
 
 
68704e5
073cdd9
214fccd
073cdd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214fccd
cb4cd4f
073cdd9
 
862c2e6
073cdd9
 
c9e9b6b
862c2e6
 
073cdd9
 
 
 
214fccd
cb4cd4f
214fccd
5465560
214fccd
073cdd9
64428bf
 
5465560
64428bf
214fccd
 
 
 
 
073cdd9
cbc085f
cb4cd4f
073cdd9
 
4d799f2
 
 
f3e37c7
862c2e6
5465560
 
862c2e6
 
 
 
 
 
4d799f2
 
e1e6b13
4d799f2
5465560
 
214fccd
f3e37c7
ddae879
4d799f2
073cdd9
5465560
4d799f2
5465560
214fccd
 
 
4d799f2
f63af71
64428bf
ddae879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import sys
import json
import pandas as pd
import gradio as gr

# 1) Ajusta o path antes de importar o loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)

from smi_ted_light.load import load_smi_ted

# 2) Carrega o modelo SMI-TED Light
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
    folder=MODEL_DIR,
    ckpt_filename="smi-ted-Light_40.pt",
    vocab_filename="bert_vocab_curated.txt",
)

def process_inputs(smiles: str, file_obj):
    # Modo batch
    if file_obj is not None:
        try:
            # autodetecta delimitador (; ou , etc)
            df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
            
            # procura coluna "smiles" (case‐insensitive)
            smiles_cols = [c for c in df_in.columns if c.lower() == "smiles"]
            if not smiles_cols:
                return (
                    "Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
                    gr.update(visible=False),
                )
            smiles_col = smiles_cols[0]
            smiles_list = df_in[smiles_col].astype(str).tolist()

            # **novo**: limite de 1000 SMILES
            if len(smiles_list) > 1000:
                return (
                    f"Error: Maximum 1000 SMILES allowed per batch (you provided {len(smiles_list)}).",
                    gr.update(visible=False),
                )

            out_records = []
            invalid_smiles = []
            embed_dim = None

            # para cada SMILES, tenta gerar embedding
            for sm in smiles_list:
                try:
                    vec = model.encode(sm, return_torch=True)[0].tolist()
                    if embed_dim is None:
                        embed_dim = len(vec)
                    record = {"smiles": sm}
                    record.update({f"dim_{i}": v for i, v in enumerate(vec)})
                except Exception:
                    invalid_smiles.append(sm)
                    if embed_dim is not None:
                        record = {"smiles": f"SMILES {sm} was invalid"}
                        record.update({f"dim_{i}": None for i in range(embed_dim)})
                    else:
                        record = {"smiles": f"SMILES {sm} was invalid"}
                out_records.append(record)

            out_df = pd.DataFrame(out_records)
            out_df.to_csv("embeddings.csv", index=False)

            total = len(smiles_list)
            valid = total - len(invalid_smiles)
            invalid_count = len(invalid_smiles)
            if invalid_smiles:
                msg = (
                    f"{valid} SMILES processed successfully. "
                    f"{invalid_count} entr{'y' if invalid_count==1 else 'ies'} could not be parsed by RDKit:\n"
                    + "\n".join(f"- {s}" for s in invalid_smiles)
                )
            else:
                msg = f"Processed batch of {valid} SMILES. Download embeddings.csv."

            return msg, gr.update(value="embeddings.csv", visible=True)

        except Exception as e:
            return f"Error processing batch: {e}", gr.update(visible=False)

    # Modo single (sem mudança)
    smiles = smiles.strip()
    if not smiles:
        return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
    try:
        vec = model.encode(smiles, return_torch=True)[0].tolist()
        cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
        df_out = pd.DataFrame([[smiles] + vec], columns=cols)
        df_out.to_csv("embeddings.csv", index=False)
        return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
    except Exception:
        return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)


# 4) Interface Gradio (sem mudanças)
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # SMI-TED-Embeddings-Extraction

        **Single mode:** paste a SMILES string in the left box.  
        **Batch mode:** upload a CSV file where each row has a SMILES in the first column.  
        - **Maximum 1000 SMILES per batch.** Processing time increases with batch size due to Hugging Face environment limits.  
        _This is just a demo environment; for heavy-duty usage, please visit:_  
        https://github.com/IBM/materials/tree/main/models/smi_ted  
        to download the model and run your own experiments.

        - In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
        """
    )

    with gr.Row():
        smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
        file_in   = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])

    generate_btn = gr.Button("Extract Embeddings")

    with gr.Row():
        output_msg   = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=4)
        download_csv = gr.File(label="Download embeddings.csv", visible=False)

    generate_btn.click(
        fn=process_inputs,
        inputs=[smiles_in, file_in],
        outputs=[output_msg, download_csv]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")