ptame.data.imagenet_datamodule
1from pathlib import Path 2from typing import Any 3 4import torch 5import torchvision.transforms.v2 as transforms 6from importlib_resources import as_file, files 7from lightning import LightningDataModule 8from torch.utils.data import DataLoader, Dataset, Subset, random_split 9from torchvision.datasets import ImageNet, Imagenette 10 11import ptame 12 13 14def norm(x: torch.Tensor, reverse: bool = False) -> torch.Tensor: 15 """Normalize the input tensor. 16 17 :param x: The input tensor. 18 :param reverse: Whether to reverse the normalization. Defaults to `False`. 19 :return: The normalized tensor. 20 """ 21 if reverse: 22 return transforms.Normalize( 23 (-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225), 24 (1 / 0.229, 1 / 0.224, 1 / 0.225), 25 )(x) 26 else: 27 return transforms.Normalize( 28 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 29 )(x) 30 31 32class ImageNetDataModule(LightningDataModule): 33 """`LightningDataModule` for the imagenet dataset. 34 35 A `LightningDataModule` implements 7 key methods: 36 37 ```python 38 def prepare_data(self): 39 # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). 40 # Download data, pre-process, split, save to disk, etc... 41 42 def setup(self, stage): 43 # Things to do on every process in DDP. 44 # Load data, set variables, etc... 45 46 def train_dataloader(self): 47 # return train dataloader 48 49 def val_dataloader(self): 50 # return validation dataloader 51 52 def test_dataloader(self): 53 # return test dataloader 54 55 def predict_dataloader(self): 56 # return predict dataloader 57 58 def teardown(self, stage): 59 # Called on every process in DDP. 60 # Clean up after fit or test. 61 ``` 62 63 This allows you to share a full dataset without explaining how to download, 64 split, transform and process the data. 65 66 Read the docs: 67 https://lightning.ai/docs/pytorch/latest/data/datamodule.html 68 """ 69 70 def __init__( 71 self, 72 data_dir: str = "data/ImageNet", 73 batch_size: int = 64, 74 num_workers: int = 0, 75 pin_memory: bool = False, 76 test: bool = True, 77 test_for_train: bool = False, 78 imagenette: bool = False, 79 imagenette_augs: bool = False, 80 pct: float | None = None, 81 ) -> None: 82 """Initialize a `MNISTDataModule`. 83 84 :param data_dir: The data directory. Defaults to `"data/"`. 85 :param batch_size: The batch size. Defaults to `64`. 86 :param num_workers: The number of workers. Defaults to `0`. 87 :param pin_memory: Whether to pin memory. Defaults to `False`. 88 """ 89 super().__init__() 90 91 # this line allows to access init params with 'self.hparams' attribute 92 # also ensures init params will be stored in ckpt 93 self.save_hyperparameters(logger=False) 94 95 # data transformations 96 self.train_tsfm = transforms.Compose( 97 [ 98 transforms.Resize(256), 99 transforms.RandomCrop(224), 100 transforms.RandomHorizontalFlip(), 101 transforms.ToImage(), 102 transforms.ToDtype(torch.float32, scale=True), 103 transforms.Normalize( 104 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 105 ), 106 ] 107 ) 108 109 self.imagenette_tsfm = transforms.Compose( 110 [ 111 transforms.Resize(256), 112 transforms.RandomCrop(224), 113 transforms.RandomHorizontalFlip(), 114 transforms.AutoAugment(), 115 transforms.AugMix(), 116 transforms.ToImage(), 117 transforms.ToDtype(torch.float32, scale=True), 118 transforms.Normalize( 119 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 120 ), 121 ] 122 ) 123 self.val_tsfm = transforms.Compose( 124 [ 125 transforms.Resize(256), 126 transforms.CenterCrop(224), 127 transforms.ToImage(), 128 transforms.ToDtype(torch.float32, scale=True), 129 transforms.Normalize( 130 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 131 ), 132 ] 133 ) 134 self.data_train: Dataset | None = None 135 self.data_val: Dataset | None = None 136 self.data_test: Dataset | None = None 137 self.batch_size_per_device = batch_size 138 139 @property 140 def num_classes(self) -> int: 141 """Get the number of classes. 142 143 :return: The number of imagenet classes (1000). 144 """ 145 if self.hparams.imagenette: 146 return 10 147 else: 148 return 1000 149 150 def prepare_data(self): 151 if self.hparams.imagenette: 152 if not (p := Path(self.hparams.data_dir)).is_dir(): 153 p.mkdir(parents=False, exist_ok=False) 154 Imagenette(self.hparams.data_dir, download=True) 155 156 def setup(self, stage: str | None = None) -> None: 157 """Load data. Set variables: `self.data_train`, `self.data_val`, 158 `self.data_test`. 159 160 This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and 161 `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after 162 `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to 163 `self.setup()` once the data is prepared and available for use. 164 165 :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. 166 """ 167 # Divide batch size by the number of devices. 168 if self.trainer is not None: 169 if self.hparams.batch_size % self.trainer.world_size != 0: 170 raise RuntimeError( 171 f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 172 ) 173 self.batch_size_per_device = ( 174 self.hparams.batch_size // self.trainer.world_size 175 ) 176 dataset = Imagenette if self.hparams.imagenette else ImageNet 177 train_tsfm = ( 178 self.train_tsfm 179 if not (self.hparams.imagenette or self.hparams.imagenette_augs) 180 else self.imagenette_tsfm 181 ) 182 # load and split datasets only if not loaded already 183 if not self.data_train and not self.data_val and not self.data_test: 184 self.data_train = dataset( 185 self.hparams.data_dir, split="train", transform=train_tsfm 186 ) 187 if self.hparams.pct is not None: 188 self.data_train, _ = random_split( 189 dataset( 190 self.hparams.data_dir, 191 split="train", 192 transform=train_tsfm, 193 ), 194 [self.hparams.pct, 1 - self.hparams.pct], 195 ) 196 if self.hparams.imagenette: 197 self.data_val, self.data_test = random_split( 198 dataset( 199 self.hparams.data_dir, 200 split="val", 201 transform=self.val_tsfm, 202 ), 203 [0.5, 0.5], 204 ) 205 else: 206 datalist_source = files(ptame.datalists) 207 with as_file(datalist_source.joinpath("val_set.pt")) as f: 208 self.data_val = Subset( 209 dataset( 210 self.hparams.data_dir, 211 split="val", 212 transform=self.val_tsfm, 213 ), 214 torch.load( 215 f, 216 weights_only=True, 217 ), 218 ) 219 with as_file(datalist_source.joinpath("test_set.pt")) as f: 220 self.data_test = Subset( 221 dataset( 222 self.hparams.data_dir, 223 split="val", 224 transform=self.val_tsfm, 225 ), 226 torch.load( 227 f, 228 weights_only=True, 229 ), 230 ) 231 232 def train_dataloader(self) -> DataLoader[Any]: 233 """Create and return the train dataloader. 234 235 :return: The train dataloader. 236 """ 237 if self.hparams.test_for_train: 238 return self.test_dataloader() 239 return DataLoader( 240 dataset=self.data_train, 241 batch_size=self.batch_size_per_device, 242 num_workers=self.hparams.num_workers, 243 pin_memory=self.hparams.pin_memory, 244 shuffle=True, 245 ) 246 247 def val_dataloader(self) -> DataLoader[Any]: 248 """Create and return the validation dataloader. 249 250 :return: The validation dataloader. 251 """ 252 return DataLoader( 253 dataset=self.data_val, 254 batch_size=self.batch_size_per_device, 255 num_workers=self.hparams.num_workers, 256 pin_memory=self.hparams.pin_memory, 257 shuffle=False, 258 ) 259 260 def test_dataloader(self) -> DataLoader[Any]: 261 """Create and return the test dataloader. 262 263 :return: The test dataloader. 264 """ 265 if not self.hparams.test: 266 return self.val_dataloader() 267 return DataLoader( 268 dataset=self.data_test, 269 batch_size=self.batch_size_per_device, 270 num_workers=self.hparams.num_workers, 271 pin_memory=self.hparams.pin_memory, 272 shuffle=False, 273 ) 274 275 def teardown(self, stage: str | None = None) -> None: 276 """Lightning hook for cleaning up after `trainer.fit()`, 277 `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. 278 279 :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 280 Defaults to ``None``. 281 """ 282 pass 283 284 def state_dict(self) -> dict[Any, Any]: 285 """Called when saving a checkpoint. Implement to generate and save the 286 datamodule state. 287 288 :return: A dictionary containing the datamodule state that you want to 289 save. 290 """ 291 return {} 292 293 def load_state_dict(self, state_dict: dict[str, Any]) -> None: 294 """Called when loading a checkpoint. Implement to reload datamodule 295 state given datamodule `state_dict()`. 296 297 :param state_dict: The datamodule state returned by 298 `self.state_dict()`. 299 """ 300 pass 301 302 303if __name__ == "__main__": 304 _ = ImageNetDataModule()
15def norm(x: torch.Tensor, reverse: bool = False) -> torch.Tensor: 16 """Normalize the input tensor. 17 18 :param x: The input tensor. 19 :param reverse: Whether to reverse the normalization. Defaults to `False`. 20 :return: The normalized tensor. 21 """ 22 if reverse: 23 return transforms.Normalize( 24 (-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225), 25 (1 / 0.229, 1 / 0.224, 1 / 0.225), 26 )(x) 27 else: 28 return transforms.Normalize( 29 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 30 )(x)
Normalize the input tensor.
Parameters
- x: The input tensor.
- reverse: Whether to reverse the normalization. Defaults to
False.
Returns
The normalized tensor.
33class ImageNetDataModule(LightningDataModule): 34 """`LightningDataModule` for the imagenet dataset. 35 36 A `LightningDataModule` implements 7 key methods: 37 38 ```python 39 def prepare_data(self): 40 # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). 41 # Download data, pre-process, split, save to disk, etc... 42 43 def setup(self, stage): 44 # Things to do on every process in DDP. 45 # Load data, set variables, etc... 46 47 def train_dataloader(self): 48 # return train dataloader 49 50 def val_dataloader(self): 51 # return validation dataloader 52 53 def test_dataloader(self): 54 # return test dataloader 55 56 def predict_dataloader(self): 57 # return predict dataloader 58 59 def teardown(self, stage): 60 # Called on every process in DDP. 61 # Clean up after fit or test. 62 ``` 63 64 This allows you to share a full dataset without explaining how to download, 65 split, transform and process the data. 66 67 Read the docs: 68 https://lightning.ai/docs/pytorch/latest/data/datamodule.html 69 """ 70 71 def __init__( 72 self, 73 data_dir: str = "data/ImageNet", 74 batch_size: int = 64, 75 num_workers: int = 0, 76 pin_memory: bool = False, 77 test: bool = True, 78 test_for_train: bool = False, 79 imagenette: bool = False, 80 imagenette_augs: bool = False, 81 pct: float | None = None, 82 ) -> None: 83 """Initialize a `MNISTDataModule`. 84 85 :param data_dir: The data directory. Defaults to `"data/"`. 86 :param batch_size: The batch size. Defaults to `64`. 87 :param num_workers: The number of workers. Defaults to `0`. 88 :param pin_memory: Whether to pin memory. Defaults to `False`. 89 """ 90 super().__init__() 91 92 # this line allows to access init params with 'self.hparams' attribute 93 # also ensures init params will be stored in ckpt 94 self.save_hyperparameters(logger=False) 95 96 # data transformations 97 self.train_tsfm = transforms.Compose( 98 [ 99 transforms.Resize(256), 100 transforms.RandomCrop(224), 101 transforms.RandomHorizontalFlip(), 102 transforms.ToImage(), 103 transforms.ToDtype(torch.float32, scale=True), 104 transforms.Normalize( 105 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 106 ), 107 ] 108 ) 109 110 self.imagenette_tsfm = transforms.Compose( 111 [ 112 transforms.Resize(256), 113 transforms.RandomCrop(224), 114 transforms.RandomHorizontalFlip(), 115 transforms.AutoAugment(), 116 transforms.AugMix(), 117 transforms.ToImage(), 118 transforms.ToDtype(torch.float32, scale=True), 119 transforms.Normalize( 120 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 121 ), 122 ] 123 ) 124 self.val_tsfm = transforms.Compose( 125 [ 126 transforms.Resize(256), 127 transforms.CenterCrop(224), 128 transforms.ToImage(), 129 transforms.ToDtype(torch.float32, scale=True), 130 transforms.Normalize( 131 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 132 ), 133 ] 134 ) 135 self.data_train: Dataset | None = None 136 self.data_val: Dataset | None = None 137 self.data_test: Dataset | None = None 138 self.batch_size_per_device = batch_size 139 140 @property 141 def num_classes(self) -> int: 142 """Get the number of classes. 143 144 :return: The number of imagenet classes (1000). 145 """ 146 if self.hparams.imagenette: 147 return 10 148 else: 149 return 1000 150 151 def prepare_data(self): 152 if self.hparams.imagenette: 153 if not (p := Path(self.hparams.data_dir)).is_dir(): 154 p.mkdir(parents=False, exist_ok=False) 155 Imagenette(self.hparams.data_dir, download=True) 156 157 def setup(self, stage: str | None = None) -> None: 158 """Load data. Set variables: `self.data_train`, `self.data_val`, 159 `self.data_test`. 160 161 This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and 162 `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after 163 `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to 164 `self.setup()` once the data is prepared and available for use. 165 166 :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. 167 """ 168 # Divide batch size by the number of devices. 169 if self.trainer is not None: 170 if self.hparams.batch_size % self.trainer.world_size != 0: 171 raise RuntimeError( 172 f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 173 ) 174 self.batch_size_per_device = ( 175 self.hparams.batch_size // self.trainer.world_size 176 ) 177 dataset = Imagenette if self.hparams.imagenette else ImageNet 178 train_tsfm = ( 179 self.train_tsfm 180 if not (self.hparams.imagenette or self.hparams.imagenette_augs) 181 else self.imagenette_tsfm 182 ) 183 # load and split datasets only if not loaded already 184 if not self.data_train and not self.data_val and not self.data_test: 185 self.data_train = dataset( 186 self.hparams.data_dir, split="train", transform=train_tsfm 187 ) 188 if self.hparams.pct is not None: 189 self.data_train, _ = random_split( 190 dataset( 191 self.hparams.data_dir, 192 split="train", 193 transform=train_tsfm, 194 ), 195 [self.hparams.pct, 1 - self.hparams.pct], 196 ) 197 if self.hparams.imagenette: 198 self.data_val, self.data_test = random_split( 199 dataset( 200 self.hparams.data_dir, 201 split="val", 202 transform=self.val_tsfm, 203 ), 204 [0.5, 0.5], 205 ) 206 else: 207 datalist_source = files(ptame.datalists) 208 with as_file(datalist_source.joinpath("val_set.pt")) as f: 209 self.data_val = Subset( 210 dataset( 211 self.hparams.data_dir, 212 split="val", 213 transform=self.val_tsfm, 214 ), 215 torch.load( 216 f, 217 weights_only=True, 218 ), 219 ) 220 with as_file(datalist_source.joinpath("test_set.pt")) as f: 221 self.data_test = Subset( 222 dataset( 223 self.hparams.data_dir, 224 split="val", 225 transform=self.val_tsfm, 226 ), 227 torch.load( 228 f, 229 weights_only=True, 230 ), 231 ) 232 233 def train_dataloader(self) -> DataLoader[Any]: 234 """Create and return the train dataloader. 235 236 :return: The train dataloader. 237 """ 238 if self.hparams.test_for_train: 239 return self.test_dataloader() 240 return DataLoader( 241 dataset=self.data_train, 242 batch_size=self.batch_size_per_device, 243 num_workers=self.hparams.num_workers, 244 pin_memory=self.hparams.pin_memory, 245 shuffle=True, 246 ) 247 248 def val_dataloader(self) -> DataLoader[Any]: 249 """Create and return the validation dataloader. 250 251 :return: The validation dataloader. 252 """ 253 return DataLoader( 254 dataset=self.data_val, 255 batch_size=self.batch_size_per_device, 256 num_workers=self.hparams.num_workers, 257 pin_memory=self.hparams.pin_memory, 258 shuffle=False, 259 ) 260 261 def test_dataloader(self) -> DataLoader[Any]: 262 """Create and return the test dataloader. 263 264 :return: The test dataloader. 265 """ 266 if not self.hparams.test: 267 return self.val_dataloader() 268 return DataLoader( 269 dataset=self.data_test, 270 batch_size=self.batch_size_per_device, 271 num_workers=self.hparams.num_workers, 272 pin_memory=self.hparams.pin_memory, 273 shuffle=False, 274 ) 275 276 def teardown(self, stage: str | None = None) -> None: 277 """Lightning hook for cleaning up after `trainer.fit()`, 278 `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. 279 280 :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 281 Defaults to ``None``. 282 """ 283 pass 284 285 def state_dict(self) -> dict[Any, Any]: 286 """Called when saving a checkpoint. Implement to generate and save the 287 datamodule state. 288 289 :return: A dictionary containing the datamodule state that you want to 290 save. 291 """ 292 return {} 293 294 def load_state_dict(self, state_dict: dict[str, Any]) -> None: 295 """Called when loading a checkpoint. Implement to reload datamodule 296 state given datamodule `state_dict()`. 297 298 :param state_dict: The datamodule state returned by 299 `self.state_dict()`. 300 """ 301 pass
LightningDataModule for the imagenet dataset.
A LightningDataModule implements 7 key methods:
def prepare_data(self):
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
# Download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# Things to do on every process in DDP.
# Load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def predict_dataloader(self):
# return predict dataloader
def teardown(self, stage):
# Called on every process in DDP.
# Clean up after fit or test.
This allows you to share a full dataset without explaining how to download, split, transform and process the data.
Read the docs: https://lightning.ai/docs/pytorch/latest/data/datamodule.html
71 def __init__( 72 self, 73 data_dir: str = "data/ImageNet", 74 batch_size: int = 64, 75 num_workers: int = 0, 76 pin_memory: bool = False, 77 test: bool = True, 78 test_for_train: bool = False, 79 imagenette: bool = False, 80 imagenette_augs: bool = False, 81 pct: float | None = None, 82 ) -> None: 83 """Initialize a `MNISTDataModule`. 84 85 :param data_dir: The data directory. Defaults to `"data/"`. 86 :param batch_size: The batch size. Defaults to `64`. 87 :param num_workers: The number of workers. Defaults to `0`. 88 :param pin_memory: Whether to pin memory. Defaults to `False`. 89 """ 90 super().__init__() 91 92 # this line allows to access init params with 'self.hparams' attribute 93 # also ensures init params will be stored in ckpt 94 self.save_hyperparameters(logger=False) 95 96 # data transformations 97 self.train_tsfm = transforms.Compose( 98 [ 99 transforms.Resize(256), 100 transforms.RandomCrop(224), 101 transforms.RandomHorizontalFlip(), 102 transforms.ToImage(), 103 transforms.ToDtype(torch.float32, scale=True), 104 transforms.Normalize( 105 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 106 ), 107 ] 108 ) 109 110 self.imagenette_tsfm = transforms.Compose( 111 [ 112 transforms.Resize(256), 113 transforms.RandomCrop(224), 114 transforms.RandomHorizontalFlip(), 115 transforms.AutoAugment(), 116 transforms.AugMix(), 117 transforms.ToImage(), 118 transforms.ToDtype(torch.float32, scale=True), 119 transforms.Normalize( 120 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 121 ), 122 ] 123 ) 124 self.val_tsfm = transforms.Compose( 125 [ 126 transforms.Resize(256), 127 transforms.CenterCrop(224), 128 transforms.ToImage(), 129 transforms.ToDtype(torch.float32, scale=True), 130 transforms.Normalize( 131 (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 132 ), 133 ] 134 ) 135 self.data_train: Dataset | None = None 136 self.data_val: Dataset | None = None 137 self.data_test: Dataset | None = None 138 self.batch_size_per_device = batch_size
Initialize a MNISTDataModule.
Parameters
- data_dir: The data directory. Defaults to
"data/". - batch_size: The batch size. Defaults to
64. - num_workers: The number of workers. Defaults to
0. - pin_memory: Whether to pin memory. Defaults to
False.
140 @property 141 def num_classes(self) -> int: 142 """Get the number of classes. 143 144 :return: The number of imagenet classes (1000). 145 """ 146 if self.hparams.imagenette: 147 return 10 148 else: 149 return 1000
Get the number of classes.
Returns
The number of imagenet classes (1000).
151 def prepare_data(self): 152 if self.hparams.imagenette: 153 if not (p := Path(self.hparams.data_dir)).is_dir(): 154 p.mkdir(parents=False, exist_ok=False) 155 Imagenette(self.hparams.data_dir, download=True)
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
DO NOT set state to the model (use setup instead)
since this is NOT called on every device
Example::
def prepare_data(self):
# good
download_data()
tokenize()
etc()
# bad
self.split = data_split
self.some_state = some_other_state()
In a distributed environment, prepare_data can be called in two ways
(using :ref:prepare_data_per_node<common/lightning_module:prepare_data_per_node>)
- Once per node. This is the default and is only called on LOCAL_RANK=0.
- Once in total. Only called on GLOBAL_RANK=0.
Example::
# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = True
# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = False
This is called before requesting the dataloaders:
model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
157 def setup(self, stage: str | None = None) -> None: 158 """Load data. Set variables: `self.data_train`, `self.data_val`, 159 `self.data_test`. 160 161 This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and 162 `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after 163 `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to 164 `self.setup()` once the data is prepared and available for use. 165 166 :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. 167 """ 168 # Divide batch size by the number of devices. 169 if self.trainer is not None: 170 if self.hparams.batch_size % self.trainer.world_size != 0: 171 raise RuntimeError( 172 f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." 173 ) 174 self.batch_size_per_device = ( 175 self.hparams.batch_size // self.trainer.world_size 176 ) 177 dataset = Imagenette if self.hparams.imagenette else ImageNet 178 train_tsfm = ( 179 self.train_tsfm 180 if not (self.hparams.imagenette or self.hparams.imagenette_augs) 181 else self.imagenette_tsfm 182 ) 183 # load and split datasets only if not loaded already 184 if not self.data_train and not self.data_val and not self.data_test: 185 self.data_train = dataset( 186 self.hparams.data_dir, split="train", transform=train_tsfm 187 ) 188 if self.hparams.pct is not None: 189 self.data_train, _ = random_split( 190 dataset( 191 self.hparams.data_dir, 192 split="train", 193 transform=train_tsfm, 194 ), 195 [self.hparams.pct, 1 - self.hparams.pct], 196 ) 197 if self.hparams.imagenette: 198 self.data_val, self.data_test = random_split( 199 dataset( 200 self.hparams.data_dir, 201 split="val", 202 transform=self.val_tsfm, 203 ), 204 [0.5, 0.5], 205 ) 206 else: 207 datalist_source = files(ptame.datalists) 208 with as_file(datalist_source.joinpath("val_set.pt")) as f: 209 self.data_val = Subset( 210 dataset( 211 self.hparams.data_dir, 212 split="val", 213 transform=self.val_tsfm, 214 ), 215 torch.load( 216 f, 217 weights_only=True, 218 ), 219 ) 220 with as_file(datalist_source.joinpath("test_set.pt")) as f: 221 self.data_test = Subset( 222 dataset( 223 self.hparams.data_dir, 224 split="val", 225 transform=self.val_tsfm, 226 ), 227 torch.load( 228 f, 229 weights_only=True, 230 ), 231 )
Load data. Set variables: self.data_train, self.data_val,
self.data_test.
This method is called by Lightning before trainer.fit(), trainer.validate(), trainer.test(), and
trainer.predict(), so be careful not to execute things like random split twice! Also, it is called after
self.prepare_data() and there is a barrier in between which ensures that all the processes proceed to
self.setup() once the data is prepared and available for use.
Parameters
- stage: The stage to setup. Either
"fit","validate","test", or"predict". Defaults toNone.
233 def train_dataloader(self) -> DataLoader[Any]: 234 """Create and return the train dataloader. 235 236 :return: The train dataloader. 237 """ 238 if self.hparams.test_for_train: 239 return self.test_dataloader() 240 return DataLoader( 241 dataset=self.data_train, 242 batch_size=self.batch_size_per_device, 243 num_workers=self.hparams.num_workers, 244 pin_memory=self.hparams.pin_memory, 245 shuffle=True, 246 )
Create and return the train dataloader.
Returns
The train dataloader.
248 def val_dataloader(self) -> DataLoader[Any]: 249 """Create and return the validation dataloader. 250 251 :return: The validation dataloader. 252 """ 253 return DataLoader( 254 dataset=self.data_val, 255 batch_size=self.batch_size_per_device, 256 num_workers=self.hparams.num_workers, 257 pin_memory=self.hparams.pin_memory, 258 shuffle=False, 259 )
Create and return the validation dataloader.
Returns
The validation dataloader.
261 def test_dataloader(self) -> DataLoader[Any]: 262 """Create and return the test dataloader. 263 264 :return: The test dataloader. 265 """ 266 if not self.hparams.test: 267 return self.val_dataloader() 268 return DataLoader( 269 dataset=self.data_test, 270 batch_size=self.batch_size_per_device, 271 num_workers=self.hparams.num_workers, 272 pin_memory=self.hparams.pin_memory, 273 shuffle=False, 274 )
Create and return the test dataloader.
Returns
The test dataloader.
276 def teardown(self, stage: str | None = None) -> None: 277 """Lightning hook for cleaning up after `trainer.fit()`, 278 `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. 279 280 :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. 281 Defaults to ``None``. 282 """ 283 pass
Lightning hook for cleaning up after trainer.fit(),
trainer.validate(), trainer.test(), and trainer.predict().
Parameters
- stage: The stage being torn down. Either
"fit","validate","test", or"predict". Defaults toNone.
285 def state_dict(self) -> dict[Any, Any]: 286 """Called when saving a checkpoint. Implement to generate and save the 287 datamodule state. 288 289 :return: A dictionary containing the datamodule state that you want to 290 save. 291 """ 292 return {}
Called when saving a checkpoint. Implement to generate and save the datamodule state.
Returns
A dictionary containing the datamodule state that you want to save.
294 def load_state_dict(self, state_dict: dict[str, Any]) -> None: 295 """Called when loading a checkpoint. Implement to reload datamodule 296 state given datamodule `state_dict()`. 297 298 :param state_dict: The datamodule state returned by 299 `self.state_dict()`. 300 """ 301 pass
Called when loading a checkpoint. Implement to reload datamodule
state given datamodule state_dict().
Parameters
- state_dict: The datamodule state returned by
self.state_dict().