kxhit commited on
Commit
57cc7d0
·
1 Parent(s): 329487d

make get_conf pickle

Browse files
Files changed (1) hide show
  1. dust3r/cloud_opt/commons.py +40 -10
dust3r/cloud_opt/commons.py CHANGED
@@ -4,6 +4,8 @@
4
  # --------------------------------------------------------
5
  # utility functions for global alignment
6
  # --------------------------------------------------------
 
 
7
  import torch
8
  import torch.nn as nn
9
  import numpy as np
@@ -45,18 +47,46 @@ def get_imshapes(edges, pred_i, pred_j):
45
  return imshapes
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def get_conf_trf(mode):
49
- if mode == 'log':
50
- def conf_trf(x): return x.log()
51
- elif mode == 'sqrt':
52
- def conf_trf(x): return x.sqrt()
53
- elif mode == 'm1':
54
- def conf_trf(x): return x-1
55
- elif mode in ('id', 'none'):
56
- def conf_trf(x): return x
57
- else:
58
  raise ValueError(f'bad mode for {mode=}')
59
- return conf_trf
 
60
 
61
 
62
  def l2_dist(a, b, weight):
 
4
  # --------------------------------------------------------
5
  # utility functions for global alignment
6
  # --------------------------------------------------------
7
+ import xmlrpc.client
8
+
9
  import torch
10
  import torch.nn as nn
11
  import numpy as np
 
47
  return imshapes
48
 
49
 
50
+ # def get_conf_trf(mode):
51
+ # if mode == 'log':
52
+ # def conf_trf(x): return x.log()
53
+ # elif mode == 'sqrt':
54
+ # def conf_trf(x): return x.sqrt()
55
+ # elif mode == 'm1':
56
+ # def conf_trf(x): return x-1
57
+ # elif mode in ('id', 'none'):
58
+ # def conf_trf(x): return x
59
+ # else:
60
+ # raise ValueError(f'bad mode for {mode=}')
61
+ # return conf_trf
62
+
63
+
64
+ def conf_trf_log(x):
65
+ return x.log()
66
+
67
+ def conf_trf_sqrt(x):
68
+ return x.sqrt()
69
+
70
+ def conf_trf_m1(x):
71
+ return x - 1
72
+
73
+ def conf_trf_id(x):
74
+ return x
75
+
76
+ # Mapping of modes to their corresponding functions
77
+ conf_trf_map = {
78
+ 'log': conf_trf_log,
79
+ 'sqrt': conf_trf_sqrt,
80
+ 'm1': conf_trf_m1,
81
+ 'id': conf_trf_id,
82
+ 'none': conf_trf_id
83
+ }
84
+
85
  def get_conf_trf(mode):
86
+ if mode not in conf_trf_map:
 
 
 
 
 
 
 
 
87
  raise ValueError(f'bad mode for {mode=}')
88
+ return conf_trf_map[mode]
89
+
90
 
91
 
92
  def l2_dist(a, b, weight):