Skip to content

Commit 06c5abf

Browse files
fix: fixing types and special cases for diagonal observables
1 parent 4d85519 commit 06c5abf

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

mpqp/core/instruction/measurement/expectation_value.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,21 @@ def __init__(self, observable: Matrix | list[Real] | PauliString):
112112

113113
self._matrix = np.array(observable)
114114

115+
else:
116+
self._is_diagonal = True
117+
self._diag_elements = observable.real
118+
115119
# correspond to isinstance(observable, list)
116120
else:
117121
self._is_diagonal = True
118-
self._diag_elements = observable
122+
self._diag_elements = np.array(observable)
119123

120124
@property
121125
def matrix(self) -> Matrix:
122126
"""The matrix representation of the observable."""
123127
if self._matrix is None:
124128
if self.is_diagonal:
125-
self._matrix = np.diag(self._diag_elements)
129+
self._matrix = np.diag(self.diagonal_elements)
126130
else:
127131
self._matrix = self.pauli_string.to_matrix()
128132
matrix = copy.deepcopy(self._matrix).astype(np.complex64)
@@ -134,7 +138,7 @@ def pauli_string(self) -> PauliString:
134138
if self._pauli_string is None:
135139
if self.is_diagonal:
136140
self._pauli_string = PauliString.from_diagonal_elements(
137-
self._diag_elements
141+
self.diagonal_elements
138142
)
139143
else:
140144
self._pauli_string = PauliString.from_matrix(self.matrix)
@@ -146,7 +150,7 @@ def diagonal_elements(self) -> npt.NDArray[np.float32]:
146150
"""The diagonal elements of the matrix representing the observable (diagonal or not)."""
147151
if self._diag_elements is None:
148152
self._diag_elements = np.diagonal(self.matrix)
149-
return self._diag_elements
153+
return copy.deepcopy(self._diag_elements).real
150154

151155
@matrix.setter
152156
def matrix(self, matrix: Matrix):

mpqp/tools/obs_decomposition.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def get_monomial(self) -> PauliStringMonomial:
307307
def compute_coefficients_diagonal_case(
308308
m: list[bool],
309309
current_node: DiagPauliNode,
310-
diag_elements: npt.NDArray[np.float64],
310+
diag_elements: npt.NDArray[np.float32],
311311
monomial_list: list[PauliStringMonomial],
312312
):
313313
"""Computes coefficients for the current node in the pauli tree based on the
@@ -356,7 +356,7 @@ def update_tree_diagonal_case(current_node: DiagPauliNode, m: list[bool]):
356356
def generate_and_explore_node_diagonal_case(
357357
m: list[bool],
358358
current_node: DiagPauliNode,
359-
diag_elements: npt.NDArray[np.float64],
359+
diag_elements: npt.NDArray[np.float32],
360360
n: int,
361361
monomials: list[PauliStringMonomial],
362362
progression: Optional[list[int]] = None,
@@ -404,7 +404,7 @@ def generate_and_explore_node_diagonal_case(
404404

405405
@typechecked
406406
def decompose_diagonal_observable_ptdr(
407-
diag_elements: list[Real] | npt.NDArray[np.float64], print_progression: bool = False
407+
diag_elements: list[Real] | npt.NDArray[np.float32], print_progression: bool = False
408408
) -> PauliString:
409409
"""Decomposes a diagonal observable into a Pauli string representation.
410410
@@ -480,7 +480,7 @@ def generate_hadamard(n: int) -> npt.NDArray[np.int8]:
480480

481481

482482
def compute_coefficients_walsh(
483-
H_matrix: npt.NDArray[np.int8], diagonal_elements: npt.NDArray[np.float64]
483+
H_matrix: npt.NDArray[np.int8], diagonal_elements: npt.NDArray[np.float32]
484484
) -> list[float]:
485485
"""Computes the coefficients using the Walsh-Hadamard transform.
486486
@@ -505,7 +505,7 @@ def compute_coefficients_walsh(
505505

506506
@typechecked
507507
def decompose_diagonal_observable_walsh_hadamard(
508-
diag_elements: list[Real] | npt.NDArray[np.float64],
508+
diag_elements: list[Real] | npt.NDArray[np.float32],
509509
) -> PauliString:
510510
"""Decomposes the observable represented by the diagonal elements into a
511511
Pauli string using the Walsh-Hadamard transform.
@@ -535,7 +535,7 @@ def decompose_diagonal_observable_walsh_hadamard(
535535
if TYPE_CHECKING:
536536
assert isinstance(m, PauliStringMonomial)
537537
if c != 0.0:
538-
m.coef = c
538+
m.coef = c.real
539539
final_monomials.append(m)
540540

541541
return PauliString(final_monomials)

0 commit comments

Comments
 (0)