The torch.flatten() method is “used to flatten an input by reshaping it into a one-dimensional tensor.” The flatten() method supports both real and complex-valued input tensors. It accepts a torch tensor as an input and returns a torch tensor flattened into one dimension.
torch.flatten(input, start_dim=0, end_dim=-1)
- input (Tensor): The input tensor.
- start_dim (int): The first dimension to flatten.
- end_dim (int): The last dimension to flatten.
Example 1: Flattening a tensor
import torch tensor = torch.tensor([[1, 2], [3, 4], [5, 6]]) flattened_tensor = torch.flatten(tensor) print(flattened_tensor)
tensor([1, 2, 3, 4, 5, 6])
Example 2: Flattening specific dimensions
Let’s say we have a tensor of shape (2, 3, 4), and we want to flatten only the last two dimensions:
import torch tensor = torch.arange(24).reshape(2, 3, 4) print("Original Tensor:\n", tensor) flattened_tensor = torch.flatten(tensor, start_dim=1, end_dim=2) print("\nFlattened Tensor:\n", flattened_tensor)
The torch.flatten() method is specifically helpful when you want to flatten specific dimensions of a tensor or when you want to make your intent clear in your code.