liuganghuggingface's picture
Upload app.py with huggingface_hub
de8e581 verified
raw
history blame
4.95 kB
# import spaces
import gradio as gr
import torch
import torch.nn as nn
import random
from rdkit import Chem
from rdkit.Chem import Draw
from graph_decoder.diffusion_model import GraphDiT
ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = 'model_labeled'
# model = GraphDiT(
# model_config_path=f"{path}/config.yaml",
# data_info_path=f"{path}/data.meta.json",
# model_dtype=torch.float32
# )
# model.to(device)
def generate_random_smiles(length=10):
return ''.join(random.choices(ATOM_SYMBOLS, k=length))
# @spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
properties = torch.tensor([CH4, CO2, H2, N2, O2], dtype=torch.float32).unsqueeze(0)
print('in generate_polymer')
try:
# Generate a random SMILES string (this is a placeholder)
generated_molecule = generate_random_smiles()
# model.generate(properties, device)
mol = Chem.MolFromSmiles(generated_molecule)
if mol is not None:
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
img = Draw.MolToImage(mol)
return standardized_smiles, img
except Exception as e:
print(f"Error in generation: {e}")
return "Generation failed", None
# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
gr.Markdown("## Polymer Design with Random Neural Network")
with gr.Row():
CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
generate_btn = gr.Button("Generate Polymer")
with gr.Row():
result_smiles = gr.Textbox(label="Generated SMILES")
result_image = gr.Image(label="Molecule Visualization", type="pil")
generate_btn.click(
generate_polymer,
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
outputs=[result_smiles, result_image]
)
if __name__ == "__main__":
iface.launch()
# import spaces
# import gradio as gr
# import torch
# from rdkit import Chem
# from rdkit.Chem import Draw
# # from graph_decoder.diffusion_model import GraphDiT
# # Load the model
# def load_graph_decoder(path='model_labeled'):
# model = GraphDiT(
# model_config_path=f"{path}/config.yaml",
# data_info_path=f"{path}/data.meta.json",
# model_dtype=torch.float32,
# )
# model.init_model(path)
# model.disable_grads()
# return model
# # model = load_graph_decoder()
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# @spaces.GPU
# def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
# properties = [CH4, CO2, H2, N2, O2]
# try:
# model = load_graph_decoder()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print('enter function')
# generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
# if generated_molecule is not None:
# mol = Chem.MolFromSmiles(generated_molecule)
# if mol is not None:
# standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
# img = Draw.MolToImage(mol)
# return standardized_smiles, img
# except Exception as e:
# print(f"Error in generation: {e}")
# return "Generation failed", None
# # Create the Gradio interface
# with gr.Blocks(title="Simplified Polymer Design") as iface:
# gr.Markdown("## Polymer Design with GraphDiT")
# with gr.Row():
# CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
# CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
# H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
# N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
# O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
# guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
# generate_btn = gr.Button("Generate Polymer")
# with gr.Row():
# result_smiles = gr.Textbox(label="Generated SMILES")
# result_image = gr.Image(label="Molecule Visualization", type="pil")
# generate_btn.click(
# generate_polymer,
# inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
# outputs=[result_smiles, result_image]
# )
# if __name__ == "__main__":
# iface.launch()