Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -52,10 +52,12 @@ atexit.register(cleanup_temp_files)
|
|
52 |
def random_properties():
|
53 |
return known_labels[all_properties].sample(1).values.tolist()[0]
|
54 |
|
55 |
-
def load_model(model_choice):
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
|
60 |
# Create a flagged folder if it doesn't exist
|
61 |
flagged_folder = "flagged"
|
@@ -279,7 +281,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
|
|
279 |
```
|
280 |
""")
|
281 |
|
282 |
-
model_state = gr.State(
|
283 |
|
284 |
with gr.Row():
|
285 |
CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
|
@@ -312,7 +314,12 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
|
|
312 |
def switch_model(choice):
|
313 |
# Convert display name back to internal name
|
314 |
internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
|
315 |
-
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
|
318 |
|
|
|
52 |
def random_properties():
|
53 |
return known_labels[all_properties].sample(1).values.tolist()[0]
|
54 |
|
55 |
+
# def load_model(model_choice):
|
56 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
57 |
+
|
58 |
+
model_all = load_graph_decoder(path='model_all')
|
59 |
+
model_labeled = load_graph_decoder(path='model_labeled')
|
60 |
+
# return (model, device)
|
61 |
|
62 |
# Create a flagged folder if it doesn't exist
|
63 |
flagged_folder = "flagged"
|
|
|
281 |
```
|
282 |
""")
|
283 |
|
284 |
+
model_state = gr.State(model_labeled)
|
285 |
|
286 |
with gr.Row():
|
287 |
CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
|
|
|
314 |
def switch_model(choice):
|
315 |
# Convert display name back to internal name
|
316 |
internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
|
317 |
+
if internal_name == 'model_labeled':
|
318 |
+
return model_labeled
|
319 |
+
elif internal_name == 'model_all':
|
320 |
+
return model_all
|
321 |
+
else:
|
322 |
+
raise ValueError('Not support model', internal_name)
|
323 |
|
324 |
model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
|
325 |
|