PyTorch torch.allclose() Method

PyTorch torch.allclose() method is used to “check if all elements of two tensors are approximately equal within some tolerance.” It helps verify if two tensors are “close enough” in value, especially in unit tests or checking the correctness of computations in numerical methods where rounding errors might occur.

Syntax

torch.allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)

Parameters

  1. input (Tensor): The first tensor to compare.
  2. other (Tensor): The second tensor to compare. It should have the same shape as the input.
  3. rtol (float, optional): The relative tolerance.
  4. atol (float, optional): The absolute tolerance.
  5. equal_nan (bool, optional): Whether to treat NaNs equally if they appear in the same location in both tensors.

Example 1: How to Use torch.allclose() method

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.0, 3.0001])

print(torch.allclose(a, b))

Output

False

Example 2: Demonstrating the effect of rtol and atol

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.0, 3.1])

# Using default tolerances
print(torch.allclose(a, b)) # Expected: False

# Increasing the relative tolerance
print(torch.allclose(a, b, rtol=0.05))

Output

False
True

Example 3: Handling NaN values

import torch

a = torch.tensor([1.0, 2.0, float('nan')])
b = torch.tensor([1.0, 2.0, float('nan')])

# Without setting equal_nan=True
print(torch.allclose(a, b))

# Setting equal_nan=True
print(torch.allclose(a, b, equal_nan=True))

Output

False
True

That’s it!

Related posts

torch.flatten()

torch.Tensor.view()

Leave a Comment