Skip to content

Commit 0e4ff9e

Browse files
ystadepre-commit-ci[bot]burgholzer
authored
🚸 Support for adding tests to NALAC (cda-tum#629)
## Description This PR contains modifications that were necessary for the tests of the Neutral Atom Logical Array Compiler, see cda-tum/mqt-qmap#470. ## Checklist: - [x] The pull request only contains commits that are related to it. - [x] I have added appropriate tests and documentation. - [x] I have made sure that all CI jobs on GitHub pass. - [x] The pull request introduces no new warnings and follows the project's style guidelines. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: burgholzer <burgholzer@me.com>
1 parent 7eb0ab7 commit 0e4ff9e

16 files changed

+113
-72
lines changed

include/mqt-core/Permutation.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ class Permutation : public std::map<Qubit, Qubit> {
3030
}
3131
return t;
3232
}
33+
34+
[[nodiscard]] auto apply(const Qubit qubit) const -> Qubit {
35+
if (empty()) {
36+
return qubit;
37+
}
38+
return at(qubit);
39+
}
3340
};
3441
} // namespace qc
3542

include/mqt-core/na/NAComputation.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class NAComputation {
5757
}
5858
auto clear(const bool clearInitialPositions = true) -> void {
5959
operations.clear();
60-
initialPositions.clear();
6160
if (clearInitialPositions) {
6261
initialPositions.clear();
6362
}

include/mqt-core/na/NADefinitions.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,12 @@ template <> struct std::hash<na::FullOpType> {
123123
return qc::combineHash(h1, h2);
124124
}
125125
};
126+
127+
/// Hash function for Point, e.g., for use in unordered_map
128+
template <> struct std::hash<na::Point> {
129+
std::size_t operator()(const na::Point& p) const noexcept {
130+
const std::size_t h1 = std::hash<decltype(p.x)>{}(p.x);
131+
const std::size_t h2 = std::hash<decltype(p.y)>{}(p.y);
132+
return qc::combineHash(h1, h2);
133+
}
134+
};

