Spaces:
Sleeping
Sleeping
/** | |
* Copyright 2017-present, Facebook, Inc. | |
* All rights reserved. | |
* | |
* This source code is licensed under the license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
/* | |
This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance | |
*/ | |
torch::Tensor LevenshteinDistance( | |
torch::Tensor source, | |
torch::Tensor target, | |
torch::Tensor source_length, | |
torch::Tensor target_length) { | |
CHECK_INPUT(source); | |
CHECK_INPUT(target); | |
CHECK_INPUT(source_length); | |
CHECK_INPUT(target_length); | |
return LevenshteinDistanceCuda(source, target, source_length, target_length); | |
} | |
torch::Tensor GenerateDeletionLabel( | |
torch::Tensor source, | |
torch::Tensor operations) { | |
CHECK_INPUT(source); | |
CHECK_INPUT(operations); | |
return GenerateDeletionLabelCuda(source, operations); | |
} | |
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel( | |
torch::Tensor target, | |
torch::Tensor operations) { | |
CHECK_INPUT(target); | |
CHECK_INPUT(operations); | |
return GenerateInsertionLabelCuda(target, operations); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); | |
m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label"); | |
m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label"); | |
} | |