PyTorch torch.split() method is “used to split the tensor into chunks.” Each chunk is a view of the original tensor.
If split_size_or_sections is an integer type, the tensor will be split into equally sized chunks (if possible). The last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, the tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
Syntax
torch.split(tensor, split_size_or_sections, dim=0)
Parameters
- tensor: The tensor you want to split.
- split_size_or_sections: Either an integer or a list of integers.
- dim: The dimension along which to split the tensor. By default, it’s 0.
Example 1: Splitting with an integer size
Here, we split a tensor of size (10,) into chunks of size 3 along the first dimension.
import torch
tensor = torch.arange(10)
result = torch.split(tensor, 3)
for sub_tensor in result:
print(sub_tensor)
Output
tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])
Example 2: Splitting with a list of sizes
Here, we split a tensor of size (10, ) into chunks of sizes 2, 4, and 4 along the first dimension.
import torch
tensor = torch.arange(10)
result = torch.split(tensor, [2, 4, 4])
for sub_tensor in result:
print(sub_tensor)
Output
tensor([0, 1])
tensor([2, 3, 4, 5])
tensor([6, 7, 8, 9])
That’s it!
Related posts

Krunal Lathiya is a seasoned Computer Science expert with over eight years in the tech industry. He boasts deep knowledge in Data Science and Machine Learning. Versed in Python, JavaScript, PHP, R, and Golang. Skilled in frameworks like Angular and React and platforms such as Node.js. His expertise spans both front-end and back-end development. His proficiency in the Machine Learning frameworks like PyTorch and Tensorflow is a testament to his versatility and commitment to the craft.