File size: 1,023 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

torch::Tensor bias_act_op_impl(const torch::Tensor &input,
                               const torch::Tensor &bias,
                               const torch::Tensor &xref,
                               const torch::Tensor &yref,
                               const torch::Tensor &dy, int grad, int dim,
                               int act, float alpha, float gain, float clamp) {
  return DISPATCH_DEVICE_IMPL(bias_act_op_impl, input, bias, xref, yref, dy,
                              grad, dim, act, alpha, gain, clamp);
}

torch::Tensor bias_act(const torch::Tensor &input, const torch::Tensor &bias,
                       const torch::Tensor &xref, const torch::Tensor &yref,
                       const torch::Tensor &dy, int grad, int dim, int act,
                       float alpha, float gain, float clamp) {
  return bias_act_op_impl(input, bias, xref, yref, dy, grad, dim, act, alpha,
                          gain, clamp);
}