File size: 4,927 Bytes
7afa2e2 e9af19b 6cc3c63 750cfac 89df0a5 6cc3c63 8c02a88 e513fba 8c02a88 e513fba 5f7ed50 8c02a88 26a7403 8c02a88 6cc3c63 491cb22 7afa2e2 6a4bbaf 6cc3c63 e9af19b 6cc3c63 6a4bbaf 6cc3c63 1037049 6cc3c63 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf 6cc3c63 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf 491cb22 e9af19b 6a4bbaf e9af19b 491cb22 6cc3c63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import spaces
import gradio as gr
import torch
import torch.nn as nn
import random
from rdkit import Chem
from rdkit.Chem import Draw
from graph_decoder.diffusion_model import GraphDiT
ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = 'model_labeled'
model = GraphDiT(
model_config_path=f"{path}/config.yaml",
data_info_path=f"{path}/data.meta.json",
model_dtype=torch.float32
)
model.to(device)
def generate_random_smiles(length=10):
return ''.join(random.choices(ATOM_SYMBOLS, k=length))
@spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
properties = torch.tensor([CH4, CO2, H2, N2, O2], dtype=torch.float32).unsqueeze(0)
print('in generate_polymer')
try:
# Generate a random SMILES string (this is a placeholder)
generated_molecule = generate_random_smiles()
# model.generate(properties, device)
mol = Chem.MolFromSmiles(generated_molecule)
if mol is not None:
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
img = Draw.MolToImage(mol)
return standardized_smiles, img
except Exception as e:
print(f"Error in generation: {e}")
return "Generation failed", None
# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
gr.Markdown("## Polymer Design with Random Neural Network")
with gr.Row():
CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
generate_btn = gr.Button("Generate Polymer")
with gr.Row():
result_smiles = gr.Textbox(label="Generated SMILES")
result_image = gr.Image(label="Molecule Visualization", type="pil")
generate_btn.click(
generate_polymer,
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
outputs=[result_smiles, result_image]
)
if __name__ == "__main__":
iface.launch()
# import spaces
# import gradio as gr
# import torch
# from rdkit import Chem
# from rdkit.Chem import Draw
# # from graph_decoder.diffusion_model import GraphDiT
# # Load the model
# def load_graph_decoder(path='model_labeled'):
# model = GraphDiT(
# model_config_path=f"{path}/config.yaml",
# data_info_path=f"{path}/data.meta.json",
# model_dtype=torch.float32,
# )
# model.init_model(path)
# model.disable_grads()
# return model
# # model = load_graph_decoder()
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# @spaces.GPU
# def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
# properties = [CH4, CO2, H2, N2, O2]
# try:
# model = load_graph_decoder()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print('enter function')
# generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
# if generated_molecule is not None:
# mol = Chem.MolFromSmiles(generated_molecule)
# if mol is not None:
# standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
# img = Draw.MolToImage(mol)
# return standardized_smiles, img
# except Exception as e:
# print(f"Error in generation: {e}")
# return "Generation failed", None
# # Create the Gradio interface
# with gr.Blocks(title="Simplified Polymer Design") as iface:
# gr.Markdown("## Polymer Design with GraphDiT")
# with gr.Row():
# CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
# CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
# H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
# N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
# O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
# guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
# generate_btn = gr.Button("Generate Polymer")
# with gr.Row():
# result_smiles = gr.Textbox(label="Generated SMILES")
# result_image = gr.Image(label="Molecule Visualization", type="pil")
# generate_btn.click(
# generate_polymer,
# inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
# outputs=[result_smiles, result_image]
# )
# if __name__ == "__main__":
# iface.launch() |