PyTorch torch.where() method “returns a tensor of elements selected from either input or other, depending on the condition.”
torch.where(condition, input, other, *, out=None)
- condition (BoolTensor): When True (nonzero), yield input, otherwise yield other.
- input (Tensor or Scalar): Value (if the input is a scalar) or values selected at indices where condition is True.
- other (Tensor or Scalar): Value (if other is a scalar) or values selected at indices where condition is False.
Example 1: Basic Usage of the torch.where() method
import torch # Define a sample tensor tensor = torch.tensor([[1, 2], [3, 4]]) # Create a condition tensor condition = tensor > 2 # Apply the where function result = torch.where(condition, tensor, torch.zeros_like(tensor)) print(result)
tensor([[0, 0], [3, 4]])
Example 2: Advanced Usage – Replacing Negative Values
Suppose you have a tensor, and you want to replace all negative values with zeros:
import torch # Sample tensor with negative values tensor = torch.tensor([[1, -2], [-3, 4]]) # Apply the where function result = torch.where(tensor < 0, torch.zeros_like(tensor), tensor) print(result)
tensor([[1, 0], [0, 4]])
The torch.where() function is similar to the numpy.where() function and is particularly useful in various tensor manipulation scenarios.