Enzo Reis de Oliveira commited on
Commit
fe92162
·
1 Parent(s): e1406e0

Adding CSV option download

Browse files
Files changed (1) hide show
  1. app.py +42 -21
app.py CHANGED
@@ -1,54 +1,75 @@
1
- import os, sys
2
-
3
- BASE_DIR = os.path.dirname(__file__)
4
- INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference")
5
- sys.path.append(INFERENCE_DIR)
6
 
 
7
  import gradio as gr
8
- from smi_ted_light.load import load_smi_ted
9
 
 
 
 
 
10
 
11
  # 2) Caminho onde estão pesos e vocabulário
12
  MODEL_DIR = os.path.join("smi-ted", "inference", "smi_ted_light")
13
 
14
- # 3) Carrega o modelo SMITED (Light)
15
- # Se você renomeou o .pt ou o vocab, ajuste aqui.
16
  model = load_smi_ted(
17
  folder=MODEL_DIR,
18
  ckpt_filename="smi-ted-Light_40.pt",
19
  vocab_filename="bert_vocab_curated.txt",
20
  )
21
 
22
- # 4) Função utilizada pela interface
23
  def gerar_embedding(smiles: str):
24
  """
25
- Recebe uma string SMILES e devolve o embedding (lista de 768 floats).
26
- Em caso de erro, devolve um dicionário com a mensagem.
 
 
27
  """
28
  smiles = smiles.strip()
29
  if not smiles:
30
- return {"erro": "digite uma sequência SMILES primeiro"}
31
 
32
  try:
33
- # model.encode devolve tensor shape (1, 768) quando return_torch=True
34
  vetor_torch = model.encode(smiles, return_torch=True)[0]
35
- return vetor_torch.tolist() # JSON‑serializável
36
- except Exception as e:
37
- return {"erro": str(e)}
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
39
 
40
- # 5) Define a interface Gradio
41
  demo = gr.Interface(
42
  fn=gerar_embedding,
43
  inputs=gr.Textbox(label="SMILES", placeholder="Ex.: CCO"),
44
- outputs=gr.JSON(label="Embedding (lista de floats)"),
45
- title="SMI‑TED Embedding Generator",
 
 
 
46
  description=(
47
  "Cole uma sequência SMILES e receba o embedding gerado pelo modelo "
48
- "SMITED Light treinado pela IBM Research."
 
49
  ),
50
  )
51
 
52
- # 6) Roda localmente ou no Hugging Face Space
53
  if __name__ == "__main__":
54
  demo.launch()
 
1
+ import os
2
+ import sys
3
+ import tempfile
 
 
4
 
5
+ import pandas as pd
6
  import gradio as gr
7
+ from smi_ted_light.load import load_smi_ted
8
 
9
+ # 1) Ajusta o path para o inference do SMI-TED
10
+ BASE_DIR = os.path.dirname(__file__)
11
+ INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference")
12
+ sys.path.append(INFERENCE_DIR)
13
 
14
  # 2) Caminho onde estão pesos e vocabulário
15
  MODEL_DIR = os.path.join("smi-ted", "inference", "smi_ted_light")
16
 
17
+ # 3) Carrega o modelo SMI-TED (Light)
 
18
  model = load_smi_ted(
19
  folder=MODEL_DIR,
20
  ckpt_filename="smi-ted-Light_40.pt",
21
  vocab_filename="bert_vocab_curated.txt",
22
  )
23
 
24
+ # 4) Função utilizada pela interface Gradio
25
  def gerar_embedding(smiles: str):
26
  """
27
+ Recebe uma string SMILES e devolve:
28
+ - embedding (lista de 768 floats)
29
+ - caminho para um CSV com esse embedding, pronto para download
30
+ Em caso de erro, devolve um dicionário com a mensagem e nenhum arquivo.
31
  """
32
  smiles = smiles.strip()
33
  if not smiles:
34
+ return {"erro": "digite uma sequência SMILES primeiro"}, None
35
 
36
  try:
37
+ # model.encode devolve tensor shape (1, 768)
38
  vetor_torch = model.encode(smiles, return_torch=True)[0]
39
+ embedding = vetor_torch.tolist()
40
+
41
+ # Cria um CSV temporário com uma única linha (o embedding)
42
+ df = pd.DataFrame([embedding])
43
+ tmp = tempfile.NamedTemporaryFile(
44
+ delete=False,
45
+ suffix=".csv",
46
+ prefix="embedding_",
47
+ )
48
+ csv_path = tmp.name
49
+ df.to_csv(csv_path, index=False)
50
+ tmp.close()
51
 
52
+ return embedding, csv_path
53
+
54
+ except Exception as e:
55
+ return {"erro": str(e)}, None
56
 
57
+ # 5) Define a interface Gradio com dois outputs: JSON e arquivo para download
58
  demo = gr.Interface(
59
  fn=gerar_embedding,
60
  inputs=gr.Textbox(label="SMILES", placeholder="Ex.: CCO"),
61
+ outputs=[
62
+ gr.JSON(label="Embedding (lista de floats)"),
63
+ gr.File(label="Baixar embedding em CSV"),
64
+ ],
65
+ title="SMI-TED Embedding Generator",
66
  description=(
67
  "Cole uma sequência SMILES e receba o embedding gerado pelo modelo "
68
+ "SMI-TED Light treinado pela IBM Research. "
69
+ "Você também pode baixar o embedding em CSV."
70
  ),
71
  )
72
 
73
+ # 6) Roda localmente ou no HF Space
74
  if __name__ == "__main__":
75
  demo.launch()