File size: 2,244 Bytes
e9af19b
 
750cfac
 
e9af19b
 
6a4bbaf
e9af19b
 
6a4bbaf
 
e9af19b
 
 
 
 
491cb22
6a4bbaf
 
491cb22
6a4bbaf
e9af19b
 
6a4bbaf
e9af19b
6a4bbaf
e9af19b
 
 
 
 
 
6a4bbaf
 
 
e9af19b
6a4bbaf
e9af19b
6a4bbaf
 
 
 
e9af19b
6a4bbaf
 
 
 
 
 
e9af19b
6a4bbaf
e9af19b
 
6a4bbaf
 
491cb22
e9af19b
6a4bbaf
 
 
e9af19b
 
491cb22
6a4bbaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from graph_decoder.diffusion_model import GraphDiT

# Load the model
def load_graph_decoder(path='model_labeled'):
    model = GraphDiT(
        model_config_path=f"{path}/config.yaml",
        data_info_path=f"{path}/data.meta.json",
        model_dtype=torch.float32,
    )
    model.init_model(path)
    model.disable_grads()
    return model

model = load_graph_decoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
    properties = [CH4, CO2, H2, N2, O2]
    
    try:
        model.to(device)
        generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
        
        if generated_molecule is not None:
            mol = Chem.MolFromSmiles(generated_molecule)
            if mol is not None:
                standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
                img = Draw.MolToImage(mol)
                return standardized_smiles, img
    except Exception as e:
        print(f"Error in generation: {e}")
    
    return "Generation failed", None

# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
    gr.Markdown("## Polymer Design with GraphDiT")
    
    with gr.Row():
        CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
        CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
        H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
        N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
        O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
        guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")

    generate_btn = gr.Button("Generate Polymer")

    with gr.Row():
        result_smiles = gr.Textbox(label="Generated SMILES")
        result_image = gr.Image(label="Molecule Visualization", type="pil")

    generate_btn.click(
        generate_polymer,
        inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
        outputs=[result_smiles, result_image]
    )

if __name__ == "__main__":
    iface.launch()