Spaces:
Running
Running
File size: 3,707 Bytes
4d799f2 e1e6b13 843425c 5465560 843425c 5465560 843425c 64428bf 5465560 214fccd cb4cd4f 214fccd cb4cd4f 214fccd cb4cd4f 214fccd cb4cd4f 5465560 214fccd cb4cd4f 214fccd 5465560 214fccd cb4cd4f 64428bf 5465560 64428bf 214fccd cb4cd4f 214fccd fe92162 f3e37c7 4d799f2 cb4cd4f 5465560 4d799f2 f3e37c7 5465560 f3e37c7 4d799f2 e1e6b13 4d799f2 5465560 214fccd f3e37c7 ddae879 4d799f2 5465560 4d799f2 5465560 214fccd 4d799f2 f63af71 64428bf ddae879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import sys
import json
import pandas as pd
import gradio as gr
# 1) Adjust path before importing the loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)
from smi_ted_light.load import load_smi_ted
# 2) Load the SMI-TED Light model
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
# 3) Single function to process either a single SMILES or a CSV of SMILES
def process_inputs(smiles: str, file_obj):
# Se um arquivo CSV for fornecido, processa em batch
if file_obj is not None:
try:
df_in = pd.read_csv(file_obj.name)
# Procura coluna "smiles" (case‐insensitive), mas sem aceitar prefixes/sufixos
smiles_cols = [col for col in df_in.columns if col.lower() == "smiles"]
if not smiles_cols:
return (
"Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
gr.update(visible=False),
)
smiles_col = smiles_cols[0]
smiles_list = df_in[smiles_col].astype(str).tolist()
embeddings = []
for sm in smiles_list:
vec = model.encode(sm, return_torch=True)[0].tolist()
embeddings.append(vec)
# Constroi DataFrame de saída
out_df = pd.DataFrame(embeddings)
out_df.insert(0, "smiles", smiles_list)
out_df.to_csv("embeddings.csv", index=False)
msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
return msg, gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error processing batch: {e}", gr.update(visible=False)
# Modo single
smiles = smiles.strip()
if not smiles:
return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
try:
vec = model.encode(smiles, return_torch=True)[0].tolist()
# Salva CSV com header
cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
df_out = pd.DataFrame([[smiles] + vec], columns=cols)
df_out.to_csv("embeddings.csv", index=False)
return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error extracting embedding: {e}", gr.update(visible=False)
# 4) Build the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMI-TED-Embeddings-Extraction
**Single mode:** paste a SMILES string in the left box.
**Batch mode:** upload a CSV file where each row has a SMILES in the first column.
In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
"""
)
with gr.Row():
smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
file_in = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])
generate_btn = gr.Button("Extract Embeddings")
with gr.Row():
output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
download_csv = gr.File(label="Download embeddings.csv", visible=False)
generate_btn.click(
fn=process_inputs,
inputs=[smiles_in, file_in],
outputs=[output_msg, download_csv]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")
|