Skip to content

Commit 4043e15

Browse files
authored
[PyOV] Extend Python API with STFT-15 (openvinotoolkit#27142)
### Details: - Extend Python API with STFT-15 ### Tickets: - 147160
1 parent ebdf1fc commit 4043e15

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

src/bindings/python/src/openvino/runtime/opset15/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from openvino.runtime.opset15.ops import bitwise_left_shift
1717
from openvino.runtime.opset15.ops import bitwise_right_shift
1818
from openvino.runtime.opset15.ops import slice_scatter
19+
from openvino.runtime.opset15.ops import stft

src/bindings/python/src/openvino/runtime/opset15/ops.py

+24
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,27 @@ def slice_scatter(
303303
inputs = as_nodes(data, updates, start, stop, step, axes, name=name)
304304

305305
return _get_node_factory_opset15().create("SliceScatter", inputs)
306+
307+
308+
@nameable_op
309+
def stft(
310+
data: NodeInput,
311+
window: NodeInput,
312+
frame_size: NodeInput,
313+
frame_step: NodeInput,
314+
transpose_frames: bool,
315+
name: Optional[str] = None,
316+
) -> Node:
317+
"""Return a node which generates STFT operation.
318+
319+
:param data: The node providing input data.
320+
:param window: The node providing window data.
321+
:param frame_size: The node with scalar value representing the size of Fourier Transform.
322+
:param frame_step: The distance (number of samples) between successive window frames.
323+
:param transpose_frames: Flag to set output shape layout. If true the `frames` dimension is at out_shape[2],
324+
otherwise it is at out_shape[1].
325+
:param name: The optional name for the created output node.
326+
:return: The new node performing STFT operation.
327+
"""
328+
inputs = as_nodes(data, window, frame_size, frame_step, name=name)
329+
return _get_node_factory_opset15().create("STFT", inputs)

src/bindings/python/tests/test_graph/test_create_op.py

+16
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,22 @@ def test_slice_scatter():
24862486
assert node_default_axes.get_output_shape(0) == data_shape
24872487

24882488

2489+
def test_stft():
2490+
data_shape = [4, 48]
2491+
data = ov.parameter(data_shape, name="input", dtype=np.float32)
2492+
window = ov.parameter([7], name="window", dtype=np.float32)
2493+
frame_size = ov.constant(np.array(11, dtype=np.int32))
2494+
frame_step = ov.constant(np.array(3, dtype=np.int32))
2495+
transpose_frames = True
2496+
2497+
op = ov_opset15.stft(data, window, frame_size, frame_step, transpose_frames)
2498+
2499+
assert op.get_type_name() == "STFT"
2500+
assert op.get_output_size() == 1
2501+
assert op.get_output_element_type(0) == Type.f32
2502+
assert op.get_output_shape(0) == [4, 13, 6, 2]
2503+
2504+
24892505
def test_parameter_get_attributes():
24902506
parameter = ov.parameter([2, 2], dtype=np.float32, name="InputData")
24912507
parameter_attributes = parameter.get_attributes()

0 commit comments

Comments
 (0)