drbh commited on
Commit
3683745
·
1 Parent(s): 2c2dcc7

fix: refactors and adjust name

Browse files
torch-ext/torch_binding.cpp CHANGED
@@ -5,7 +5,7 @@
5
 
6
 
7
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
- ops.def("adam_atan2("
9
  "Tensor(a!)[] params, "
10
  "Tensor(b!)[] grads, "
11
  "Tensor(c!)[] exp_avgs, "
 
5
 
6
 
7
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
+ ops.def("adam_atan2_cuda_impl_("
9
  "Tensor(a!)[] params, "
10
  "Tensor(b!)[] grads, "
11
  "Tensor(c!)[] exp_avgs, "
torch-ext/torch_binding.h CHANGED
@@ -16,14 +16,3 @@ void adam_atan2_cuda_impl_(
16
  const double weight_decay);
17
 
18
  }
19
-
20
- // void adam_atan2_cuda_impl_(
21
- // std::vector<at::Tensor, std::allocator<at::Tensor> > params,
22
- // std::vector<at::Tensor, std::allocator<at::Tensor> > grads,
23
- // std::vector<at::Tensor, std::allocator<at::Tensor> > exp_avgs,
24
- // std::vector<at::Tensor, std::allocator<at::Tensor> > exp_avg_sqs,
25
- // std::vector<at::Tensor, std::allocator<at::Tensor> > state_steps,
26
- // const double lr,
27
- // const double beta1,
28
- // const double beta2,
29
- // const double weight_decay);
 
16
  const double weight_decay);
17
 
18
  }