PyTorch torch.chunk() Method

PyTorch torch.chunk() method is “used to attempt to split a tensor into a specified number of chunks.” Each chunk is a view of the input tensor.

Syntax

torch.chunk(input, chunks, dim=0)

Parameters

  1. input (Tensor): It is the tensor to split.
  2. chunks (int): It is the number of chunks to return.
  3. dim (int): The dimension along which to split the tensor.

Example 1: Basic usage with a 1D tensor

import torch

tensor_1 = torch.tensor([1, 2, 3, 4, 5, 6])
chunks_1 = torch.chunk(tensor_1, chunks=3)

print("Original Tensor:")
print(tensor_1)
print("\nChunks:")
for c in chunks_1:
  print(c)

Output

Original Tensor:
tensor([1, 2, 3, 4, 5, 6])

Chunks:
tensor([1, 2])
tensor([3, 4])
tensor([5, 6])

Example 2: Using the method with a 2D tensor

import torch

tensor_2 = torch.tensor([
  [1, 2],
  [3, 4],
  [5, 6],
  [7, 8]
])
chunks_2 = torch.chunk(tensor_2, chunks=2, dim=0)

print("Original Tensor:")
print(tensor_2)
print("\nChunks along Dimension 0:")
for c in chunks_2:
  print(c)

Output

Using the method with a 2D tensor

Example 3: Splitting along a different dimension

import torch

tensor_3 = torch.tensor([
  [1, 2, 3, 4],
  [5, 6, 7, 8]
])
chunks_3 = torch.chunk(tensor_3, chunks=2, dim=1)

print("Original Tensor:")
print(tensor_3)
print("\nChunks along Dimension 1:")
for c in chunks_3:
  print(c)

Output

Splitting along a different dimension

That’s it!

Related posts

torch.conj()

torch.argwhere()

torch.adjoint()

torch.polar()

torch.dequantize()

Leave a Comment