ptame.data.imagenet_datamodule

  1from pathlib import Path
  2from typing import Any
  3
  4import torch
  5import torchvision.transforms.v2 as transforms
  6from importlib_resources import as_file, files
  7from lightning import LightningDataModule
  8from torch.utils.data import DataLoader, Dataset, Subset, random_split
  9from torchvision.datasets import ImageNet, Imagenette
 10
 11import ptame
 12
 13
 14def norm(x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
 15    """Normalize the input tensor.
 16
 17    :param x: The input tensor.
 18    :param reverse: Whether to reverse the normalization. Defaults to `False`.
 19    :return: The normalized tensor.
 20    """
 21    if reverse:
 22        return transforms.Normalize(
 23            (-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225),
 24            (1 / 0.229, 1 / 0.224, 1 / 0.225),
 25        )(x)
 26    else:
 27        return transforms.Normalize(
 28            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
 29        )(x)
 30
 31
 32class ImageNetDataModule(LightningDataModule):
 33    """`LightningDataModule` for the imagenet dataset.
 34
 35    A `LightningDataModule` implements 7 key methods:
 36
 37    ```python
 38        def prepare_data(self):
 39        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
 40        # Download data, pre-process, split, save to disk, etc...
 41
 42        def setup(self, stage):
 43        # Things to do on every process in DDP.
 44        # Load data, set variables, etc...
 45
 46        def train_dataloader(self):
 47        # return train dataloader
 48
 49        def val_dataloader(self):
 50        # return validation dataloader
 51
 52        def test_dataloader(self):
 53        # return test dataloader
 54
 55        def predict_dataloader(self):
 56        # return predict dataloader
 57
 58        def teardown(self, stage):
 59        # Called on every process in DDP.
 60        # Clean up after fit or test.
 61    ```
 62
 63    This allows you to share a full dataset without explaining how to download,
 64    split, transform and process the data.
 65
 66    Read the docs:
 67        https://lightning.ai/docs/pytorch/latest/data/datamodule.html
 68    """
 69
 70    def __init__(
 71        self,
 72        data_dir: str = "data/ImageNet",
 73        batch_size: int = 64,
 74        num_workers: int = 0,
 75        pin_memory: bool = False,
 76        test: bool = True,
 77        test_for_train: bool = False,
 78        imagenette: bool = False,
 79        imagenette_augs: bool = False,
 80        pct: float | None = None,
 81    ) -> None:
 82        """Initialize a `MNISTDataModule`.
 83
 84        :param data_dir: The data directory. Defaults to `"data/"`.
 85        :param batch_size: The batch size. Defaults to `64`.
 86        :param num_workers: The number of workers. Defaults to `0`.
 87        :param pin_memory: Whether to pin memory. Defaults to `False`.
 88        """
 89        super().__init__()
 90
 91        # this line allows to access init params with 'self.hparams' attribute
 92        # also ensures init params will be stored in ckpt
 93        self.save_hyperparameters(logger=False)
 94
 95        # data transformations
 96        self.train_tsfm = transforms.Compose(
 97            [
 98                transforms.Resize(256),
 99                transforms.RandomCrop(224),
100                transforms.RandomHorizontalFlip(),
101                transforms.ToImage(),
102                transforms.ToDtype(torch.float32, scale=True),
103                transforms.Normalize(
104                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
105                ),
106            ]
107        )
108
109        self.imagenette_tsfm = transforms.Compose(
110            [
111                transforms.Resize(256),
112                transforms.RandomCrop(224),
113                transforms.RandomHorizontalFlip(),
114                transforms.AutoAugment(),
115                transforms.AugMix(),
116                transforms.ToImage(),
117                transforms.ToDtype(torch.float32, scale=True),
118                transforms.Normalize(
119                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
120                ),
121            ]
122        )
123        self.val_tsfm = transforms.Compose(
124            [
125                transforms.Resize(256),
126                transforms.CenterCrop(224),
127                transforms.ToImage(),
128                transforms.ToDtype(torch.float32, scale=True),
129                transforms.Normalize(
130                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
131                ),
132            ]
133        )
134        self.data_train: Dataset | None = None
135        self.data_val: Dataset | None = None
136        self.data_test: Dataset | None = None
137        self.batch_size_per_device = batch_size
138
139    @property
140    def num_classes(self) -> int:
141        """Get the number of classes.
142
143        :return: The number of imagenet classes (1000).
144        """
145        if self.hparams.imagenette:
146            return 10
147        else:
148            return 1000
149
150    def prepare_data(self):
151        if self.hparams.imagenette:
152            if not (p := Path(self.hparams.data_dir)).is_dir():
153                p.mkdir(parents=False, exist_ok=False)
154                Imagenette(self.hparams.data_dir, download=True)
155
156    def setup(self, stage: str | None = None) -> None:
157        """Load data. Set variables: `self.data_train`, `self.data_val`,
158        `self.data_test`.
159
160        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
161        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
162        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
163        `self.setup()` once the data is prepared and available for use.
164
165        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
166        """
167        # Divide batch size by the number of devices.
168        if self.trainer is not None:
169            if self.hparams.batch_size % self.trainer.world_size != 0:
170                raise RuntimeError(
171                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
172                )
173            self.batch_size_per_device = (
174                self.hparams.batch_size // self.trainer.world_size
175            )
176        dataset = Imagenette if self.hparams.imagenette else ImageNet
177        train_tsfm = (
178            self.train_tsfm
179            if not (self.hparams.imagenette or self.hparams.imagenette_augs)
180            else self.imagenette_tsfm
181        )
182        # load and split datasets only if not loaded already
183        if not self.data_train and not self.data_val and not self.data_test:
184            self.data_train = dataset(
185                self.hparams.data_dir, split="train", transform=train_tsfm
186            )
187            if self.hparams.pct is not None:
188                self.data_train, _ = random_split(
189                    dataset(
190                        self.hparams.data_dir,
191                        split="train",
192                        transform=train_tsfm,
193                    ),
194                    [self.hparams.pct, 1 - self.hparams.pct],
195                )
196            if self.hparams.imagenette:
197                self.data_val, self.data_test = random_split(
198                    dataset(
199                        self.hparams.data_dir,
200                        split="val",
201                        transform=self.val_tsfm,
202                    ),
203                    [0.5, 0.5],
204                )
205            else:
206                datalist_source = files(ptame.datalists)
207                with as_file(datalist_source.joinpath("val_set.pt")) as f:
208                    self.data_val = Subset(
209                        dataset(
210                            self.hparams.data_dir,
211                            split="val",
212                            transform=self.val_tsfm,
213                        ),
214                        torch.load(
215                            f,
216                            weights_only=True,
217                        ),
218                    )
219                with as_file(datalist_source.joinpath("test_set.pt")) as f:
220                    self.data_test = Subset(
221                        dataset(
222                            self.hparams.data_dir,
223                            split="val",
224                            transform=self.val_tsfm,
225                        ),
226                        torch.load(
227                            f,
228                            weights_only=True,
229                        ),
230                    )
231
232    def train_dataloader(self) -> DataLoader[Any]:
233        """Create and return the train dataloader.
234
235        :return: The train dataloader.
236        """
237        if self.hparams.test_for_train:
238            return self.test_dataloader()
239        return DataLoader(
240            dataset=self.data_train,
241            batch_size=self.batch_size_per_device,
242            num_workers=self.hparams.num_workers,
243            pin_memory=self.hparams.pin_memory,
244            shuffle=True,
245        )
246
247    def val_dataloader(self) -> DataLoader[Any]:
248        """Create and return the validation dataloader.
249
250        :return: The validation dataloader.
251        """
252        return DataLoader(
253            dataset=self.data_val,
254            batch_size=self.batch_size_per_device,
255            num_workers=self.hparams.num_workers,
256            pin_memory=self.hparams.pin_memory,
257            shuffle=False,
258        )
259
260    def test_dataloader(self) -> DataLoader[Any]:
261        """Create and return the test dataloader.
262
263        :return: The test dataloader.
264        """
265        if not self.hparams.test:
266            return self.val_dataloader()
267        return DataLoader(
268            dataset=self.data_test,
269            batch_size=self.batch_size_per_device,
270            num_workers=self.hparams.num_workers,
271            pin_memory=self.hparams.pin_memory,
272            shuffle=False,
273        )
274
275    def teardown(self, stage: str | None = None) -> None:
276        """Lightning hook for cleaning up after `trainer.fit()`,
277        `trainer.validate()`, `trainer.test()`, and `trainer.predict()`.
278
279        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
280            Defaults to ``None``.
281        """
282        pass
283
284    def state_dict(self) -> dict[Any, Any]:
285        """Called when saving a checkpoint. Implement to generate and save the
286        datamodule state.
287
288        :return: A dictionary containing the datamodule state that you want to
289            save.
290        """
291        return {}
292
293    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
294        """Called when loading a checkpoint. Implement to reload datamodule
295        state given datamodule `state_dict()`.
296
297        :param state_dict: The datamodule state returned by
298            `self.state_dict()`.
299        """
300        pass
301
302
303if __name__ == "__main__":
304    _ = ImageNetDataModule()
def norm(x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
15def norm(x: torch.Tensor, reverse: bool = False) -> torch.Tensor:
16    """Normalize the input tensor.
17
18    :param x: The input tensor.
19    :param reverse: Whether to reverse the normalization. Defaults to `False`.
20    :return: The normalized tensor.
21    """
22    if reverse:
23        return transforms.Normalize(
24            (-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225),
25            (1 / 0.229, 1 / 0.224, 1 / 0.225),
26        )(x)
27    else:
28        return transforms.Normalize(
29            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
30        )(x)

Normalize the input tensor.

Parameters
  • x: The input tensor.
  • reverse: Whether to reverse the normalization. Defaults to False.
Returns

The normalized tensor.

class ImageNetDataModule(lightning.pytorch.core.datamodule.LightningDataModule):
 33class ImageNetDataModule(LightningDataModule):
 34    """`LightningDataModule` for the imagenet dataset.
 35
 36    A `LightningDataModule` implements 7 key methods:
 37
 38    ```python
 39        def prepare_data(self):
 40        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
 41        # Download data, pre-process, split, save to disk, etc...
 42
 43        def setup(self, stage):
 44        # Things to do on every process in DDP.
 45        # Load data, set variables, etc...
 46
 47        def train_dataloader(self):
 48        # return train dataloader
 49
 50        def val_dataloader(self):
 51        # return validation dataloader
 52
 53        def test_dataloader(self):
 54        # return test dataloader
 55
 56        def predict_dataloader(self):
 57        # return predict dataloader
 58
 59        def teardown(self, stage):
 60        # Called on every process in DDP.
 61        # Clean up after fit or test.
 62    ```
 63
 64    This allows you to share a full dataset without explaining how to download,
 65    split, transform and process the data.
 66
 67    Read the docs:
 68        https://lightning.ai/docs/pytorch/latest/data/datamodule.html
 69    """
 70
 71    def __init__(
 72        self,
 73        data_dir: str = "data/ImageNet",
 74        batch_size: int = 64,
 75        num_workers: int = 0,
 76        pin_memory: bool = False,
 77        test: bool = True,
 78        test_for_train: bool = False,
 79        imagenette: bool = False,
 80        imagenette_augs: bool = False,
 81        pct: float | None = None,
 82    ) -> None:
 83        """Initialize a `MNISTDataModule`.
 84
 85        :param data_dir: The data directory. Defaults to `"data/"`.
 86        :param batch_size: The batch size. Defaults to `64`.
 87        :param num_workers: The number of workers. Defaults to `0`.
 88        :param pin_memory: Whether to pin memory. Defaults to `False`.
 89        """
 90        super().__init__()
 91
 92        # this line allows to access init params with 'self.hparams' attribute
 93        # also ensures init params will be stored in ckpt
 94        self.save_hyperparameters(logger=False)
 95
 96        # data transformations
 97        self.train_tsfm = transforms.Compose(
 98            [
 99                transforms.Resize(256),
100                transforms.RandomCrop(224),
101                transforms.RandomHorizontalFlip(),
102                transforms.ToImage(),
103                transforms.ToDtype(torch.float32, scale=True),
104                transforms.Normalize(
105                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
106                ),
107            ]
108        )
109
110        self.imagenette_tsfm = transforms.Compose(
111            [
112                transforms.Resize(256),
113                transforms.RandomCrop(224),
114                transforms.RandomHorizontalFlip(),
115                transforms.AutoAugment(),
116                transforms.AugMix(),
117                transforms.ToImage(),
118                transforms.ToDtype(torch.float32, scale=True),
119                transforms.Normalize(
120                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
121                ),
122            ]
123        )
124        self.val_tsfm = transforms.Compose(
125            [
126                transforms.Resize(256),
127                transforms.CenterCrop(224),
128                transforms.ToImage(),
129                transforms.ToDtype(torch.float32, scale=True),
130                transforms.Normalize(
131                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
132                ),
133            ]
134        )
135        self.data_train: Dataset | None = None
136        self.data_val: Dataset | None = None
137        self.data_test: Dataset | None = None
138        self.batch_size_per_device = batch_size
139
140    @property
141    def num_classes(self) -> int:
142        """Get the number of classes.
143
144        :return: The number of imagenet classes (1000).
145        """
146        if self.hparams.imagenette:
147            return 10
148        else:
149            return 1000
150
151    def prepare_data(self):
152        if self.hparams.imagenette:
153            if not (p := Path(self.hparams.data_dir)).is_dir():
154                p.mkdir(parents=False, exist_ok=False)
155                Imagenette(self.hparams.data_dir, download=True)
156
157    def setup(self, stage: str | None = None) -> None:
158        """Load data. Set variables: `self.data_train`, `self.data_val`,
159        `self.data_test`.
160
161        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
162        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
163        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
164        `self.setup()` once the data is prepared and available for use.
165
166        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
167        """
168        # Divide batch size by the number of devices.
169        if self.trainer is not None:
170            if self.hparams.batch_size % self.trainer.world_size != 0:
171                raise RuntimeError(
172                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
173                )
174            self.batch_size_per_device = (
175                self.hparams.batch_size // self.trainer.world_size
176            )
177        dataset = Imagenette if self.hparams.imagenette else ImageNet
178        train_tsfm = (
179            self.train_tsfm
180            if not (self.hparams.imagenette or self.hparams.imagenette_augs)
181            else self.imagenette_tsfm
182        )
183        # load and split datasets only if not loaded already
184        if not self.data_train and not self.data_val and not self.data_test:
185            self.data_train = dataset(
186                self.hparams.data_dir, split="train", transform=train_tsfm
187            )
188            if self.hparams.pct is not None:
189                self.data_train, _ = random_split(
190                    dataset(
191                        self.hparams.data_dir,
192                        split="train",
193                        transform=train_tsfm,
194                    ),
195                    [self.hparams.pct, 1 - self.hparams.pct],
196                )
197            if self.hparams.imagenette:
198                self.data_val, self.data_test = random_split(
199                    dataset(
200                        self.hparams.data_dir,
201                        split="val",
202                        transform=self.val_tsfm,
203                    ),
204                    [0.5, 0.5],
205                )
206            else:
207                datalist_source = files(ptame.datalists)
208                with as_file(datalist_source.joinpath("val_set.pt")) as f:
209                    self.data_val = Subset(
210                        dataset(
211                            self.hparams.data_dir,
212                            split="val",
213                            transform=self.val_tsfm,
214                        ),
215                        torch.load(
216                            f,
217                            weights_only=True,
218                        ),
219                    )
220                with as_file(datalist_source.joinpath("test_set.pt")) as f:
221                    self.data_test = Subset(
222                        dataset(
223                            self.hparams.data_dir,
224                            split="val",
225                            transform=self.val_tsfm,
226                        ),
227                        torch.load(
228                            f,
229                            weights_only=True,
230                        ),
231                    )
232
233    def train_dataloader(self) -> DataLoader[Any]:
234        """Create and return the train dataloader.
235
236        :return: The train dataloader.
237        """
238        if self.hparams.test_for_train:
239            return self.test_dataloader()
240        return DataLoader(
241            dataset=self.data_train,
242            batch_size=self.batch_size_per_device,
243            num_workers=self.hparams.num_workers,
244            pin_memory=self.hparams.pin_memory,
245            shuffle=True,
246        )
247
248    def val_dataloader(self) -> DataLoader[Any]:
249        """Create and return the validation dataloader.
250
251        :return: The validation dataloader.
252        """
253        return DataLoader(
254            dataset=self.data_val,
255            batch_size=self.batch_size_per_device,
256            num_workers=self.hparams.num_workers,
257            pin_memory=self.hparams.pin_memory,
258            shuffle=False,
259        )
260
261    def test_dataloader(self) -> DataLoader[Any]:
262        """Create and return the test dataloader.
263
264        :return: The test dataloader.
265        """
266        if not self.hparams.test:
267            return self.val_dataloader()
268        return DataLoader(
269            dataset=self.data_test,
270            batch_size=self.batch_size_per_device,
271            num_workers=self.hparams.num_workers,
272            pin_memory=self.hparams.pin_memory,
273            shuffle=False,
274        )
275
276    def teardown(self, stage: str | None = None) -> None:
277        """Lightning hook for cleaning up after `trainer.fit()`,
278        `trainer.validate()`, `trainer.test()`, and `trainer.predict()`.
279
280        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
281            Defaults to ``None``.
282        """
283        pass
284
285    def state_dict(self) -> dict[Any, Any]:
286        """Called when saving a checkpoint. Implement to generate and save the
287        datamodule state.
288
289        :return: A dictionary containing the datamodule state that you want to
290            save.
291        """
292        return {}
293
294    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
295        """Called when loading a checkpoint. Implement to reload datamodule
296        state given datamodule `state_dict()`.
297
298        :param state_dict: The datamodule state returned by
299            `self.state_dict()`.
300        """
301        pass

LightningDataModule for the imagenet dataset.

A LightningDataModule implements 7 key methods:

    def prepare_data(self):
    # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
    # Download data, pre-process, split, save to disk, etc...

    def setup(self, stage):
    # Things to do on every process in DDP.
    # Load data, set variables, etc...

    def train_dataloader(self):
    # return train dataloader

    def val_dataloader(self):
    # return validation dataloader

    def test_dataloader(self):
    # return test dataloader

    def predict_dataloader(self):
    # return predict dataloader

    def teardown(self, stage):
    # Called on every process in DDP.
    # Clean up after fit or test.

This allows you to share a full dataset without explaining how to download, split, transform and process the data.

Read the docs: https://lightning.ai/docs/pytorch/latest/data/datamodule.html

ImageNetDataModule( data_dir: str = 'data/ImageNet', batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, test: bool = True, test_for_train: bool = False, imagenette: bool = False, imagenette_augs: bool = False, pct: float | None = None)
 71    def __init__(
 72        self,
 73        data_dir: str = "data/ImageNet",
 74        batch_size: int = 64,
 75        num_workers: int = 0,
 76        pin_memory: bool = False,
 77        test: bool = True,
 78        test_for_train: bool = False,
 79        imagenette: bool = False,
 80        imagenette_augs: bool = False,
 81        pct: float | None = None,
 82    ) -> None:
 83        """Initialize a `MNISTDataModule`.
 84
 85        :param data_dir: The data directory. Defaults to `"data/"`.
 86        :param batch_size: The batch size. Defaults to `64`.
 87        :param num_workers: The number of workers. Defaults to `0`.
 88        :param pin_memory: Whether to pin memory. Defaults to `False`.
 89        """
 90        super().__init__()
 91
 92        # this line allows to access init params with 'self.hparams' attribute
 93        # also ensures init params will be stored in ckpt
 94        self.save_hyperparameters(logger=False)
 95
 96        # data transformations
 97        self.train_tsfm = transforms.Compose(
 98            [
 99                transforms.Resize(256),
100                transforms.RandomCrop(224),
101                transforms.RandomHorizontalFlip(),
102                transforms.ToImage(),
103                transforms.ToDtype(torch.float32, scale=True),
104                transforms.Normalize(
105                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
106                ),
107            ]
108        )
109
110        self.imagenette_tsfm = transforms.Compose(
111            [
112                transforms.Resize(256),
113                transforms.RandomCrop(224),
114                transforms.RandomHorizontalFlip(),
115                transforms.AutoAugment(),
116                transforms.AugMix(),
117                transforms.ToImage(),
118                transforms.ToDtype(torch.float32, scale=True),
119                transforms.Normalize(
120                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
121                ),
122            ]
123        )
124        self.val_tsfm = transforms.Compose(
125            [
126                transforms.Resize(256),
127                transforms.CenterCrop(224),
128                transforms.ToImage(),
129                transforms.ToDtype(torch.float32, scale=True),
130                transforms.Normalize(
131                    (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
132                ),
133            ]
134        )
135        self.data_train: Dataset | None = None
136        self.data_val: Dataset | None = None
137        self.data_test: Dataset | None = None
138        self.batch_size_per_device = batch_size

Initialize a MNISTDataModule.

Parameters
  • data_dir: The data directory. Defaults to "data/".
  • batch_size: The batch size. Defaults to 64.
  • num_workers: The number of workers. Defaults to 0.
  • pin_memory: Whether to pin memory. Defaults to False.
train_tsfm
imagenette_tsfm
val_tsfm
data_train: torch.utils.data.dataset.Dataset | None
data_val: torch.utils.data.dataset.Dataset | None
data_test: torch.utils.data.dataset.Dataset | None
batch_size_per_device
num_classes: int
140    @property
141    def num_classes(self) -> int:
142        """Get the number of classes.
143
144        :return: The number of imagenet classes (1000).
145        """
146        if self.hparams.imagenette:
147            return 10
148        else:
149            return 1000

Get the number of classes.

Returns

The number of imagenet classes (1000).

def prepare_data(self):
151    def prepare_data(self):
152        if self.hparams.imagenette:
153            if not (p := Path(self.hparams.data_dir)).is_dir():
154                p.mkdir(parents=False, exist_ok=False)
155                Imagenette(self.hparams.data_dir, download=True)

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

DO NOT set state to the model (use setup instead)

since this is NOT called on every device

Example::

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using :ref:prepare_data_per_node<common/lightning_module:prepare_data_per_node>)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.
  2. Once in total. Only called on GLOBAL_RANK=0.

Example::

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
def setup(self, stage: str | None = None) -> None:
157    def setup(self, stage: str | None = None) -> None:
158        """Load data. Set variables: `self.data_train`, `self.data_val`,
159        `self.data_test`.
160
161        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
162        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
163        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
164        `self.setup()` once the data is prepared and available for use.
165
166        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
167        """
168        # Divide batch size by the number of devices.
169        if self.trainer is not None:
170            if self.hparams.batch_size % self.trainer.world_size != 0:
171                raise RuntimeError(
172                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
173                )
174            self.batch_size_per_device = (
175                self.hparams.batch_size // self.trainer.world_size
176            )
177        dataset = Imagenette if self.hparams.imagenette else ImageNet
178        train_tsfm = (
179            self.train_tsfm
180            if not (self.hparams.imagenette or self.hparams.imagenette_augs)
181            else self.imagenette_tsfm
182        )
183        # load and split datasets only if not loaded already
184        if not self.data_train and not self.data_val and not self.data_test:
185            self.data_train = dataset(
186                self.hparams.data_dir, split="train", transform=train_tsfm
187            )
188            if self.hparams.pct is not None:
189                self.data_train, _ = random_split(
190                    dataset(
191                        self.hparams.data_dir,
192                        split="train",
193                        transform=train_tsfm,
194                    ),
195                    [self.hparams.pct, 1 - self.hparams.pct],
196                )
197            if self.hparams.imagenette:
198                self.data_val, self.data_test = random_split(
199                    dataset(
200                        self.hparams.data_dir,
201                        split="val",
202                        transform=self.val_tsfm,
203                    ),
204                    [0.5, 0.5],
205                )
206            else:
207                datalist_source = files(ptame.datalists)
208                with as_file(datalist_source.joinpath("val_set.pt")) as f:
209                    self.data_val = Subset(
210                        dataset(
211                            self.hparams.data_dir,
212                            split="val",
213                            transform=self.val_tsfm,
214                        ),
215                        torch.load(
216                            f,
217                            weights_only=True,
218                        ),
219                    )
220                with as_file(datalist_source.joinpath("test_set.pt")) as f:
221                    self.data_test = Subset(
222                        dataset(
223                            self.hparams.data_dir,
224                            split="val",
225                            transform=self.val_tsfm,
226                        ),
227                        torch.load(
228                            f,
229                            weights_only=True,
230                        ),
231                    )

Load data. Set variables: self.data_train, self.data_val, self.data_test.

This method is called by Lightning before trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict(), so be careful not to execute things like random split twice! Also, it is called after self.prepare_data() and there is a barrier in between which ensures that all the processes proceed to self.setup() once the data is prepared and available for use.

Parameters
  • stage: The stage to setup. Either "fit", "validate", "test", or "predict". Defaults to None.
def train_dataloader(self) -> torch.utils.data.dataloader.DataLoader[typing.Any]:
233    def train_dataloader(self) -> DataLoader[Any]:
234        """Create and return the train dataloader.
235
236        :return: The train dataloader.
237        """
238        if self.hparams.test_for_train:
239            return self.test_dataloader()
240        return DataLoader(
241            dataset=self.data_train,
242            batch_size=self.batch_size_per_device,
243            num_workers=self.hparams.num_workers,
244            pin_memory=self.hparams.pin_memory,
245            shuffle=True,
246        )

Create and return the train dataloader.

Returns

The train dataloader.

def val_dataloader(self) -> torch.utils.data.dataloader.DataLoader[typing.Any]:
248    def val_dataloader(self) -> DataLoader[Any]:
249        """Create and return the validation dataloader.
250
251        :return: The validation dataloader.
252        """
253        return DataLoader(
254            dataset=self.data_val,
255            batch_size=self.batch_size_per_device,
256            num_workers=self.hparams.num_workers,
257            pin_memory=self.hparams.pin_memory,
258            shuffle=False,
259        )

Create and return the validation dataloader.

Returns

The validation dataloader.

def test_dataloader(self) -> torch.utils.data.dataloader.DataLoader[typing.Any]:
261    def test_dataloader(self) -> DataLoader[Any]:
262        """Create and return the test dataloader.
263
264        :return: The test dataloader.
265        """
266        if not self.hparams.test:
267            return self.val_dataloader()
268        return DataLoader(
269            dataset=self.data_test,
270            batch_size=self.batch_size_per_device,
271            num_workers=self.hparams.num_workers,
272            pin_memory=self.hparams.pin_memory,
273            shuffle=False,
274        )

Create and return the test dataloader.

Returns

The test dataloader.

def teardown(self, stage: str | None = None) -> None:
276    def teardown(self, stage: str | None = None) -> None:
277        """Lightning hook for cleaning up after `trainer.fit()`,
278        `trainer.validate()`, `trainer.test()`, and `trainer.predict()`.
279
280        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
281            Defaults to ``None``.
282        """
283        pass

Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().

Parameters
  • stage: The stage being torn down. Either "fit", "validate", "test", or "predict". Defaults to None.
def state_dict(self) -> dict[typing.Any, typing.Any]:
285    def state_dict(self) -> dict[Any, Any]:
286        """Called when saving a checkpoint. Implement to generate and save the
287        datamodule state.
288
289        :return: A dictionary containing the datamodule state that you want to
290            save.
291        """
292        return {}

Called when saving a checkpoint. Implement to generate and save the datamodule state.

Returns

A dictionary containing the datamodule state that you want to save.

def load_state_dict(self, state_dict: dict[str, typing.Any]) -> None:
294    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
295        """Called when loading a checkpoint. Implement to reload datamodule
296        state given datamodule `state_dict()`.
297
298        :param state_dict: The datamodule state returned by
299            `self.state_dict()`.
300        """
301        pass

Called when loading a checkpoint. Implement to reload datamodule state given datamodule state_dict().

Parameters
  • state_dict: The datamodule state returned by self.state_dict().