File size: 1,031 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#pragma once

#include <ATen/ATen.h>

namespace at::caching {

// Some systems (just cudagraphs currently) will persist a static tensor output
// whose TensorImpl does not change across iterations. For these tensors caching
// dtype conversions is invalid. Additionally, there will be an extra reference
// count to these cached tensors that would prevent buffer inplacing and other
// checks on tensor uniqueness. If we are not using these systems the enabled
// flag will be false and we will avoid the hash lookup.

TORCH_API bool is_cached_tensor(const at::Tensor& t);
TORCH_API void add_cached_tensor(const at::Tensor& t);
TORCH_API void remove_cached_tensor(const at::Tensor& t);
TORCH_API void set_cached_tensors_enabled(bool enable);

// For gradient buffer stealing we will adjust the use count of tensors
// which are persisted by cudagraphs, just as we need to adjust reference
// count of tensors with hooks.
TORCH_API size_t adjusted_use_count(const at::Tensor& t);

} // namespace at::caching