import os, sys BASE_DIR = os.path.dirname(__file__) INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference") sys.path.append(INFERENCE_DIR) import gradio as gr from smi_ted_light.load import load_smi_ted # 1) Ajusta o path para o inference do SMI-TED BASE_DIR = os.path.dirname(__file__) INFERENCE_DIR = os.path.join(BASE_DIR, "smi-ted", "inference") sys.path.append(INFERENCE_DIR) # 2) Caminho onde estão pesos e vocabulário MODEL_DIR = os.path.join("smi-ted", "inference", "smi_ted_light") # 3) Carrega o modelo SMI-TED (Light) model = load_smi_ted( folder=MODEL_DIR, ckpt_filename="smi-ted-Light_40.pt", vocab_filename="bert_vocab_curated.txt", ) # 4) Função utilizada pela interface Gradio def gerar_embedding(smiles: str): """ Recebe uma string SMILES e devolve: - embedding (lista de 768 floats) - caminho para um CSV com esse embedding, pronto para download Em caso de erro, devolve um dicionário com a mensagem e nenhum arquivo. """ smiles = smiles.strip() if not smiles: return {"erro": "digite uma sequência SMILES primeiro"}, None try: # model.encode devolve tensor shape (1, 768) vetor_torch = model.encode(smiles, return_torch=True)[0] embedding = vetor_torch.tolist() # Cria um CSV temporário com uma única linha (o embedding) df = pd.DataFrame([embedding]) tmp = tempfile.NamedTemporaryFile( delete=False, suffix=".csv", prefix="embedding_", ) csv_path = tmp.name df.to_csv(csv_path, index=False) tmp.close() return embedding, csv_path except Exception as e: return {"erro": str(e)}, None # 5) Define a interface Gradio com dois outputs: JSON e arquivo para download demo = gr.Interface( fn=gerar_embedding, inputs=gr.Textbox(label="SMILES", placeholder="Ex.: CCO"), outputs=[ gr.JSON(label="Embedding (lista de floats)"), gr.File(label="Baixar embedding em CSV"), ], title="SMI-TED Embedding Generator", description=( "Cole uma sequência SMILES e receba o embedding gerado pelo modelo " "SMI-TED Light treinado pela IBM Research. " "Você também pode baixar o embedding em CSV." ), ) # 6) Roda localmente ou no HF Space if __name__ == "__main__": demo.launch()