PyTorch torch.where() method “returns a tensor of elements selected from either input or other, depending on the condition.”
Syntax
torch.where(condition, input, other, *, out=None)
Parameters
- condition (BoolTensor): When True (nonzero), yield input, otherwise yield other.
- input (Tensor or Scalar): Value (if the input is a scalar) or values selected at indices where condition is True.
- other (Tensor or Scalar): Value (if other is a scalar) or values selected at indices where condition is False.
Example 1: Basic Usage of the torch.where() method
import torch
# Define a sample tensor
tensor = torch.tensor([[1, 2], [3, 4]])
# Create a condition tensor
condition = tensor > 2
# Apply the where function
result = torch.where(condition, tensor, torch.zeros_like(tensor))
print(result)
Output
tensor([[0, 0],
[3, 4]])
Example 2: Advanced Usage – Replacing Negative Values
Suppose you have a tensor, and you want to replace all negative values with zeros:
import torch
# Sample tensor with negative values
tensor = torch.tensor([[1, -2], [-3, 4]])
# Apply the where function
result = torch.where(tensor < 0, torch.zeros_like(tensor), tensor)
print(result)
Output
tensor([[1, 0],
[0, 4]])
The torch.where() function is similar to the numpy.where() function and is particularly useful in various tensor manipulation scenarios.
Related posts

Krunal Lathiya is a seasoned Computer Science expert with over eight years in the tech industry. He boasts deep knowledge in Data Science and Machine Learning. Versed in Python, JavaScript, PHP, R, and Golang. Skilled in frameworks like Angular and React and platforms such as Node.js. His expertise spans both front-end and back-end development. His proficiency in the Machine Learning frameworks like PyTorch and Tensorflow is a testament to his versatility and commitment to the craft.