Spaces:
Running
Running
import streamlit as st | |
from trainer import Trainer | |
import random | |
class DrugGENConfig: | |
submodel='CrossLoss' | |
act='relu' | |
z_dim=16 | |
max_atom=45 | |
lambda_gp=1 | |
dim=128 | |
depth=1 | |
heads=8 | |
dec_depth=1 | |
dec_heads=8 | |
dec_dim=128 | |
mlp_ratio=3 | |
warm_up_steps=0 | |
dis_select='mlp' | |
init_type='normal' | |
batch_size=128 | |
epoch=50 | |
g_lr=0.00001 | |
d_lr=0.00001 | |
g2_lr=0.00001 | |
d2_lr=0.00001 | |
dropout=0. | |
dec_dropout=0. | |
n_critic=1 | |
beta1=0.9 | |
beta2=0.999 | |
resume_iters=None | |
clipping_value=2 | |
features=False | |
test_iters=10_000 | |
num_test_epoch=30_000 | |
inference_sample_num=1000 | |
num_workers=1 | |
mode="inference" | |
inference_iterations=100 | |
inf_batch_size=1 | |
protein_data_dir='data/akt' | |
drug_index='data/drug_smiles.index' | |
drug_data_dir='data/akt' | |
mol_data_dir='data' | |
log_dir='experiments/logs' | |
model_save_dir='experiments/models' | |
# inference_model="" | |
sample_dir='experiments/samples' | |
result_dir="experiments/tboard_output" | |
dataset_file="chembl45_train.pt" | |
drug_dataset_file="akt_train.pt" | |
raw_file='data/chembl_train.smi' | |
drug_raw_file="data/akt_train.smi" | |
inf_dataset_file="chembl45_test.pt" | |
inf_drug_dataset_file='akt_test.pt' | |
inf_raw_file='data/chembl_test.smi' | |
inf_drug_raw_file="data/akt_test.smi" | |
log_sample_step=1000 | |
set_seed=True | |
seed=1 | |
resume=False | |
resume_epoch=None | |
resume_iter=None | |
resume_directory=None | |
class ProtConfig(DrugGENConfig): | |
submodel="Prot" | |
inference_model="experiments/models/Prot" | |
class CrossLossConfig(DrugGENConfig): | |
submodel="CrossLoss" | |
inference_model="experiments/models/CrossLoss" | |
class NoTargetConfig(DrugGENConfig): | |
submodel="NoTarget" | |
inference_model="experiments/models/NoTarget" | |
model_configs = { | |
"Prot": ProtConfig(), | |
"CrossLoss": CrossLossConfig(), | |
"NoTarget": NoTargetConfig(), | |
} | |
with st.sidebar: | |
st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") | |
st.write("[](https://arxiv.org/abs/2302.07868) [](https://github.com/HUBioDataLab/DrugGEN)") | |
with st.form("model_selection_from"): | |
model_name = st.radio( | |
''' | |
Select a model to make inference (DrugGEN-Prot and DrugGEN-CrossLoss models design molecules to target the AKT1 protein) | |
- **DrugGEN-Prot**: composed of two GANs, incorporates protein features to the transformer decoder module of GAN2 (together with the de novo molecules generated by GAN1) to direct the target centric molecule design. | |
- **DrugGEN-CrossLoss**: composed of one GAN, the input of the GAN1 generator is the real molecules dataset and the GAN1 discriminator compares the generated molecules with the real inhibitors of the given target. | |
- **DrugGEN-NoTarget**: composed of one GAN, focuses on learning the chemical properties from the ChEMBL training dataset, no target-specific generation. | |
''' | |
", | |
('DrugGEN-Prot', 'DrugGEN-CrossLoss', 'DrugGEN-NoTarget') | |
) | |
model_name = model_name.replace("DrugGEN-", "") | |
molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1) | |
seed_input = st.number_input("Input an RNG seed for reproducibiliy", min_value=0, value=42, step=1) | |
submitted = st.form_submit_button("Start Computing") | |
if submitted: | |
config = model_configs[model_name] | |
config.inference_sample_num = molecule_num_input | |
config.seed = seed_input | |
with st.spinner(f'Creating the trainer class instance for {model_name}...'): | |
trainer = Trainer(config) | |
with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'): | |
results = trainer.inference() | |
st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.") | |
with st.expander("Expand to see scores"): | |
st.success(f"Validity: {results['fraction_valid']}") | |
st.success(f"Uniqueness: {results['uniqueness']}") | |
st.success(f"Novelty: {results['novelty']}") | |
with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f: | |
inference_drugs = f.read() | |
st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'{model_name}_inference.smi', mime="text/plain") | |
else: | |
st.warning("Please select a model to make inference") | |