PyTorch torch.Tensor.view() Method

The view() method in PyTorch is “used to reshape a tensor without changing its data.” The torch.Tensor.view() method returns a new tensor with the same data as the input tensor but with a different shape.

Syntax

tensor.view(*shape)

Parameters

shape (tuple of ints): The shape you want the tensor to have.

Important Note

The desired view must be compatible with the original tensor’s size and stride. If not, you must use reshape() or another method to change the tensor’s shape.

Example 1: Reshaping a 1D tensor to 2D in PyTorch

import torch

tensor = torch.tensor([1, 2, 3, 4, 5, 6])
reshaped_tensor = tensor.view(2, 3)

print(reshaped_tensor)

Output

tensor([[1, 2, 3],
        [4, 5, 6]])

Example 2: Reshaping a 2D tensor to 3D

import torch

tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
reshaped_tensor = tensor.view(1, 3, 2)

print(reshaped_tensor)

Output

tensor([[[1, 2],
         [3, 4],
         [5, 6]]])

Example 3: Using -1 to infer a dimension size

If you use -1 for a particular dimension in the view() method, PyTorch will automatically compute the correct size for that dimension based on the tensor’s total number of elements and the sizes of the other dimensions you’ve specified.

import torch

tensor = torch.tensor([1, 2, 3, 4, 5, 6])
reshaped_tensor = tensor.view(-1, 3)

print(reshaped_tensor)

Output

tensor([[1, 2, 3],
        [4, 5, 6]])

Always be cautious when reshaping tensors, especially using the view() method. If the original tensor is modified, the reshaped tensor will also be affected since they share the same underlying data.

Leave a Comment