Skip to content

Commit a4a0195

Browse files
LucaLumettipytorchmergebot
authored andcommitted
Fix torch.where signature mismatch that was caused by torchgen (pytorch#91627)
Fixes pytorch#91003 Pull Request resolved: pytorch#91627 Approved by: https://github.com/albanD
1 parent accecd7 commit a4a0195

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

torch/_torch_docs.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -12505,34 +12505,34 @@ def merge_dicts(*dicts):
1250512505
add_docstr(
1250612506
torch.where,
1250712507
r"""
12508-
where(condition, x, y, *, out=None) -> Tensor
12508+
where(condition, input, other, *, out=None) -> Tensor
1250912509
12510-
Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`.
12510+
Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`.
1251112511
1251212512
The operation is defined as:
1251312513
1251412514
.. math::
1251512515
\text{out}_i = \begin{cases}
12516-
\text{x}_i & \text{if } \text{condition}_i \\
12517-
\text{y}_i & \text{otherwise} \\
12516+
\text{input}_i & \text{if } \text{condition}_i \\
12517+
\text{other}_i & \text{otherwise} \\
1251812518
\end{cases}
1251912519
"""
1252012520
+ r"""
1252112521
.. note::
12522-
The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be :ref:`broadcastable <broadcasting-semantics>`.
12522+
The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable <broadcasting-semantics>`.
1252312523
1252412524
Arguments:
12525-
condition (BoolTensor): When True (nonzero), yield x, otherwise yield y
12526-
x (Tensor or Scalar): value (if :attr:`x` is a scalar) or values selected at indices
12525+
condition (BoolTensor): When True (nonzero), yield input, otherwise yield other
12526+
input (Tensor or Scalar): value (if :attr:`input` is a scalar) or values selected at indices
1252712527
where :attr:`condition` is ``True``
12528-
y (Tensor or Scalar): value (if :attr:`y` is a scalar) or values selected at indices
12528+
other (Tensor or Scalar): value (if :attr:`other` is a scalar) or values selected at indices
1252912529
where :attr:`condition` is ``False``
1253012530
1253112531
Keyword args:
1253212532
{out}
1253312533
1253412534
Returns:
12535-
Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y`
12535+
Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other`
1253612536
1253712537
Example::
1253812538

torchgen/api/python.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
227227
# s/self/input/ outside method bindings
228228
# [old codegen] TODO: remove this? doesn't rename in codegen, it's just
229229
# for the parse string
230-
if name == "self" and type_str == "Tensor" and not method:
230+
if name == "self" and type_str in ["Tensor", "Number"] and not method:
231231
name = "input"
232232

233233
# add default

0 commit comments

Comments
 (0)