Spaces:
Running
Running
File size: 5,169 Bytes
4d799f2 e1e6b13 843425c 073cdd9 843425c 073cdd9 843425c 64428bf 214fccd 073cdd9 214fccd 073cdd9 6b91e18 cb4cd4f 073cdd9 cb4cd4f 073cdd9 68704e5 073cdd9 214fccd 073cdd9 214fccd cb4cd4f 073cdd9 c9e9b6b 073cdd9 c9e9b6b 073cdd9 214fccd cb4cd4f 214fccd 5465560 214fccd 073cdd9 64428bf 5465560 64428bf 214fccd 073cdd9 cbc085f cb4cd4f 073cdd9 4d799f2 f3e37c7 5465560 f3e37c7 4d799f2 e1e6b13 4d799f2 5465560 214fccd f3e37c7 ddae879 4d799f2 073cdd9 5465560 4d799f2 5465560 214fccd 4d799f2 f63af71 64428bf ddae879 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import sys
import json
import pandas as pd
import gradio as gr
# 1) Ajusta o path antes de importar o loader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
sys.path.insert(0, INFERENCE_PATH)
from smi_ted_light.load import load_smi_ted
# 2) Carrega o modelo SMI-TED Light
MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
model = load_smi_ted(
folder=MODEL_DIR,
ckpt_filename="smi-ted-Light_40.pt",
vocab_filename="bert_vocab_curated.txt",
)
def process_inputs(smiles: str, file_obj):
# Modo batch
if file_obj is not None:
try:
# autodetecta delimitador (; ou , etc)
df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
# procura coluna "smiles" (case‐insensitive)
smiles_cols = [c for c in df_in.columns if c.lower() == "smiles"]
if not smiles_cols:
return (
"Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
gr.update(visible=False),
)
smiles_col = smiles_cols[0]
smiles_list = df_in[smiles_col].astype(str).tolist()
out_records = []
invalid_smiles = []
embed_dim = None
# para cada SMILES, tenta gerar embedding
for sm in smiles_list:
try:
vec = model.encode(sm, return_torch=True)[0].tolist()
# guarda dimensão do vetor na primeira vez
if embed_dim is None:
embed_dim = len(vec)
# monta registro válido
record = {"smiles": sm}
record.update({f"dim_{i}": v for i, v in enumerate(vec)})
except Exception:
# marca como inválido
invalid_smiles.append(sm)
# se já souber quantos dims, preenche com None
if embed_dim is not None:
record = {"smiles": f"SMILES {sm} was invalid"}
record.update({f"dim_{i}": None for i in range(embed_dim)})
else:
# ainda não sabemos quantos dims: só guarda smiles
record = {"smiles": f"SMILES {sm} was invalid"}
out_records.append(record)
# converte para DataFrame (vai unificar todas as colunas)
out_df = pd.DataFrame(out_records)
out_df.to_csv("embeddings.csv", index=False)
# monta mensagem de saída
total = len(smiles_list)
valid = total - len(invalid_smiles)
if invalid_smiles:
invalid_count = len(invalid_smiles)
msg = (
f"{valid} SMILES processed successfully. "
f"{invalid_count} entr{'y' if invalid_count==1 else 'ies'} "
f"could not be parsed by RDKit:\n"
+ "\n".join(f"- {sm}" for sm in invalid_smiles)
)
else:
msg = f"Processed batch of {valid} SMILES. Download embeddings.csv."
return msg, gr.update(value="embeddings.csv", visible=True)
except Exception as e:
return f"Error processing batch: {e}", gr.update(visible=False)
# Modo single (sem mudança)
smiles = smiles.strip()
if not smiles:
return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
try:
vec = model.encode(smiles, return_torch=True)[0].tolist()
cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
df_out = pd.DataFrame([[smiles] + vec], columns=cols)
df_out.to_csv("embeddings.csv", index=False)
return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
except Exception:
return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)
# 4) Interface Gradio (sem mudanças)
with gr.Blocks() as demo:
gr.Markdown(
"""
# SMI-TED-Embeddings-Extraction
**Single mode:** paste a SMILES string in the left box.
**Batch mode:** upload a CSV file where each row has a SMILES in the first column.
In both cases, an `embeddings.csv` file will be extracted for download, with the first column as SMILES and the embedding values in the following columns.
"""
)
with gr.Row():
smiles_in = gr.Textbox(label="SMILES (single mode)", placeholder="e.g. CCO")
file_in = gr.File(label="SMILES CSV (batch mode)", file_types=[".csv"])
generate_btn = gr.Button("Extract Embeddings")
with gr.Row():
output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=4)
download_csv = gr.File(label="Download embeddings.csv", visible=False)
generate_btn.click(
fn=process_inputs,
inputs=[smiles_in, file_in],
outputs=[output_msg, download_csv]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")
|