include/mqt-core/na/operations/NAGlobalOperation.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class NAGlobalOperation : public NAOperation {
2828
[[nodiscard]] auto getParams() const -> const std::vector<qc::fp>& {
2929
return params;
3030
}
31+
[[nodiscard]] auto getType() const -> FullOpType { return type; }
3132
[[nodiscard]] auto isGlobalOperation() const -> bool override { return true; }
3233
[[nodiscard]] auto toString() const -> std::string override;
3334
[[nodiscard]] auto clone() const -> std::unique_ptr<NAOperation> override {

include/mqt-core/operations/CompoundOperation.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class CompoundOperation final : public Operation {
6262

6363
std::vector<std::unique_ptr<Operation>>& getOps() { return ops; }
6464

65-
[[nodiscard]] std::set<Qubit> getUsedQubits() const override;
65+
[[nodiscard]] auto getUsedQubitsPermuted(const Permutation& perm) const
66+
-> std::set<Qubit> override;
6667

6768
[[nodiscard]] auto commutesAtQubit(const Operation& other,
6869
const Qubit& qubit) const -> bool override;

include/mqt-core/operations/NonUnitaryOperation.hpp

-13
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,6 @@ class NonUnitaryOperation final : public Operation {
4747
std::vector<Bit>& getClassics() { return classics; }
4848
[[nodiscard]] std::size_t getNclassics() const { return classics.size(); }
4949

50-
[[nodiscard]] std::set<Qubit> getUsedQubits() const override {
51-
const auto& opTargets = getTargets();
52-
return {opTargets.begin(), opTargets.end()};
53-
}
54-
55-
[[nodiscard]] const Controls& getControls() const override {
56-
throw QFRException("Cannot get controls from non-unitary operation.");
57-
}
58-
59-
[[nodiscard]] Controls& getControls() override {
60-
throw QFRException("Cannot get controls from non-unitary operation.");
61-
}
62-
6350
void addDepthContribution(std::vector<std::size_t>& depths) const override;
6451

6552
void addControl(const Control /*c*/) override {

include/mqt-core/operations/Operation.hpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,10 @@ class Operation {
6868
[[nodiscard]] const std::string& getName() const { return name; }
6969
[[nodiscard]] virtual OpType getType() const { return type; }
7070

71-
[[nodiscard]] virtual std::set<Qubit> getUsedQubits() const {
72-
const auto& opTargets = getTargets();
73-
const auto& opControls = getControls();
74-
std::set<Qubit> usedQubits = {opTargets.begin(), opTargets.end()};
75-
for (const auto& control : opControls) {
76-
usedQubits.insert(control.qubit);
77-
}
78-
return usedQubits;
79-
}
71+
[[nodiscard]] virtual auto
72+
getUsedQubitsPermuted(const Permutation& perm) const -> std::set<Qubit>;
73+
74+
[[nodiscard]] auto getUsedQubits() const -> std::set<Qubit>;
8075

8176
[[nodiscard]] std::unique_ptr<Operation> getInverted() const {
8277
auto op = clone();

src/na/NAComputation.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ auto NAComputation::toString() const -> std::string {
99
std::stringstream ss;
1010
ss << "init at ";
1111
for (const auto& p : initialPositions) {
12-
ss << "(" << p->x << ", " << p->y << ")"
13-
<< ", ";
12+
ss << *p << ", ";
13+
}
14+
if (ss.tellp() == 8) {
15+
ss.seekp(-1, std::ios_base::end);
16+
} else {
17+
ss.seekp(-2, std::ios_base::end);
1418
}
15-
ss.seekp(-2, std::ios_base::end);
1619
ss << ";\n";
1720
for (const auto& op : operations) {
18-
ss << op->toString();
21+
ss << *op;
1922
}
2023
return ss.str();
2124
}

src/na/operations/NALocalOperation.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@ auto NALocalOperation::toString() const -> std::string {
1717
ss << ")";
1818
}
1919
ss << " at ";
20-
for (const auto& p : positions) {
21-
ss << *p << ", ";
20+
if (positions.empty()) {
21+
ss.seekp(-1, std::ios_base::end);
22+
} else {
23+
for (const auto& p : positions) {
24+
ss << *p << ", ";
25+
}
26+
ss.seekp(-2, std::ios_base::end);
2227
}
23-
ss.seekp(-2, std::ios_base::end);
2428
ss << ";\n";
2529
return ss.str();
2630
}

src/operations/CompoundOperation.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,11 @@ void CompoundOperation::dumpOpenQASM(std::ostream& of,
160160
}
161161
}
162162

163-
std::set<Qubit> CompoundOperation::getUsedQubits() const {
163+
auto CompoundOperation::getUsedQubitsPermuted(const Permutation& perm) const
164+
-> std::set<Qubit> {
164165
std::set<Qubit> usedQubits{};
165166
for (const auto& op : ops) {
166-
usedQubits.merge(op->getUsedQubits());
167+
usedQubits.merge(op->getUsedQubitsPermuted(perm));
167168
}
168169
return usedQubits;
169170
}

src/operations/Operation.cpp

+36-37
Original file line numberDiff line numberDiff line change
@@ -121,51 +121,35 @@ bool Operation::equals(const Operation& op, const Permutation& perm1,
121121
return false;
122122
}
123123

124-
// check controls
125-
if (nc1 != 0U) {
126-
Controls controls1{};
127-
if (perm1.empty()) {
128-
controls1 = getControls();
129-
} else {
130-
for (const auto& control : getControls()) {
131-
controls1.emplace(perm1.at(control.qubit), control.type);
132-
}
124+
if (isDiagonalGate()) {
125+
// check pos. controls and targets together
126+
const auto& usedQubits1 = getUsedQubitsPermuted(perm1);
127+
const auto& usedQubits2 = op.getUsedQubitsPermuted(perm2);
128+
if (usedQubits1 != usedQubits2) {
129+
return false;
133130
}
134131

135-
Controls controls2{};
136-
if (perm2.empty()) {
137-
controls2 = op.getControls();
138-
} else {
139-
for (const auto& control : op.getControls()) {
140-
controls2.emplace(perm2.at(control.qubit), control.type);
132+
std::set<Qubit> negControls1{};
133+
for (const auto& control : getControls()) {
134+
if (control.type == Control::Type::Neg) {
135+
negControls1.emplace(perm1.apply(control.qubit));
141136
}
142137
}
143-
144-
if (controls1 != controls2) {
145-
return false;
146-
}
147-
}
148-
149-
// check targets
150-
std::set<Qubit> targets1{};
151-
if (perm1.empty()) {
152-
targets1 = {getTargets().begin(), getTargets().end()};
153-
} else {
154-
for (const auto& target : getTargets()) {
155-
targets1.emplace(perm1.at(target));
138+
std::set<Qubit> negControls2{};
139+
for (const auto& control : op.getControls()) {
140+
if (control.type == Control::Type::Neg) {
141+
negControls2.emplace(perm2.apply(control.qubit));
142+
}
156143
}
144+
return negControls1 == negControls2;
157145
}
158-
159-
std::set<Qubit> targets2{};
160-
if (perm2.empty()) {
161-
targets2 = {op.getTargets().begin(), op.getTargets().end()};
162-
} else {
163-
for (const auto& target : op.getTargets()) {
164-
targets2.emplace(perm2.at(target));
165-
}
146+
// check controls
147+
if (nc1 != 0U &&
148+
perm1.apply(getControls()) != perm2.apply(op.getControls())) {
149+
return false;
166150
}
167151

168-
return targets1 == targets2;
152+
return perm1.apply(getTargets()) == perm2.apply(op.getTargets());
169153
}
170154

171155
void Operation::addDepthContribution(std::vector<std::size_t>& depths) const {
@@ -198,4 +182,19 @@ auto Operation::isInverseOf(const Operation& other) const -> bool {
198182
return operator==(*other.getInverted());
199183
}
200184

185+
auto Operation::getUsedQubitsPermuted(const qc::Permutation& perm) const
186+
-> std::set<Qubit> {
187+
std::set<Qubit> usedQubits;
188+
for (const auto& target : getTargets()) {
189+
usedQubits.emplace(perm.apply(target));
190+
}
191+
for (const auto& control : getControls()) {
192+
usedQubits.emplace(perm.apply(control.qubit));
193+
}
194+
return usedQubits;
195+
}
196+
197+
auto Operation::getUsedQubits() const -> std::set<Qubit> {
198+
return getUsedQubitsPermuted({});
199+
}
201200
} // namespace qc

test/datastructures/test_layer.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88
#include <gtest/gtest.h>
99
#include <memory>
10+
#include <stdexcept>
1011
#include <tuple>
1112
#include <vector>
1213

@@ -39,7 +40,7 @@ TEST(Layer, ExecutableSet1) {
3940
EXPECT_EQ(layer.getExecutableSet().size(), 1); // layer (1)
4041
std::shared_ptr<Layer::DAGVertex> v = *(layer.getExecutableSet()).begin();
4142
v->execute();
42-
EXPECT_ANY_THROW(v->execute());
43+
EXPECT_THROW(v->execute(), std::logic_error);
4344
EXPECT_EQ(layer.getExecutableSet().size(), 3); // layer (2)
4445
v = *(layer.getExecutableSet()).begin();
4546
v->execute();

test/na/test_nacomputation.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,11 @@ TEST(NAComputation, General) {
4646
"move (0, 1), (1, 1) to (4, 1), (5, 1);\n"
4747
"store (4, 1), (5, 1) to (4, 0), (5, 0);\n");
4848
}
49+
50+
TEST(NAComputation, EmptyPrint) {
51+
const NAComputation qc;
52+
std::stringstream ss;
53+
ss << qc;
54+
EXPECT_EQ(ss.str(), "init at;\n");
55+
}
4956
} // namespace na

test/na/test_naoperation.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <gtest/gtest.h>
99
#include <memory>
10+
#include <sstream>
1011
#include <vector>
1112

1213
namespace na {
@@ -49,4 +50,12 @@ TEST(NAOperation, LocalOperation) {
4950
EXPECT_ANY_THROW(
5051
NALocalOperation(FullOpType{qc::RY, 1}, std::make_shared<Point>(0, 0)));
5152
}
53+
54+
TEST(NAOperation, EmptyPrint) {
55+
const NALocalOperation op(FullOpType{qc::RY, 0}, std::vector{qc::PI_2},
56+
std::vector<std::shared_ptr<Point>>{});
57+
std::stringstream ss;
58+
ss << op;
59+
EXPECT_EQ(ss.str(), "ry(1.5708) at;\n");
60+
}
5261
} // namespace na

test/test_operation.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Definitions.hpp"
2+
#include "Permutation.hpp"
23
#include "operations/AodOperation.hpp"
34
#include "operations/CompoundOperation.hpp"
45
#include "operations/Expression.hpp"
@@ -139,6 +140,24 @@ TEST(Operation, IsDiagonalGate) {
139140
EXPECT_TRUE(op2.isDiagonalGate());
140141
}
141142

143+
TEST(Operation, Equality) {
144+
const qc::StandardOperation op1(0, qc::Z);
145+
const qc::StandardOperation op2(1, 0, qc::Z);
146+
const qc::StandardOperation op3(0, 1, qc::Z);
147+
const qc::StandardOperation op4({0, qc::Control::Type::Neg}, 1, qc::Z);
148+
EXPECT_FALSE(op1 == op2);
149+
EXPECT_TRUE(op2 == op3);
150+
EXPECT_TRUE(op3 == op2);
151+
EXPECT_FALSE(op2 == op4);
152+
153+
EXPECT_TRUE(op2.equals(op3, qc::Permutation{{{0, 0}, {1, 2}}},
154+
qc::Permutation{{{0, 2}, {1, 0}}}));
155+
EXPECT_FALSE(
156+
op2.equals(op3, qc::Permutation{{{0, 0}, {1, 2}}}, qc::Permutation{}));
157+
EXPECT_FALSE(op2.equals(op4, qc::Permutation{{{0, 0}, {1, 2}}},
158+
qc::Permutation{{{0, 2}, {1, 0}}}));
159+
}
160+
142161
TEST(StandardOperation, Move) {
143162
const qc::StandardOperation moveOp({0, 1}, qc::OpType::Move);
144163
EXPECT_EQ(moveOp.getTargets().size(), 2);

test/unittests/test_qfr_functionality.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -1706,7 +1706,6 @@ TEST_F(QFRFunctionality, addControlClassicControlledOperation) {
17061706
TEST_F(QFRFunctionality, addControlNonUnitaryOperation) {
17071707
auto op = NonUnitaryOperation(0U, Measure);
17081708

1709-
EXPECT_THROW(static_cast<void>(op.getControls()), QFRException);
17101709
EXPECT_THROW(op.addControl(1), QFRException);
17111710
EXPECT_THROW(op.removeControl(1), QFRException);
17121711
EXPECT_THROW(op.clearControls(), QFRException);

0 commit comments

Comments
 (0)