File size: 387 Bytes
146b945 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
#pragma once
#include <torch/torch.h>
namespace adam_atan2 {
void adam_atan2_cuda_impl_(
std::vector<at::Tensor> params,
std::vector<at::Tensor> grads,
std::vector<at::Tensor> exp_avgs,
std::vector<at::Tensor> exp_avg_sqs,
std::vector<at::Tensor> state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay);
}
|