PyTorch torch.split() Method

PyTorch torch.split() method is “used to split the tensor into chunks.” Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, the tensor will be split into equally sized chunks (if possible). The last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, the tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

Syntax

torch.split(tensor, split_size_or_sections, dim=0)

Parameters

  1. tensor: The tensor you want to split.
  2. split_size_or_sections: Either an integer or a list of integers.
  3. dim: The dimension along which to split the tensor. By default, it’s 0.

Example 1: Splitting with an integer size

Here, we split a tensor of size (10,) into chunks of size 3 along the first dimension.

import torch

tensor = torch.arange(10)
result = torch.split(tensor, 3)

for sub_tensor in result:
  print(sub_tensor)

Output

tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])

Example 2: Splitting with a list of sizes

Here, we split a tensor of size (10, ) into chunks of sizes 2, 4, and 4 along the first dimension.

import torch

tensor = torch.arange(10)
result = torch.split(tensor, [2, 4, 4])

for sub_tensor in result:
  print(sub_tensor)

Output

tensor([0, 1])
tensor([2, 3, 4, 5])
tensor([6, 7, 8, 9])

That’s it!

Related posts

torch.unsqueeze()

torch.nn.Conv2d()

torch.clone()

torch.squeeze()

torch.sum()

torch.full()

Leave a Comment