PyTorch torch.squeeze() Method

PyTorch torch.squeeze() method “returns a tensor with all specified dimensions of input of size 1 removed.” The torch.squeeze() method is used to remove unnecessary dimensions that do not carry any information.

Syntax

torch.squeeze(input, dim=None)

Parameters

  1. input: The input tensor.
  2. dim: The dimensions to squeeze. If dim is specified, the function only squeezes the specified dimensions. If dim is not specified, all dimensions of size 1 will be removed.

Example 1: Squeeze all dimensions of size 1

import torch

# Create a tensor with dimensions of size 1
a = torch.Tensor([[1, 2, 3]]).unsqueeze(0)
print(a.size())

# Remove all dimensions of size 1
squeezed_tensor = torch.squeeze(a)
print(squeezed_tensor.size())

Output

torch.Size([1, 1, 3])

torch.Size([3])

Example 2: Squeeze a specific dimension

import torch

# Create a tensor with dimensions of size 1
a = torch.Tensor([[1, 2, 3]]).unsqueeze(0)
print(a.size())

# Remove all dimensions of size 1
squeezed_tensor = torch.squeeze(a, dim=0)
print(squeezed_tensor.size())

Output

torch.Size([1, 1, 3])

torch.Size([1, 3])

Example 3: Squeeze multiple specific dimensions

import torch

# Create a tensor with dimensions of size 1
a = torch.Tensor([[1, 2, 3]]).unsqueeze(0).unsqueeze(0)
print(a.size())

# Remove all dimensions of size 1
squeezed_tensor = torch.squeeze(a, dim=(0, 1))
print(squeezed_tensor.size())

Output

torch.Size([1, 1, 1, 3])

torch.Size([1, 3])

That’s it!

Related posts

PyTorch torch.sum()

PyTorch torch.full()

Leave a Comment