File size: 2,321 Bytes
7afa2e2 e9af19b 750cfac 26a9b58 e9af19b 6a4bbaf e9af19b 26a9b58 e9af19b 491cb22 6a4bbaf 491cb22 7afa2e2 6a4bbaf e9af19b 6a4bbaf e9af19b 26a9b58 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf e9af19b 6a4bbaf 491cb22 e9af19b 6a4bbaf e9af19b 491cb22 6a4bbaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import spaces
import gradio as gr
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from graph_decoder.diffusion_model import GraphDiT
# Load the model
def load_graph_decoder(path='model_labeled'):
model = GraphDiT(
model_config_path=f"{path}/config.yaml",
data_info_path=f"{path}/data.meta.json",
model_dtype=torch.float32,
)
model.init_model(path)
model.disable_grads()
# model = None
return model
model = load_graph_decoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
properties = [CH4, CO2, H2, N2, O2]
try:
model.to(device)
print('enter function')
generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
if generated_molecule is not None:
mol = Chem.MolFromSmiles(generated_molecule)
if mol is not None:
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
img = Draw.MolToImage(mol)
return standardized_smiles, img
except Exception as e:
print(f"Error in generation: {e}")
return "Generation failed", None
# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
gr.Markdown("## Polymer Design with GraphDiT")
with gr.Row():
CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
generate_btn = gr.Button("Generate Polymer")
with gr.Row():
result_smiles = gr.Textbox(label="Generated SMILES")
result_image = gr.Image(label="Molecule Visualization", type="pil")
generate_btn.click(
generate_polymer,
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
outputs=[result_smiles, result_image]
)
if __name__ == "__main__":
iface.launch() |