@@ -388,7 +388,8 @@ def _(a: tf.Tensor, k: int = 0) -> tf.Tensor:
388
388
elif rank == 2 :
389
389
return tf .linalg .diag_part (a , k = k )
390
390
else :
391
- raise ValueError ("Input tensor must be 1D or 2D." )
391
+ error_msg = "Input tensor must be 1D or 2D."
392
+ raise ValueError (error_msg )
392
393
393
394
394
395
@numeric .logical_or .register (tf .Tensor )
@@ -439,14 +440,16 @@ def _(a: tf.Tensor, axis: Union[int, Tuple[int, ...], List[int]]) -> np.ndarray:
439
440
axis = (axis ,)
440
441
441
442
if len (set (axis )) != len (axis ):
442
- raise ValueError ("repeated axis" )
443
+ error_msg = "repeated axis"
444
+ raise ValueError (error_msg )
443
445
444
446
out_ndim = len (axis ) + a .ndim
445
447
446
448
norm_axis = []
447
449
for ax in axis :
448
450
if ax < - out_ndim or ax >= out_ndim :
449
- raise ValueError (f"axis { ax } is out of bounds for array of dimension { out_ndim } " )
451
+ error_msg = f"axis { ax } is out of bounds for array of dimension { out_ndim } "
452
+ raise ValueError (error_msg )
450
453
norm_axis .append (ax + out_ndim if ax < 0 else ax )
451
454
452
455
shape_it = iter (a .shape )
@@ -463,9 +466,11 @@ def _(a: tf.Tensor) -> tf.Tensor:
463
466
@numeric .searchsorted .register (tf .Tensor )
464
467
def _ (a : tf .Tensor , v : tf .Tensor , side : str = "left" , sorter : Optional [tf .Tensor ] = None ) -> tf .Tensor :
465
468
if side not in ["right" , "left" ]:
466
- raise ValueError (f"Invalid value for 'side': { side } . Expected 'right' or 'left'." )
469
+ error_msg = f"Invalid value for 'side': { side } . Expected 'right' or 'left'."
470
+ raise ValueError (error_msg )
467
471
if a .ndim != 1 :
468
- raise ValueError (f"Input tensor 'a' must be 1-D. Received { a .ndim } -D tensor." )
472
+ error_msg = f"Input tensor 'a' must be 1-D. Received { a .ndim } -D tensor."
473
+ raise ValueError (error_msg )
469
474
sorted_a = tf .sort (a )
470
475
return tf .searchsorted (sorted_sequence = sorted_a , values = v , side = side )
471
476
@@ -542,4 +547,4 @@ def tensor(
542
547
device = convert_to_tf_device (device )
543
548
dtype = convert_to_tf_dtype (dtype )
544
549
with tf .device (device ):
545
- return tf .constant (data , dtype = dtype )
550
+ return tf .constant (data , dtype = dtype )
0 commit comments