Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from abc import ABC, abstractmethod | |
| from typing import Dict, List | |
| from torch import Tensor | |
| from torch.nn import Module | |
| from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc | |
| class CachedComputationProtocol(ABC): | |
| def get_output(self, | |
| key: str, | |
| modules: Dict[str, Module], | |
| batch: List[Tensor], | |
| outputs: Dict[str, List[Tensor]]): | |
| if key in outputs: | |
| return outputs[key] | |
| else: | |
| output = self.compute_output(key, modules, batch, outputs) | |
| outputs[key] = output | |
| return outputs[key] | |
| def compute_output(self, | |
| key: str, | |
| modules: Dict[str, Module], | |
| batch: List[Tensor], | |
| outputs: Dict[str, List[Tensor]]) -> List[Tensor]: | |
| pass | |
| def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc: | |
| def func(modules: Dict[str, Module], | |
| batch: List[Tensor], | |
| outputs: Dict[str, List[Tensor]]): | |
| return self.get_output(key, modules, batch, outputs)[index] | |
| return func | |
| def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc: | |
| def func(modules: Dict[str, Module], | |
| batch: List[Tensor], | |
| outputs: Dict[str, List[Tensor]]): | |
| return self.get_output(key, modules, batch, outputs) | |
| return func |