Spaces:
Running
Running
File size: 5,820 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include "csrc_dipu/diopirt/diopirt_impl.h"
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
DISPATCH_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, input, target, weight,
output, gamma, alpha);
}
void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha) {
DISPATCH_DEVICE_IMPL(sigmoid_focal_loss_backward_impl, input, target, weight,
grad_input, gamma, alpha);
}
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, input, target, weight,
output, gamma, alpha);
}
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, input, target, weight,
buff, grad_input, gamma, alpha);
}
#ifdef MMCV_WITH_DIOPI
void sigmoid_focal_loss_forward_diopi(Tensor input, Tensor target,
Tensor weight, Tensor output, float gamma,
float alpha) {
auto input_p = toDiopiTensorHandle(input);
diopiDevice_t device;
diopiGetTensorDevice(input_p, &device);
if (device == diopi_host) {
sigmoid_focal_loss_forward_impl(input, target, weight, output, gamma,
alpha);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto target_p = toDiopiTensorHandle(target);
auto weight_p = toDiopiTensorHandle(weight);
auto output_p = toDiopiTensorHandle(output);
if (reinterpret_cast<void *>(diopiSigmoidFocalLossMmcv) != nullptr) {
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
}
LOG(WARNING)
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
auto input_cpu = input.cpu();
auto target_cpu = target.cpu();
auto weight_cpu = weight.cpu();
auto output_cpu = output.cpu();
sigmoid_focal_loss_forward_impl(input_cpu, target_cpu, weight_cpu, output_cpu,
gamma, alpha);
output.copy_(output_cpu);
return;
}
void sigmoid_focal_loss_backward_diopi(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha) {
auto input_p = toDiopiTensorHandle(input);
diopiDevice_t device;
diopiGetTensorDevice(input_p, &device);
if (device == diopi_host) {
sigmoid_focal_loss_backward_impl(input, target, weight, grad_input, gamma,
alpha);
return;
}
diopiContext ctx(dipu::getCurrentDIPUStream().rawstream());
diopiContextHandle_t ch = &ctx;
auto target_p = toDiopiTensorHandle(target);
auto weight_p = toDiopiTensorHandle(weight);
auto grad_input_p = toDiopiTensorHandle(grad_input);
if (reinterpret_cast<void *>(diopiSigmoidFocalLossBackwardMmcv) != nullptr) {
auto ret = diopiSigmoidFocalLossBackwardMmcv(
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
}
LOG(WARNING)
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
auto input_cpu = input.cpu();
auto target_cpu = target.cpu();
auto weight_cpu = weight.cpu();
auto grad_input_cpu = grad_input.cpu();
sigmoid_focal_loss_backward_impl(input_cpu, target_cpu, weight_cpu,
grad_input_cpu, gamma, alpha);
grad_input.copy_(grad_input_cpu);
return;
}
#endif
void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
#ifdef MMCV_WITH_DIOPI
sigmoid_focal_loss_forward_diopi(input, target, weight, output, gamma, alpha);
#else
sigmoid_focal_loss_forward_impl(input, target, weight, output, gamma, alpha);
#endif
}
void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha) {
#ifdef MMCV_WITH_DIOPI
sigmoid_focal_loss_backward_diopi(input, target, weight, grad_input, gamma,
alpha);
#else
sigmoid_focal_loss_backward_impl(input, target, weight, grad_input, gamma,
alpha);
#endif
}
void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
softmax_focal_loss_forward_impl(input, target, weight, output, gamma, alpha);
}
void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha) {
softmax_focal_loss_backward_impl(input, target, weight, buff, grad_input,
gamma, alpha);
}
|