import streamlit as st from trainer import Trainer import random class DrugGENConfig: submodel='CrossLoss' act='relu' z_dim=16 max_atom=45 lambda_gp=1 dim=128 depth=1 heads=8 dec_depth=1 dec_heads=8 dec_dim=128 mlp_ratio=3 warm_up_steps=0 dis_select='mlp' init_type='normal' batch_size=128 epoch=50 g_lr=0.00001 d_lr=0.00001 g2_lr=0.00001 d2_lr=0.00001 dropout=0. dec_dropout=0. n_critic=1 beta1=0.9 beta2=0.999 resume_iters=None clipping_value=2 features=False test_iters=10_000 num_test_epoch=30_000 inference_sample_num=1000 num_workers=1 mode="inference" inference_iterations=100 inf_batch_size=1 protein_data_dir='data/akt' drug_index='data/drug_smiles.index' drug_data_dir='data/akt' mol_data_dir='data' log_dir='experiments/logs' model_save_dir='experiments/models' # inference_model="" sample_dir='experiments/samples' result_dir="experiments/tboard_output" dataset_file="chembl45_train.pt" drug_dataset_file="akt_train.pt" raw_file='data/chembl_train.smi' drug_raw_file="data/akt_train.smi" inf_dataset_file="chembl45_test.pt" inf_drug_dataset_file='akt_test.pt' inf_raw_file='data/chembl_test.smi' inf_drug_raw_file="data/akt_test.smi" log_sample_step=1000 set_seed=True seed=1 resume=False resume_epoch=None resume_iter=None resume_directory=None class ProtConfig(DrugGENConfig): submodel="Prot" inference_model="experiments/models/Prot" class CrossLossConfig(DrugGENConfig): submodel="CrossLoss" inference_model="experiments/models/CrossLoss" class NoTargetConfig(DrugGENConfig): submodel="NoTarget" inference_model="experiments/models/NoTarget" model_configs = { "Prot": ProtConfig(), "CrossLoss": CrossLossConfig(), "NoTarget": NoTargetConfig(), } with st.sidebar: st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") st.write("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868) [![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)") with st.form("model_selection_from"): model_name = st.radio( "Select a model to make inference (Prot and CrossLoss models are pretrained on AKT1 inhibitors)", ('Prot', 'CrossLoss', 'NoTarget')) molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1) seed_input = st.number_input("Input an RNG seed for reproducibiliy", min_value=0, value=42, step=1) submitted = st.form_submit_button("Start Computing") if submitted: config = model_configs[model_name] config.inference_sample_num = molecule_num_input config.seed = seed_input with st.spinner(f'Creating the trainer class instance for {model_name}...'): trainer = Trainer(config) with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'): results = trainer.inference() st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.") with st.expander("Expand to see scores"): st.success(f"Validity: {results['fraction_valid']}") st.success(f"Uniqueness: {results['uniqueness']}") st.success(f"Novelty: {results['novelty']}") with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f: inference_drugs = f.read() st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'{model_name}_inference.smi', mime="text/plain") else: st.warning("Please select a model to make inference")