The torch.flatten() method is “used to flatten an input by reshaping it into a one-dimensional tensor.” The flatten() method supports both real and complex-valued input tensors. It accepts a torch tensor as an input and returns a torch tensor flattened into one dimension.
It’s a more explicit method than view() or reshape() when you want to flatten a tensor, making your code more readable.
Syntax
torch.flatten(input, start_dim=0, end_dim=-1)
Parameters
- input (Tensor): The input tensor.
- start_dim (int): The first dimension to flatten.
- end_dim (int): The last dimension to flatten.
Example 1: Flattening a tensor
import torch
tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
flattened_tensor = torch.flatten(tensor)
print(flattened_tensor)
Output
tensor([1, 2, 3, 4, 5, 6])
Example 2: Flattening specific dimensions
Let’s say we have a tensor of shape (2, 3, 4), and we want to flatten only the last two dimensions:
import torch
tensor = torch.arange(24).reshape(2, 3, 4)
print("Original Tensor:\n", tensor)
flattened_tensor = torch.flatten(tensor, start_dim=1, end_dim=2)
print("\nFlattened Tensor:\n", flattened_tensor)
Output
The torch.flatten() method is specifically helpful when you want to flatten specific dimensions of a tensor or when you want to make your intent clear in your code.

Krunal Lathiya is a seasoned Computer Science expert with over eight years in the tech industry. He boasts deep knowledge in Data Science and Machine Learning. Versed in Python, JavaScript, PHP, R, and Golang. Skilled in frameworks like Angular and React and platforms such as Node.js. His expertise spans both front-end and back-end development. His proficiency in the Machine Learning frameworks like PyTorch and Tensorflow is a testament to his versatility and commitment to the craft.