File size: 938 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import abc
import torch.nn as nn

class Sharder(abc.ABC):
    """

    This is an interface which allows user to create more advanced

    sharding strategies that are not easily be composed by the

    `ShardingSpec`.



    :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could

    take an object of the `Sharder` and call `shard` to shard the module,

    then replace the original module with sharded module returned.

    """
    @abc.abstractmethod
    def shard(self, module: nn.Module) -> nn.Module:
        """

        Shard a module base on the implementation of this method, and

        return the sharded version of the module.



        Args:

            module (:class:`torch.nn.Module`):

                The module to apply sharding to.

        Returns:

            A :class:`torch.nn.Module` object that represents a module

            that's already been sharded.

        """
        pass