Update app.py
Browse files
app.py
CHANGED
@@ -64,7 +64,6 @@ def load_model(model_choice):
|
|
64 |
model = GraphDiT(
|
65 |
model_config_path=model_config_path,
|
66 |
data_info_path=data_info_path,
|
67 |
-
# model_dtype=torch.float16,
|
68 |
model_dtype=torch.float32,
|
69 |
)
|
70 |
### test
|
@@ -113,7 +112,7 @@ def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_ti
|
|
113 |
|
114 |
for _ in range(repeating_time):
|
115 |
# try:
|
116 |
-
model.to(device)
|
117 |
# generated_molecule, img_list = model.generate(properties, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
|
118 |
generated_molecule = 'CC'
|
119 |
img_list = []
|
|
|
64 |
model = GraphDiT(
|
65 |
model_config_path=model_config_path,
|
66 |
data_info_path=data_info_path,
|
|
|
67 |
model_dtype=torch.float32,
|
68 |
)
|
69 |
### test
|
|
|
112 |
|
113 |
for _ in range(repeating_time):
|
114 |
# try:
|
115 |
+
# model.to(device)
|
116 |
# generated_molecule, img_list = model.generate(properties, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
|
117 |
generated_molecule = 'CC'
|
118 |
img_list = []
|