Enzo Reis de Oliveira commited on
Commit
073cdd9
·
1 Parent(s): cbc085f

Better error message for batch

Browse files
Files changed (1) hide show
  1. app.py +49 -22
app.py CHANGED
@@ -4,14 +4,14 @@ import json
4
  import pandas as pd
5
  import gradio as gr
6
 
7
- # 1) Adjust path before importing the loader
8
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
  INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
10
  sys.path.insert(0, INFERENCE_PATH)
11
 
12
  from smi_ted_light.load import load_smi_ted
13
 
14
- # 2) Load the SMI-TED Light model
15
  MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
16
  model = load_smi_ted(
17
  folder=MODEL_DIR,
@@ -19,14 +19,15 @@ model = load_smi_ted(
19
  vocab_filename="bert_vocab_curated.txt",
20
  )
21
 
22
- # 3) Single function to process either a single SMILES or a CSV of SMILES
23
  def process_inputs(smiles: str, file_obj):
24
- # Se um arquivo CSV for fornecido, processa em batch
25
  if file_obj is not None:
26
  try:
 
27
  df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
28
 
29
- smiles_cols = [col for col in df_in.columns if col.lower() == "smiles"]
 
30
  if not smiles_cols:
31
  return (
32
  "Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
@@ -35,42 +36,68 @@ def process_inputs(smiles: str, file_obj):
35
  smiles_col = smiles_cols[0]
36
  smiles_list = df_in[smiles_col].astype(str).tolist()
37
 
38
- if (len(smiles_list) > 100):
39
- return (
40
- "Error: The CSV must have up to 100 Smiles.",
41
- gr.update(visible=False),
42
- )
43
 
44
- embeddings = []
45
  for sm in smiles_list:
46
- vec = model.encode(sm, return_torch=True)[0].tolist()
47
- embeddings.append(vec)
48
-
49
- out_df = pd.DataFrame(embeddings)
50
- out_df.insert(0, "smiles", smiles_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  out_df.to_csv("embeddings.csv", index=False)
52
 
53
- msg = f"Processed batch of {len(smiles_list)} SMILES. Download embeddings.csv."
 
 
 
 
 
 
 
 
 
 
 
54
  return msg, gr.update(value="embeddings.csv", visible=True)
55
 
56
  except Exception as e:
57
  return f"Error processing batch: {e}", gr.update(visible=False)
58
 
59
- # Modo single
60
  smiles = smiles.strip()
61
  if not smiles:
62
  return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
63
  try:
64
  vec = model.encode(smiles, return_torch=True)[0].tolist()
65
- # Salva CSV com header
66
  cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
67
  df_out = pd.DataFrame([[smiles] + vec], columns=cols)
68
  df_out.to_csv("embeddings.csv", index=False)
69
  return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
70
- except Exception as e:
71
  return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)
72
 
73
- # 4) Build the Gradio Blocks interface
 
74
  with gr.Blocks() as demo:
75
  gr.Markdown(
76
  """
@@ -88,7 +115,7 @@ with gr.Blocks() as demo:
88
  generate_btn = gr.Button("Extract Embeddings")
89
 
90
  with gr.Row():
91
- output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=2)
92
  download_csv = gr.File(label="Download embeddings.csv", visible=False)
93
 
94
  generate_btn.click(
 
4
  import pandas as pd
5
  import gradio as gr
6
 
7
+ # 1) Ajusta o path antes de importar o loader
8
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
  INFERENCE_PATH = os.path.join(BASE_DIR, "smi-ted", "inference")
10
  sys.path.insert(0, INFERENCE_PATH)
11
 
12
  from smi_ted_light.load import load_smi_ted
13
 
14
+ # 2) Carrega o modelo SMI-TED Light
15
  MODEL_DIR = os.path.join(INFERENCE_PATH, "smi_ted_light")
16
  model = load_smi_ted(
17
  folder=MODEL_DIR,
 
19
  vocab_filename="bert_vocab_curated.txt",
20
  )
21
 
 
22
  def process_inputs(smiles: str, file_obj):
23
+ # Modo batch
24
  if file_obj is not None:
25
  try:
26
+ # autodetecta delimitador (; ou , etc)
27
  df_in = pd.read_csv(file_obj.name, sep=None, engine='python')
28
 
29
+ # procura coluna "smiles" (case‐insensitive)
30
+ smiles_cols = [c for c in df_in.columns if c.lower() == "smiles"]
31
  if not smiles_cols:
32
  return (
33
  "Error: The CSV must have a column named 'Smiles' with the respective SMILES.",
 
36
  smiles_col = smiles_cols[0]
37
  smiles_list = df_in[smiles_col].astype(str).tolist()
38
 
39
+ out_records = []
40
+ invalid_smiles = []
41
+ embed_dim = None
 
 
42
 
43
+ # para cada SMILES, tenta gerar embedding
44
  for sm in smiles_list:
45
+ try:
46
+ vec = model.encode(sm, return_torch=True)[0].tolist()
47
+ # guarda dimensão do vetor na primeira vez
48
+ if embed_dim is None:
49
+ embed_dim = len(vec)
50
+ # monta registro válido
51
+ record = {"smiles": sm}
52
+ record.update({f"dim_{i}": v for i, v in enumerate(vec)})
53
+ except Exception:
54
+ # marca como inválido
55
+ invalid_smiles.append(sm)
56
+ # se já souber quantos dims, preenche com None
57
+ if embed_dim is not None:
58
+ record = {"smiles": f"SMILES {sm} was invalid"}
59
+ record.update({f"dim_{i}": None for i in range(embed_dim)})
60
+ else:
61
+ # ainda não sabemos quantos dims: só guarda smiles
62
+ record = {"smiles": f"SMILES {sm} was invalid"}
63
+ out_records.append(record)
64
+
65
+ # converte para DataFrame (vai unificar todas as colunas)
66
+ out_df = pd.DataFrame(out_records)
67
  out_df.to_csv("embeddings.csv", index=False)
68
 
69
+ # monta mensagem de saída
70
+ total = len(smiles_list)
71
+ valid = total - len(invalid_smiles)
72
+ if invalid_smiles:
73
+ msg = (
74
+ f"{valid} SMILES were successfully processed, "
75
+ f"{len(invalid_smiles)} had errors:\n"
76
+ + "\n".join(invalid_smiles)
77
+ )
78
+ else:
79
+ msg = f"Processed batch of {valid} SMILES. Download embeddings.csv."
80
+
81
  return msg, gr.update(value="embeddings.csv", visible=True)
82
 
83
  except Exception as e:
84
  return f"Error processing batch: {e}", gr.update(visible=False)
85
 
86
+ # Modo single (sem mudança)
87
  smiles = smiles.strip()
88
  if not smiles:
89
  return "Please enter a SMILES or upload a CSV file.", gr.update(visible=False)
90
  try:
91
  vec = model.encode(smiles, return_torch=True)[0].tolist()
 
92
  cols = ["smiles"] + [f"dim_{i}" for i in range(len(vec))]
93
  df_out = pd.DataFrame([[smiles] + vec], columns=cols)
94
  df_out.to_csv("embeddings.csv", index=False)
95
  return json.dumps(vec), gr.update(value="embeddings.csv", visible=True)
96
+ except Exception:
97
  return f"The following input '{smiles}' is not a valid SMILES", gr.update(visible=False)
98
 
99
+
100
+ # 4) Interface Gradio (sem mudanças)
101
  with gr.Blocks() as demo:
102
  gr.Markdown(
103
  """
 
115
  generate_btn = gr.Button("Extract Embeddings")
116
 
117
  with gr.Row():
118
+ output_msg = gr.Textbox(label="Message / Embedding (JSON)", interactive=False, lines=4)
119
  download_csv = gr.File(label="Download embeddings.csv", visible=False)
120
 
121
  generate_btn.click(