liuganghuggingface commited on
Commit
d33f846
·
verified ·
1 Parent(s): f04a2e9

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -52,10 +52,12 @@ atexit.register(cleanup_temp_files)
52
  def random_properties():
53
  return known_labels[all_properties].sample(1).values.tolist()[0]
54
 
55
- def load_model(model_choice):
56
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
- model = load_graph_decoder(path=model_choice)
58
- return (model, device)
 
 
59
 
60
  # Create a flagged folder if it doesn't exist
61
  flagged_folder = "flagged"
@@ -279,7 +281,7 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
279
  ```
280
  """)
281
 
282
- model_state = gr.State(lambda: load_model("model_all"))
283
 
284
  with gr.Row():
285
  CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
@@ -312,7 +314,12 @@ with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
312
  def switch_model(choice):
313
  # Convert display name back to internal name
314
  internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
315
- return load_model(internal_name)
 
 
 
 
 
316
 
317
  model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
318
 
 
52
  def random_properties():
53
  return known_labels[all_properties].sample(1).values.tolist()[0]
54
 
55
+ # def load_model(model_choice):
56
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+
58
+ model_all = load_graph_decoder(path='model_all')
59
+ model_labeled = load_graph_decoder(path='model_labeled')
60
+ # return (model, device)
61
 
62
  # Create a flagged folder if it doesn't exist
63
  flagged_folder = "flagged"
 
281
  ```
282
  """)
283
 
284
+ model_state = gr.State(model_labeled)
285
 
286
  with gr.Row():
287
  CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
 
314
  def switch_model(choice):
315
  # Convert display name back to internal name
316
  internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
317
+ if internal_name == 'model_labeled':
318
+ return model_labeled
319
+ elif internal_name == 'model_all':
320
+ return model_all
321
+ else:
322
+ raise ValueError('Not support model', internal_name)
323
 
324
  model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
325