File size: 5,877 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Union

from pytorch_ie.core import Document
from pytorch_ie.core.taskmodule import (
    IterableTaskEncodingDataset,
    TaskEncoding,
    TaskEncodingDataset,
    TaskModule,
)
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Sampler
from typing_extensions import TypeAlias

from .components.sampler import ImbalancedDatasetSampler

DocumentType = TypeVar("DocumentType", bound=Document)
InputEncoding = TypeVar("InputEncoding")
TargetEncoding = TypeVar("TargetEncoding")
DatasetType: TypeAlias = Union[
    TaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
    IterableTaskEncodingDataset[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
]


class PieDataModule(LightningDataModule, Generic[DocumentType, InputEncoding, TargetEncoding]):
    """A simple LightningDataModule for PIE document datasets.

    A DataModule implements 5 key methods:
        - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
        - setup (things to do on every accelerator in distributed mode)
        - train_dataloader (the training dataloader)
        - val_dataloader (the validation dataloader(s))
        - test_dataloader (the test dataloader(s))

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
    """

    def __init__(
        self,
        taskmodule: TaskModule[DocumentType, InputEncoding, TargetEncoding, Any, Any, Any],
        dataset: Dict[str, Sequence[DocumentType]],
        data_config_path: Optional[str] = None,
        train_split: Optional[str] = "train",
        val_split: Optional[str] = "validation",
        test_split: Optional[str] = "test",
        show_progress_for_encode: bool = False,
        train_sampler: Optional[str] = None,
        **dataloader_kwargs,
    ):
        super().__init__()

        self.taskmodule = taskmodule
        self.config_path = data_config_path
        self.dataset = dataset
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.show_progress_for_encode = show_progress_for_encode
        self.train_sampler_name = train_sampler
        self.dataloader_kwargs = dataloader_kwargs

        self._data: Dict[str, DatasetType] = {}

    @property
    def num_train(self) -> int:
        if self.train_split is None:
            raise ValueError("no train_split assigned")
        data_train = self._data.get(self.train_split, None)
        if data_train is None:
            raise ValueError("can not get train size if setup() was not yet called")
        if isinstance(data_train, IterableTaskEncodingDataset):
            raise TypeError("IterableTaskEncodingDataset has no length")
        return len(data_train)

    def setup(self, stage: str):
        if stage == "fit":
            split_names = [self.train_split, self.val_split]
        elif stage == "validate":
            split_names = [self.val_split]
        elif stage == "test":
            split_names = [self.test_split]
        else:
            raise NotImplementedError(f"not implemented for stage={stage} ")

        for split in split_names:
            if split is None or split not in self.dataset:
                continue
            task_encoding_dataset = self.taskmodule.encode(
                self.dataset[split],
                encode_target=True,
                as_dataset=True,
                show_progress=self.show_progress_for_encode,
            )
            if not isinstance(
                task_encoding_dataset,
                (TaskEncodingDataset, IterableTaskEncodingDataset),
            ):
                raise TypeError(
                    f"taskmodule.encode did not return a (Iterable)TaskEncodingDataset, but: {type(task_encoding_dataset)}"
                )
            self._data[split] = task_encoding_dataset

    def data_split(self, split: Optional[str] = None) -> DatasetType:
        if split is None or split not in self._data:
            raise ValueError(f"data for split={split} not available")
        return self._data[split]

    def get_train_sampler(
        self,
        sampler_name: str,
        dataset: DatasetType,
    ) -> Sampler:
        if sampler_name == "imbalanced_dataset":
            # for now, this work only with targets that have a single entry
            return ImbalancedDatasetSampler(
                dataset, callback_get_label=lambda ds: [x.targets[0] for x in ds]
            )
        else:
            raise ValueError(f"unknown sampler name: {sampler_name}")

    def train_dataloader(self):
        ds = self.data_split(self.train_split)
        if self.train_sampler_name is not None:
            sampler = self.get_train_sampler(sampler_name=self.train_sampler_name, dataset=ds)
        else:
            sampler = None
        return DataLoader(
            dataset=ds,
            sampler=sampler,
            collate_fn=self.taskmodule.collate,
            # don't shuffle streamed datasets or if we use a sampler
            shuffle=not (isinstance(ds, IterableTaskEncodingDataset) or sampler is not None),
            **self.dataloader_kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_split(self.val_split),
            collate_fn=self.taskmodule.collate,
            shuffle=False,
            **self.dataloader_kwargs,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_split(self.test_split),
            collate_fn=self.taskmodule.collate,
            shuffle=False,
            **self.dataloader_kwargs,
        )