@@ -24,9 +24,14 @@ def _(
24
24
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
25
25
keepdims : bool = False ,
26
26
) -> tf .Tensor :
27
- if ord is None :
28
- ord = "euclidean"
29
27
rank = tf .rank (a )
28
+
29
+ if ord is None :
30
+ if axis is None and rank == 2 :
31
+ ord = "fro"
32
+ else :
33
+ ord = 2
34
+
30
35
if rank == 2 and axis is None :
31
36
axis = (0 , 1 )
32
37
@@ -49,41 +54,75 @@ def _(
49
54
if rank != 2 :
50
55
error_msg = "ord=-1 is only supported for 2D tensors"
51
56
raise ValueError (error_msg )
52
- return tf .reduce_min (tf .reduce_sum (tf .abs (a ), axis = axis [0 ]), keepdims = keepdims )
57
+ result = tf .reduce_min (tf .reduce_sum (tf .abs (a ), axis = axis [0 ]), keepdims = keepdims )
58
+ if keepdims :
59
+ result = tf .reshape (result , [1 , 1 ])
60
+ return result
53
61
54
62
if ord == 1 and isinstance (axis , tuple ) and len (axis ) != 1 :
55
63
if rank != 2 :
56
64
error_msg = "ord=1 is only supported for 2D tensors"
57
65
raise ValueError (error_msg )
58
- return tf .reduce_max (tf .reduce_sum (tf .abs (a ), axis = axis [0 ]), keepdims = keepdims )
66
+ result = tf .reduce_max (tf .reduce_sum (tf .abs (a ), axis = axis [0 ]), keepdims = keepdims )
67
+ if keepdims :
68
+ result = tf .reshape (result , [1 , 1 ])
69
+ return result
59
70
60
71
if ord == - 2 and isinstance (axis , tuple ) and len (axis ) != 1 :
61
72
if rank != 2 :
62
73
error_msg = "ord=-2 is only supported for 2D tensors"
63
74
raise ValueError (error_msg )
64
75
s = tf .linalg .svd (a , compute_uv = False )
65
- return tf .reduce_min (s , axis = - 1 )
76
+ result = tf .reduce_min (s , axis = - 1 )
77
+ if keepdims :
78
+ result = tf .reshape (result , [1 , 1 ])
79
+ return result
66
80
67
81
if ord == 2 and isinstance (axis , tuple ) and len (axis ) != 1 :
68
82
if rank != 2 :
69
83
error_msg = "ord=2 is only supported for 2D tensors"
70
84
raise ValueError (error_msg )
71
85
s = tf .linalg .svd (a , compute_uv = False )
72
- return tf .reduce_max (s , axis = - 1 )
86
+ result = tf .reduce_max (s , axis = - 1 )
87
+ if keepdims :
88
+ result = tf .reshape (result , [1 , 1 ])
89
+ return result
73
90
74
91
if ord == float ("inf" ) and isinstance (axis , tuple ) and len (axis ) != 1 :
75
92
if rank != 2 :
76
93
error_msg = "ord=inf is only supported for 2D tensors"
77
94
raise ValueError (error_msg )
78
- return tf .reduce_max (tf .reduce_sum (tf .abs (a ), axis = axis [1 ]), keepdims = keepdims )
95
+ result = tf .reduce_max (tf .reduce_sum (tf .abs (a ), axis = axis [1 ]), keepdims = keepdims )
96
+ if keepdims :
97
+ result = tf .reshape (result , [1 , 1 ])
98
+ return result
79
99
80
100
if ord == - float ("inf" ) and isinstance (axis , tuple ) and len (axis ) != 1 :
81
101
if rank != 2 :
82
102
error_msg = "ord=-inf is only supported for 2D tensors"
83
103
raise ValueError (error_msg )
84
- return tf .reduce_min (tf .reduce_sum (tf .abs (a ), axis = axis [1 ]), keepdims = keepdims )
104
+ result = tf .reduce_min (tf .reduce_sum (tf .abs (a ), axis = axis [1 ]), keepdims = keepdims )
105
+ if keepdims :
106
+ result = tf .reshape (result , [1 , 1 ])
107
+ return result
85
108
86
- return tf .linalg .norm (a , ord = ord , axis = axis , keepdims = keepdims )
109
+ try :
110
+ return tf .linalg .norm (a , ord = ord , axis = axis , keepdims = keepdims )
111
+ except (TypeError , ValueError ):
112
+ if axis is not None :
113
+ if ord == 2 :
114
+ squared = tf .square (a )
115
+ sum_squares = tf .reduce_sum (squared , axis = axis , keepdims = keepdims )
116
+ return tf .sqrt (sum_squares )
117
+ elif ord == 1 :
118
+ return tf .reduce_sum (tf .abs (a ), axis = axis , keepdims = keepdims )
119
+ elif ord == float ("inf" ):
120
+ return tf .reduce_max (tf .abs (a ), axis = axis , keepdims = keepdims )
121
+ elif ord == - float ("inf" ):
122
+ return tf .reduce_min (tf .abs (a ), axis = axis , keepdims = keepdims )
123
+
124
+ error_msg = f"Unsupported combination of ord={ ord } and axis={ axis } "
125
+ raise ValueError (error_msg )
87
126
88
127
89
128
@linalg .cholesky .register
0 commit comments