File size: 3,260 Bytes
4d799f2
 
e1e6b13
 
 
843425c
5465560
843425c
 
 
 
 
 
5465560
843425c
64428bf
 
 
 
 
 
5465560
214fccd
5465560
214fccd
 
 
 
 
 
 
 
5465560
214fccd
 
 
5465560
214fccd
 
5465560
214fccd
5465560
64428bf
 
5465560
64428bf
214fccd
5465560
214fccd
 
 
 
fe92162
f3e37c7
4d799f2
5465560
4d799f2
 
 
f3e37c7
5465560
 
f3e37c7
4d799f2
 
e1e6b13
4d799f2
5465560
 
214fccd
f3e37c7
ddae879
4d799f2
5465560
 
4d799f2
5465560
214fccd
 
 
4d799f2
f63af71
64428bf
ddae879
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import sys
import json
import pandas as pd
import gradio as gr

# 1) Adjust path before importing the loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)

from smi_ted_light.load import load_smi_ted

# 2) Load the SMI-TED Light model
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
    folder=MODEL_DIR,
    ckpt_filename="smi-ted-Light_40.pt",
    vocab_filename="bert_vocab_curated.txt",
)

# 3) Single function to process either a single SMILES or a CSV of SMILES
def process_inputs(smiles: str, file_obj):
    # If a CSV file is provided, process in batch
    if file_obj is not None:
        try:
            df_in = pd.read_csv(file_obj.name)
            smiles_list = df_in.iloc[:, 0].astype(str).tolist()
            embeddings = []
            for sm in smiles_list:
                vec = model.encode(sm, return_torch=True)[0].tolist()
                embeddings.append(vec)
            # Build output DataFrame
            out_df = pd.DataFrame(embeddings)
            out_df.insert(0, "smiles", smiles_list)
            out_df.to_csv("embeddings.csv", index=False)
            msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
            return msg, gr.update(value="embeddings.csv", visible=True)
        except Exception as e:
            return f"Error processing batch: {e}", gr.update(visible=False)

    # Otherwise, process a single SMILES
    smiles = smiles.strip()
    if not smiles:
        return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
    try:
        vec = model.encode(smiles, return_torch=True)[0].tolist()
        # Save CSV with header
        cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
        df_out = pd.DataFrame([[smiles] + vec], columns=cols)
        df_out.to_csv("embeddings.csv", index=False)
        return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
    except Exception as e:
        return f"Error extracting embedding: {e}", gr.update(visible=False)

# 4) Build the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # SMI-TED-Embeddings-Extraction
        **Single mode:** paste a SMILES string in the left box.  
        **Batch mode:** upload a CSV file where each row has a SMILES in the first column.  
        In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
        """
    )

    with gr.Row():
        smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
        file_in   = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])

    generate_btn = gr.Button("Extract Embeddings")

    with gr.Row():
        output_msg   = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
        download_csv = gr.File(label="Download embeddings.csv", visible=False)

    generate_btn.click(
        fn=process_inputs,
        inputs=[smiles_in, file_in],
        outputs=[output_msg, download_csv]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")