Skip to content

Commit 88a1504

Browse files
authored
Rebase onto main (#165)
1 parent 8f6053d commit 88a1504

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

CHANGELOG.rst

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Change log
88
==========
99

10+
0.12.2 (unreleased)
11+
-------------------
12+
13+
**Bug fix**
14+
15+
- Value propagation of string tensors no longer raises an erroneous ``ValueError`` in some instances.
16+
17+
1018
0.12.1 (2024-06-18)
1119
-------------------
1220

src/spox/_value_prop.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,15 @@ def __str__(self):
7575

7676
def check(self) -> bool:
7777
if isinstance(self.type, Tensor):
78-
return (
78+
if not (
7979
isinstance(self.value, np.ndarray)
80-
and self.value.dtype.type is self.type.dtype.type
8180
and Shape.from_simple(self.value.shape) <= self.type._shape
82-
)
81+
):
82+
return False
83+
# Strings need some special handling
84+
if self.value.dtype == object and self.type.dtype == str:
85+
return True
86+
return self.value.dtype.type is self.type.dtype.type
8387
elif isinstance(self.type, Sequence):
8488
return isinstance(self.value, list) and all(
8589
elem.type._subtype(self.type.elem_type) for elem in self.value

tests/test_value_propagation.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import spox
66
import spox._future
77
import spox.opset.ai.onnx.ml.v3 as ml
8-
import spox.opset.ai.onnx.v17 as op
8+
import spox.opset.ai.onnx.v20 as op
99
from spox import Var, _type_system
1010
from spox._graph import arguments, results
1111
from spox._shape import Shape
@@ -205,3 +205,11 @@ def test_value_propagation_does_not_fail_on_unseen_opsets(value_prop_backend):
205205
)
206206

207207
spox.inline(model)(X=op.const(["Test Test"], dtype=np.str_))
208+
209+
210+
def test_strings(value_prop_backend):
211+
x, y = op.const("foo"), op.const("bar")
212+
assert op.string_concat(x, y)._value.value == "foobar" # type: ignore
213+
214+
x, y = op.const(["foo"]), op.const(["bar"])
215+
np.testing.assert_equal(op.string_concat(x, y)._value.value, np.array(["foobar"])) # type: ignore

0 commit comments

Comments
 (0)