YouTip LogoYouTip

Pytorch Transforms

PyTorch Data Transforms |

\n\n

In PyTorch, data transformation is a mechanism for processing data during loading, converting raw data into a format suitable for model training. It is primarily accomplished using tools provided by torchvision.transforms.

\n\n

Data transforms can not only perform basic data preprocessing (such as normalization, resizing, etc.), but also help with data augmentation (such as random cropping, flipping, etc.), thereby improving the model's generalization ability.

\n\n

Why are Data Transforms Needed?

\n\n

Data Preprocessing:

\n
    \n
  • Adjust data format, size, and range to make them suitable for model input.
  • \n
  • For example, images need to be resized to a fixed size, converted to tensor format, and normalized to [0, 1].
  • \n
\n\n

Data Augmentation:

\n
    \n
  • Apply transformations to data during training to increase diversity.
  • \n
  • For example, increase data sample variations through random rotation, flipping, and cropping to avoid overfitting.
  • \n
\n\n

Flexibility:

\n
    \n
  • By defining a series of transform operations, data can be processed dynamically, simplifying the complexity of data loading.
  • \n
\n\n

In PyTorch, the torchvision.transforms module provides various transformation operations for image processing.

\n\n

Basic Transform Operations

\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
Transform Function NameDescriptionExample
transforms.ToTensor()Converts PIL images or NumPy arrays to PyTorch tensors, automatically normalizing pixel values to [0, 1].transform = transforms.ToTensor()
transforms.Normalize(mean, std)Normalizes images so that data follows a zero mean and unit variance.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.Resize(size)Resizes images to ensure consistent input sizes for the network.transform = transforms.Resize((256, 256))
transforms.CenterCrop(size)Crops a region of the specified size from the center of the image.transform = transforms.CenterCrop(224)
\n\n

1. ToTensor

\n

Converts PIL images or NumPy arrays to PyTorch tensors.

\n

It also normalizes pixel values from [0, 255] to [0, 1].

\n
from torchvision import transforms transform = transforms.ToTensor()
\n\n

2. Normalize

\n

Standardizes data to conform to specific mean and standard deviation.

\n

Commonly used for image data to normalize pixel values to zero mean and unit variance.

\n
transform = transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
\n\n

3. Resize

\n

Adjusts the size of the image.

\n
transform = transforms.Resize((128, 128)) # SetResize image to 128x128
\n\n

4. CenterCrop

\n

Crops a region of the specified size from the center of the image.

\n
transform = transforms.CenterCrop(128) # Crop a 128x128 region
\n\n

Data Augmentation Operations

\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
Transform Function NameDescriptionExample
transforms.RandomHorizontalFlip(p)Randomly flips the image horizontally.transform = transforms.RandomHorizontalFlip(p=0.5)
transforms.RandomRotation(degrees)Randomly rotates the image.transform = transforms.RandomRotation(degrees=45)
transforms.ColorJitter(brightness, contrast, saturation, hue)Adjusts the brightness, contrast, saturation, and hue of the image.transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
transforms.RandomCrop(size)Randomly crops a region of the specified size.transform = transforms.RandomCrop(224)
transforms.RandomResizedCrop(size)Randomly crops the image and resizes it to the specified size.transform = transforms.RandomResizedCrop(224)
\n\n

1. RandomCrop

\n

Randomly crops a specified size from the image.

\n
transform = transforms.RandomCrop(128)
\n\n

2. RandomHorizontalFlip

\n

Flips the image horizontally with a certain probability.

\n
transform = transforms.RandomHorizontalFlip(p=0.5) # 50% Random flip with probability
\n\n

3. RandomRotation

\n

Randomly rotates by a certain angle.

\n
transform = transforms.RandomRotation(degrees=30) # Random rotation -30 to +30 degrees
\n\n

4. ColorJitter

\n

Randomly changes the brightness, contrast, saturation, or hue of the image.

\n
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5)
\n\n

Composing Transforms

\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
Transform Function NameDescriptionExample
transforms.Compose()Combines multiple transforms together, applying them sequentially in order.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((256, 256))])
\n\n

Combine multiple transforms using transforms.Compose.

\n
transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
\n\n

Custom Transforms

\n

If the functionality provided by transforms does not meet your needs, you can implement custom classes or functions.

\n\n

Examples

\n
class CustomTransform:\n\ndef __call__ (self, x):\n\n# Custom transform logic can be defined here\n\nreturn x * 2\n\ntransform = CustomTransform()
\n\n
\n\n

Examples

\n

Applying Transforms to an Image Dataset

\n

Load the MNIST dataset and apply transforms.

\n\n

Example

\n
from torchvision import datasets, transforms\n\nfrom torch.utils.data import DataLoader\n\n# Define transforms\n\n transform = transforms.Compose([\n\n transforms.Resize((128,128)),\n\n transforms.ToTensor(),\n\n transforms.Normalize(mean=[0.5], std=[0.5])\n\n])\n\n# Load dataset\n\n train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)\n\n# Use DataLoader\n\n train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)\n\n# View transformed data\n\nfor images, labels in train_loader:\n\nprint("Image tensor size:", images.size())# [batch_size, 1, 128, 128]\n\nbreak
\n\n

The output result is:

\n
Image tensor size: torch.Size([32, 1, 128, 128])
\n\n

Visualizing Transform Effects

\n

The following code demonstrates a comparison between original and transformed data.

\n\n

Example

\n
import matplotlib.pyplot as plt\n\nfrom torchvision import datasets\n\nfrom torchvision import datasets, transforms\n\n# Visualize original and augmented images\n\n transform_augment = transforms.Compose([\n\n transforms.RandomHorizontalFlip(),\n\n transforms.RandomRotation(30),\n\n transforms.ToTensor()\n\n])\n\n# Load dataset\n\n dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_augment)\n\n# Display Image\n\ndef show_images(dataset):\n\n fig, axs = plt.subplots(1,5, figsize=(15,5))\n\nfor i in range(5):\n\n image, label = dataset\n\n axs.imshow(image.squeeze(0), cmap='gray')# Set (1, H, W) Convert to (H, W)\n\n axs.set_title(f"Label: {label}")\n\n axs.axis('off')\n\n plt.show()\n\nshow_images(dataset)
\n\n

Displayed as shown below:

\n

Image 1

← Vue3 Declarative RenderingPytorch First Neural Network β†’