liuganghuggingface commited on
Commit
1678cdd
·
verified ·
1 Parent(s): 1ff1c11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -17,7 +17,42 @@ from datetime import datetime
17
  import json
18
 
19
  from evaluator import Evaluator
20
- from loader import load_graph_decoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Load the CSV data
23
  known_labels = pd.read_csv('data/known_labels.csv')
 
17
  import json
18
 
19
  from evaluator import Evaluator
20
+ # from loader import load_graph_decoder
21
+
22
+ ### loader
23
+ from graph_decoder.diffusion_model import GraphDiT
24
+ def count_parameters(model):
25
+ r"""
26
+ Returns the number of trainable parameters and number of all parameters in the model.
27
+ """
28
+ trainable_params, all_param = 0, 0
29
+ for param in model.parameters():
30
+ num_params = param.numel()
31
+ all_param += num_params
32
+ if param.requires_grad:
33
+ trainable_params += num_params
34
+
35
+ return trainable_params, all_param
36
+
37
+ def load_graph_decoder(path='model_labeled'):
38
+ model_config_path = f"{path}/config.yaml"
39
+ data_info_path = f"{path}/data.meta.json"
40
+
41
+ model = GraphDiT(
42
+ model_config_path=model_config_path,
43
+ data_info_path=data_info_path,
44
+ # model_dtype=torch.float16,
45
+ model_dtype=torch.float32,
46
+ )
47
+ model.init_model(path)
48
+ model.disable_grads()
49
+
50
+ trainable_params, all_param = count_parameters(model)
51
+ param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
52
+ path, trainable_params, all_param, 100 * trainable_params / all_param
53
+ )
54
+ print(param_stats)
55
+ return model
56
 
57
  # Load the CSV data
58
  known_labels = pd.read_csv('data/known_labels.csv')