PyTorch torch.cat() Method

PyTorch torch.cat() method is “used to concatenate the input sequence of seq tensors in the provided dimension.” All tensors must either have the same shape (except in the concatenating dimension) or be empty.

Syntax

torch.cat(tensors, dim=0, out=None)

Parameters

  1. tensors: A sequence of tensors to concatenate. These tensors must have the same shape, except in the dimension corresponding to dim.
  2. dim: The dimension along which the tensors will be concatenated. The default is 0.
  3. out: The output tensor (optional).

Example 1: Concatenating 1-D Tensors

import torch

a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
c = torch.tensor([5, 6])

result = torch.cat((a, b, c), dim=0)

print(result)

Output

tensor([1, 2, 3, 4, 5, 6])

Example 2: Concatenating 2-D Tensors Along Rows (dim=0)

import torch

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6]])

result = torch.cat((a, b), dim=0)

print(result)

Output

tensor([[1, 2],
        [3, 4],
        [5, 6]])

Example 3: Concatenating 2-D Tensors Along Columns (dim=1)

import torch

a = torch.tensor([[1, 2]])
b = torch.tensor([[3, 4]])

result = torch.cat((a, b), dim=1)

print(result)

Output

tensor([[1, 2, 3, 4]])

Example 4: Concatenating 3-D Tensors

import torch

a = torch.randn(2, 3, 4)
b = torch.randn(2, 3, 4)

result = torch.cat((a, b), dim=0)

print(result)

Output

Concatenating 3-D Tensors

That’s it!

Related posts

torch.matmul()

torch.split()

torch.unsqueeze()

torch.nn.Conv2d()

torch.clone()

torch.squeeze()

torch.sum()

Leave a Comment