PyTorch torch.stack() Method

PyTorch torch.stack() method is “used to concatenate a sequence of tensors along a new dimension.” This method joins the tensors with the same dimensions and shape.

Syntax

torch.stack(tensors, dim=0, out)

Parameters

  1. tensors: It’s a sequence of tensors of the same shape and dimensions
  2. dim: It’s the dimension to insert. It’s an integer between 0 and the dimensions of input tensors.
  3. out: It’s the output tensor.

Example 1: Stacking 1D Tensors Along a New Dimension

import torch

tensor1 = torch.Tensor([1, 2, 3])
tensor2 = torch.Tensor([4, 5, 6])
tensor3 = torch.Tensor([7, 8, 9])

# Stack the tensors along a new dimension at index 0
stacked_tensor_1d = torch.stack((tensor1, tensor2, tensor3), dim=0)

print(stacked_tensor_1d)

print(stacked_tensor_1d.shape)

Output

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

torch.Size([3, 3])

Example 2: Stacking 2D Tensors Along a New Dimension at Index 1

Here, we take two 2D tensors and stack them along a new dimension at index 1.

import torch

# Create two 2x3 tensors
tensor4 = torch.Tensor([[1, 2, 3], [4, 5, 6]])
tensor5 = torch.Tensor([[7, 8, 9], [10, 11, 12]])

# Stack the tensors along a new dimension at index 1
stacked_tensor_2d_dim1 = torch.stack((tensor4, tensor5), dim=1)

print(stacked_tensor_2d_dim1)
print(stacked_tensor_2d_dim1.shape)

Output

Stacking 2D Tensors Along a New Dimension at Index 1

Example 3: Stacking 3D Tensors Along a New Dimension at Index 2

import torch

# Create two 2x2x2 tensors
tensor6 = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
tensor7 = torch.Tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

# Stack the tensors along a new dimension at index 2
stacked_tensor_3d_dim2 = torch.stack((tensor6, tensor7), dim=2)

print(stacked_tensor_3d_dim2)
print(stacked_tensor_3d_dim2.shape)

Output

Stacking 3D Tensors Along a New Dimension at Index 2

These examples should give you a good understanding of how the torch.stack() works. If you have any more questions or need further clarification, feel free to ask!

Related posts

torch.cat()

torch.matmul()

torch.split()

torch.unsqueeze()

torch.nn.Conv2d()

torch.clone()

Leave a Comment