diff --git a/include/mqt-core/operations/Expression.hpp b/include/mqt-core/operations/Expression.hpp index 85829e316..4ddd43b6b 100644 --- a/include/mqt-core/operations/Expression.hpp +++ b/include/mqt-core/operations/Expression.hpp @@ -91,6 +91,10 @@ class Term { coeff /= rhs; return *this; } + Term& operator/=(const std::int64_t rhs) { + coeff /= static_cast(rhs); + return *this; + } [[nodiscard]] bool totalAssignment(const VariableAssignment& assignment) const { return assignment.find(getVar()) != assignment.end(); @@ -291,6 +295,16 @@ class Expression { return *this; } + Expression& operator/=(int64_t rhs) { + if (rhs == 0) { + throw std::runtime_error("Trying to divide expression by 0!"); + } + std::for_each(terms.begin(), terms.end(), + [&](auto& term) { term /= T{static_cast(rhs)}; }); + constant = U{double{constant} / static_cast(rhs)}; + return *this; + } + [[nodiscard]] Expression operator-() const { Expression e; e.terms.reserve(terms.size()); @@ -451,6 +465,12 @@ inline Expression operator/(Expression lhs, const U& rhs) { return lhs; } +template +inline Expression operator/(Expression lhs, int64_t rhs) { + lhs /= rhs; + return lhs; +} + template inline Expression operator*(const T& lhs, Expression rhs) { return rhs * lhs; diff --git a/include/mqt-core/zx/FunctionalityConstruction.hpp b/include/mqt-core/zx/FunctionalityConstruction.hpp index cd00ce1f2..66b1709ba 100644 --- a/include/mqt-core/zx/FunctionalityConstruction.hpp +++ b/include/mqt-core/zx/FunctionalityConstruction.hpp @@ -5,6 +5,7 @@ #include "zx/ZXDiagram.hpp" #include +#include namespace zx { class FunctionalityConstruction { @@ -28,6 +29,16 @@ class FunctionalityConstruction { std::vector& qubits, const PiExpression& phase = PiExpression(), EdgeType type = EdgeType::Simple); + static void + addRz(ZXDiagram& diag, const PiExpression& phase, Qubit target, + std::vector& qubits, + const std::optional& unconvertedPhase = std::nullopt); + static void addRx(ZXDiagram& diag, const PiExpression& phase, Qubit target, + std::vector& qubits); + static void + addRy(ZXDiagram& diag, const PiExpression& phase, Qubit target, + std::vector& qubits, + const std::optional& unconvertedPhase = std::nullopt); static void addCnot(ZXDiagram& diag, Qubit ctrl, Qubit target, std::vector& qubits, EdgeType type = EdgeType::Simple); @@ -35,14 +46,30 @@ class FunctionalityConstruction { Qubit target, std::vector& qubits); static void addSwap(ZXDiagram& diag, Qubit target, Qubit target2, std::vector& qubits); - static void addRzz(ZXDiagram& diag, const PiExpression& phase, Qubit target, - Qubit target2, std::vector& qubits); - static void addRxx(ZXDiagram& diag, const PiExpression& phase, Qubit target, - Qubit target2, std::vector& qubits); - static void addRzx(ZXDiagram& diag, const PiExpression& phase, Qubit target, - Qubit target2, std::vector& qubits); + static void + addRzz(ZXDiagram& diag, const PiExpression& phase, Qubit target, + Qubit target2, std::vector& qubits, + const std::optional& unconvertedPhase = std::nullopt); + static void + addRxx(ZXDiagram& diag, const PiExpression& phase, Qubit target, + Qubit target2, std::vector& qubits, + const std::optional& unconvertedPhase = std::nullopt); + static void + addRzx(ZXDiagram& diag, const PiExpression& phase, Qubit target, + Qubit target2, std::vector& qubits, + const std::optional& unconvertedPhase = std::nullopt); static void addDcx(ZXDiagram& diag, Qubit qubit1, Qubit qubit2, std::vector& qubits); + static void + addXXplusYY(ZXDiagram& diag, const PiExpression& theta, + const PiExpression& beta, Qubit qubit0, Qubit qubit1, + std::vector& qubits, + const std::optional& unconvertedBeta = std::nullopt); + static void + addXXminusYY(ZXDiagram& diag, const PiExpression& theta, + const PiExpression& beta, Qubit qubit0, Qubit qubit1, + std::vector& qubits, + const std::optional& unconvertedBeta = std::nullopt); static void addCcx(ZXDiagram& diag, Qubit ctrl0, Qubit ctrl1, Qubit target, std::vector& qubits); static op_it parseOp(ZXDiagram& diag, op_it it, op_it end, diff --git a/src/zx/FunctionalityConstruction.cpp b/src/zx/FunctionalityConstruction.cpp index 3c4ae4031..5c4d863a8 100644 --- a/src/zx/FunctionalityConstruction.cpp +++ b/src/zx/FunctionalityConstruction.cpp @@ -1,12 +1,12 @@ #include "zx/FunctionalityConstruction.hpp" -#include "Definitions.hpp" #include "operations/OpType.hpp" #include "zx/Rational.hpp" #include "zx/ZXDefinitions.hpp" #include "zx/ZXDiagram.hpp" #include +#include #include #include #include @@ -32,8 +32,7 @@ bool FunctionalityConstruction::checkSwap(const op_it& it, const op_it& end, return false; } -void FunctionalityConstruction::addZSpider(ZXDiagram& diag, - const zx::Qubit qubit, +void FunctionalityConstruction::addZSpider(ZXDiagram& diag, const Qubit qubit, std::vector& qubits, const PiExpression& phase, const EdgeType type) { @@ -63,6 +62,42 @@ void FunctionalityConstruction::addXSpider(ZXDiagram& diag, const Qubit qubit, qubits[q] = newVertex; } +void FunctionalityConstruction::addRz( + ZXDiagram& diag, const PiExpression& phase, const Qubit target, + std::vector& qubits, + const std::optional& unconvertedPhase) { + if (unconvertedPhase.has_value()) { + diag.addGlobalPhase( + PiExpression(PiRational(-unconvertedPhase.value() / 2))); + } else { + diag.addGlobalPhase(-(phase / 2)); + } + addZSpider(diag, target, qubits, phase); +} + +void FunctionalityConstruction::addRx(ZXDiagram& diag, + const PiExpression& phase, + const Qubit target, + std::vector& qubits) { + addXSpider(diag, target, qubits, phase); +} + +void FunctionalityConstruction::addRy( + ZXDiagram& diag, const PiExpression& phase, const Qubit target, + std::vector& qubits, + const std::optional& unconvertedPhase) { + if (unconvertedPhase.has_value()) { + diag.addGlobalPhase( + PiExpression(PiRational(-unconvertedPhase.value() / 2))); + } else { + diag.addGlobalPhase(-(phase / 2)); + } + addXSpider(diag, target, qubits, PiExpression(PiRational(1, 2))); + addZSpider(diag, target, qubits, phase + PiRational(1, 1)); + addXSpider(diag, target, qubits, PiExpression(PiRational(1, 2))); + addZSpider(diag, target, qubits, PiExpression(PiRational(1, 1))); +} + void FunctionalityConstruction::addCnot(ZXDiagram& diag, const Qubit ctrl, const Qubit target, std::vector& qubits, @@ -88,63 +123,129 @@ void FunctionalityConstruction::addCphase(ZXDiagram& diag, addZSpider(diag, target, qubits, newPhase); } -void FunctionalityConstruction::addRzz(ZXDiagram& diag, - const PiExpression& phase, - const Qubit target, const Qubit target2, - std::vector& qubits) { +void FunctionalityConstruction::addRzz( + ZXDiagram& diag, const PiExpression& phase, const Qubit target, + const Qubit target2, std::vector& qubits, + const std::optional& unconvertedPhase) { addZSpider(diag, target, qubits); addZSpider(diag, target2, qubits); const auto midX = - diag.addVertex(-1, -1, PiExpression(PiRational(0, 1)), zx::VertexType::X); - const auto midZ = diag.addVertex(-1, -1, phase, zx::VertexType::Z); + diag.addVertex(-1, -1, PiExpression(PiRational(0, 1)), VertexType::X); + const auto midZ = diag.addVertex(-1, -1, phase, VertexType::Z); diag.addEdge(qubits[static_cast(target)], midX); diag.addEdge(qubits[static_cast(target2)], midX); diag.addEdge(midX, midZ); - diag.addGlobalPhase(-phase / 2.0); + + if (unconvertedPhase.has_value()) { + diag.addGlobalPhase( + PiExpression(PiRational(-unconvertedPhase.value() / 2))); + } else { + diag.addGlobalPhase(-(phase / 2)); + } } -void FunctionalityConstruction::addRxx(ZXDiagram& diag, - const PiExpression& phase, - const Qubit target, const Qubit target2, - std::vector& qubits) { +void FunctionalityConstruction::addRxx( + ZXDiagram& diag, const PiExpression& phase, const Qubit target, + const Qubit target2, std::vector& qubits, + const std::optional& unconvertedPhase) { addXSpider(diag, target, qubits); addXSpider(diag, target2, qubits); const auto midZ = - diag.addVertex(-1, -1, PiExpression(PiRational(0, 1)), zx::VertexType::Z); - const auto midX = diag.addVertex(-1, -1, phase, zx::VertexType::X); + diag.addVertex(-1, -1, PiExpression(PiRational(0, 1)), VertexType::Z); + const auto midX = diag.addVertex(-1, -1, phase, VertexType::X); diag.addEdge(qubits[static_cast(target)], midZ); diag.addEdge(qubits[static_cast(target2)], midZ); diag.addEdge(midZ, midX); - diag.addGlobalPhase(-phase / 2.0); + + if (unconvertedPhase.has_value()) { + diag.addGlobalPhase( + PiExpression(PiRational(-unconvertedPhase.value() / 2))); + } else { + diag.addGlobalPhase(-(phase / 2)); + } } -void FunctionalityConstruction::addRzx(ZXDiagram& diag, - const PiExpression& phase, - const Qubit target, const Qubit target2, - std::vector& qubits) { +void FunctionalityConstruction::addRzx( + ZXDiagram& diag, const PiExpression& phase, const Qubit target, + const Qubit target2, std::vector& qubits, + const std::optional& unconvertedPhase) { addZSpider(diag, target, qubits); addXSpider(diag, target2, qubits); const auto midX = - diag.addVertex(-1, -1, PiExpression(PiRational(0, 1)), zx::VertexType::X); - const auto midZ = diag.addVertex(-1, -1, phase, zx::VertexType::Z); + diag.addVertex(-1, -1, PiExpression(PiRational(0, 1)), VertexType::X); + const auto midZ = diag.addVertex(-1, -1, phase, VertexType::Z); diag.addEdge(qubits[static_cast(target)], midX); diag.addEdge(qubits[static_cast(target2)], midX, EdgeType::Hadamard); diag.addEdge(midX, midZ); - diag.addGlobalPhase(-phase / 2.0); + + if (unconvertedPhase.has_value()) { + diag.addGlobalPhase( + PiExpression(PiRational(-unconvertedPhase.value() / 2))); + } else { + diag.addGlobalPhase(-(phase / 2)); + } } -void FunctionalityConstruction::addDcx(zx::ZXDiagram& diag, - const zx::Qubit qubit1, - const zx::Qubit qubit2, +void FunctionalityConstruction::addDcx(ZXDiagram& diag, const Qubit qubit1, + const Qubit qubit2, std::vector& qubits) { addCnot(diag, qubit1, qubit2, qubits); addCnot(diag, qubit2, qubit1, qubits); } +void FunctionalityConstruction::addXXplusYY( + ZXDiagram& diag, const PiExpression& theta, const PiExpression& beta, + const Qubit qubit0, const Qubit qubit1, std::vector& qubits, + const std::optional& unconvertedBeta) { + addRz(diag, beta, qubit1, qubits, unconvertedBeta); + addRz(diag, PiExpression(PiRational(1, 2)), qubit1, qubits); + addRz(diag, PiExpression(PiRational(-1, 2)), qubit0, qubits); + addRx(diag, PiExpression(PiRational(1, 2)), qubit0, qubits); + addRz(diag, PiExpression(PiRational(1, 2)), qubit0, qubits); + addCnot(diag, qubit0, qubit1, qubits); + addRy(diag, theta / 2, qubit0, qubits); + addRy(diag, theta / 2, qubit1, qubits); + addCnot(diag, qubit0, qubit1, qubits); + addRz(diag, PiExpression(PiRational(-1, 2)), qubit0, qubits); + addRx(diag, PiExpression(PiRational(-1, 2)), qubit0, qubits); + addRz(diag, PiExpression(PiRational(1, 2)), qubit0, qubits); + if (unconvertedBeta.has_value()) { + addRz(diag, -beta, qubit1, qubits, -unconvertedBeta.value()); + } else { + addRz(diag, -beta, qubit1, qubits); + } + + addRz(diag, PiExpression(-PiRational(1, 2)), qubit1, qubits); +} + +void FunctionalityConstruction::addXXminusYY( + ZXDiagram& diag, const PiExpression& theta, const PiExpression& beta, + const Qubit qubit0, const Qubit qubit1, std::vector& qubits, + const std::optional& unconvertedBeta) { + if (unconvertedBeta.has_value()) { + addRz(diag, -beta, qubit1, qubits, -unconvertedBeta.value()); + } else { + addRz(diag, -beta, qubit1, qubits); + } + addRz(diag, PiExpression(PiRational(1, 2)), qubit1, qubits); + addRz(diag, PiExpression(PiRational(-1, 2)), qubit0, qubits); + addRx(diag, PiExpression(PiRational(1, 2)), qubit0, qubits); + addRz(diag, PiExpression(PiRational(1, 2)), qubit0, qubits); + addCnot(diag, qubit0, qubit1, qubits); + addRy(diag, -theta / 2, qubit0, qubits); + addRy(diag, theta / 2, qubit1, qubits); + addCnot(diag, qubit0, qubit1, qubits); + addRz(diag, PiExpression(PiRational(-1, 2)), qubit0, qubits); + addRx(diag, PiExpression(PiRational(-1, 2)), qubit0, qubits); + addRz(diag, PiExpression(PiRational(1, 2)), qubit0, qubits); + addRz(diag, beta, qubit1, qubits, unconvertedBeta); + addRz(diag, PiExpression(-PiRational(1, 2)), qubit1, qubits); +} + void FunctionalityConstruction::addSwap(ZXDiagram& diag, const Qubit target, const Qubit target2, std::vector& qubits) { @@ -201,7 +302,7 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, if (!op->isControlled()) { // single qubit gates - const auto target = static_cast(p.at(op->getTargets().front())); + const auto target = static_cast(p.at(op->getTargets().front())); switch (op->getType()) { case qc::OpType::GPhase: { const auto& param = parseParam(op.get(), 0); @@ -212,10 +313,12 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, addZSpider(diag, target, qubits, PiExpression(PiRational(1, 1))); break; case qc::OpType::RZ: { - const auto& param = parseParam(op.get(), 0); - diag.addGlobalPhase(-param / 2.0); - - addZSpider(diag, target, qubits, parseParam(op.get(), 0)); + const auto& phase = parseParam(op.get(), 0); + if (phase.isConstant()) { + addRz(diag, phase, target, qubits, op->getParameter().at(0)); + } else { + addRz(diag, phase, target, qubits); + } break; } case qc::OpType::P: @@ -225,25 +328,22 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, addXSpider(diag, target, qubits, PiExpression(PiRational(1, 1))); break; case qc::OpType::RX: - addXSpider(diag, target, qubits, parseParam(op.get(), 0)); + addRx(diag, parseParam(op.get(), 0), target, qubits); break; case qc::OpType::Y: diag.addGlobalPhase(PiExpression{-PiRational(1, 2)}); - addZSpider(diag, target, qubits, PiExpression(PiRational(1, 1))); addXSpider(diag, target, qubits, PiExpression(PiRational(1, 1))); break; - case qc::OpType::RY: - diag.addGlobalPhase( - PiExpression(-PiRational(op->getParameter().front()) / 2 + - PiRational(1, 2) + PiRational(3, 2))); - - addXSpider(diag, target, qubits, PiExpression(PiRational(1, 2))); - addZSpider(diag, target, qubits, - parseParam(op.get(), 0) + PiRational(1, 1)); - addXSpider(diag, target, qubits, PiExpression(PiRational(1, 2))); - addZSpider(diag, target, qubits, PiExpression(PiRational(3, 1))); + case qc::OpType::RY: { + const auto& phase = parseParam(op.get(), 0); + if (phase.isConstant()) { + addRy(diag, phase, target, qubits, op->getParameter().at(0)); + } else { + addRy(diag, phase, target, qubits); + } break; + } case qc::OpType::T: addZSpider(diag, target, qubits, PiExpression(PiRational(1, 4))); break; @@ -273,12 +373,12 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, parseParam(op.get(), 1) + PiRational(3, 1)); break; case qc::OpType::SWAP: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto target2 = static_cast(p.at(op->getTargets()[1])); addSwap(diag, target, target2, qubits); break; } case qc::OpType::iSWAP: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto target2 = static_cast(p.at(op->getTargets()[1])); addZSpider(diag, target, qubits, PiExpression(PiRational(1, 2))); addZSpider(diag, target2, qubits, PiExpression(PiRational(1, 2))); addZSpider(diag, target, qubits, PiExpression(), EdgeType::Hadamard); @@ -289,45 +389,88 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, break; } case qc::OpType::RZZ: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); - addRzz(diag, parseParam(op.get(), 0), target, target2, qubits); + const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto& phase = parseParam(op.get(), 0); + if (phase.isConstant()) { + addRzz(diag, phase, target, target2, qubits, op->getParameter().at(0)); + } else { + addRzz(diag, phase, target, target2, qubits); + } break; } case qc::OpType::RXX: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); - addRxx(diag, parseParam(op.get(), 0), target, target2, qubits); + const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto& phase = parseParam(op.get(), 0); + if (phase.isConstant()) { + addRxx(diag, phase, target, target2, qubits, op->getParameter().at(0)); + } else { + addRxx(diag, phase, target, target2, qubits); + } break; } case qc::OpType::RZX: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); - addRzx(diag, parseParam(op.get(), 0), target, target2, qubits); + const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto& phase = parseParam(op.get(), 0); + if (phase.isConstant()) { + addRzx(diag, phase, target, target2, qubits, op->getParameter().at(0)); + } else { + addRzx(diag, phase, target, target2, qubits); + } break; } case qc::OpType::RYY: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto target2 = static_cast(p.at(op->getTargets()[1])); const auto param = parseParam(op.get(), 0); addXSpider(diag, target, qubits, PiExpression(PiRational(1, 2))); addXSpider(diag, target2, qubits, PiExpression(PiRational(1, 2))); - addRzz(diag, parseParam(op.get(), 0), target, target2, qubits); + if (param.isConstant()) { + addRzz(diag, param, target, target2, qubits, op->getParameter().at(0)); + } else { + addRzz(diag, param, target, target2, qubits); + } addXSpider(diag, target2, qubits, PiExpression(-PiRational(1, 2))); addXSpider(diag, target, qubits, PiExpression(-PiRational(1, 2))); break; } case qc::OpType::DCX: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto target2 = static_cast(p.at(op->getTargets()[1])); addDcx(diag, target, target2, qubits); break; } case qc::OpType::ECR: { - const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto target2 = static_cast(p.at(op->getTargets()[1])); addRzx(diag, PiExpression(PiRational(1, 4)), target, target2, qubits); addXSpider(diag, target, qubits); addRzx(diag, PiExpression(-PiRational(1, 4)), target, target2, qubits); break; } + case qc::OpType::XXplusYY: { + const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto& betaExpr = parseParam(op.get(), 0); + if (betaExpr.isConstant()) { + addXXplusYY(diag, betaExpr, parseParam(op.get(), 1), target, target2, + qubits, op->getParameter().at(0)); + } else { + addXXplusYY(diag, betaExpr, parseParam(op.get(), 1), target, target2, + qubits); + } + break; + } + case qc::OpType::XXminusYY: { + const auto target2 = static_cast(p.at(op->getTargets()[1])); + const auto& betaExpr = parseParam(op.get(), 0); + if (betaExpr.isConstant()) { + addXXminusYY(diag, betaExpr, parseParam(op.get(), 1), target, target2, + qubits, op->getParameter().at(0)); + } else { + addXXminusYY(diag, betaExpr, parseParam(op.get(), 1), target, target2, + qubits); + } + break; + } case qc::OpType::H: addZSpider(diag, target, qubits, PiExpression(), EdgeType::Hadamard); break; @@ -346,9 +489,9 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, } } else if (op->getNcontrols() == 1 && op->getNtargets() == 1) { // two-qubit controlled gates - const auto target = static_cast(p.at(op->getTargets().front())); + const auto target = static_cast(p.at(op->getTargets().front())); const auto ctrl = - static_cast(p.at((*op->getControls().begin()).qubit)); + static_cast(p.at((*op->getControls().begin()).qubit)); switch (op->getType()) { // TODO: any gate can be controlled case qc::OpType::X: // check if swap @@ -376,21 +519,19 @@ FunctionalityConstruction::parseOp(ZXDiagram& diag, op_it it, op_it end, break; case qc::OpType::T: - addCphase(diag, zx::PiExpression{PiRational(1, 4)}, ctrl, target, qubits); + addCphase(diag, PiExpression{PiRational(1, 4)}, ctrl, target, qubits); break; case qc::OpType::S: - addCphase(diag, zx::PiExpression{PiRational(1, 2)}, ctrl, target, qubits); + addCphase(diag, PiExpression{PiRational(1, 2)}, ctrl, target, qubits); break; case qc::OpType::Tdg: - addCphase(diag, zx::PiExpression{PiRational(-1, 4)}, ctrl, target, - qubits); + addCphase(diag, PiExpression{PiRational(-1, 4)}, ctrl, target, qubits); break; case qc::OpType::Sdg: - addCphase(diag, zx::PiExpression{PiRational(-1, 2)}, ctrl, target, - qubits); + addCphase(diag, PiExpression{PiRational(-1, 2)}, ctrl, target, qubits); break; default: throw ZXException("Unsupported Controlled Operation: " + @@ -518,6 +659,8 @@ bool FunctionalityConstruction::transformableToZX(const qc::Operation* op) { case qc::OpType::RYY: case qc::OpType::DCX: case qc::OpType::ECR: + case qc::OpType::XXplusYY: + case qc::OpType::XXminusYY: return true; default: return false; @@ -554,13 +697,13 @@ PiExpression FunctionalityConstruction::parseParam(const qc::Operation* op, if (const auto* symbOp = dynamic_cast(op)) { return toPiExpr(symbOp->getParameter(i)); } - return PiExpression{zx::PiRational{op->getParameter().at(i)}}; + return PiExpression{PiRational{op->getParameter().at(i)}}; } PiExpression FunctionalityConstruction::toPiExpr(const qc::SymbolOrNumber& param) { if (std::holds_alternative(param)) { - return zx::PiExpression{zx::PiRational{std::get(param)}}; + return PiExpression{PiRational{std::get(param)}}; } - return std::get(param).convert(); + return std::get(param).convert(); } } // namespace zx diff --git a/test/zx/test_zx_functionality.cpp b/test/zx/test_zx_functionality.cpp index f08753c31..5eaded765 100644 --- a/test/zx/test_zx_functionality.cpp +++ b/test/zx/test_zx_functionality.cpp @@ -304,3 +304,75 @@ TEST_F(ZXFunctionalityTest, ISWAP) { EXPECT_TRUE(d.globalPhaseIsZero()); EXPECT_TRUE(d.connected(d.getInput(0), d.getOutput(0))); } + +TEST_F(ZXFunctionalityTest, XXplusYY) { + const auto theta = zx::PI / 4.; + const auto beta = zx::PI / 2.; + + qc = qc::QuantumComputation(2); + qc.xx_plus_yy(theta, beta, 0, 1); + + auto qcPrime = qc::QuantumComputation(2); + qcPrime.rz(beta, 1); + qcPrime.rz(-qc::PI_2, 0); + qcPrime.sx(0); + qcPrime.rz(qc::PI_2, 0); + qcPrime.s(1); + qcPrime.cx(0, 1); + qcPrime.ry(theta / 2, 0); + qcPrime.ry(theta / 2, 1); + qcPrime.cx(0, 1); + qcPrime.rz(-qc::PI_2, 0); + qcPrime.sdg(1); + qcPrime.sxdg(0); + qcPrime.rz(qc::PI_2, 0); + qcPrime.rz(-beta, 1); + + auto d = zx::FunctionalityConstruction::buildFunctionality(&qc); + + auto dPrime = zx::FunctionalityConstruction::buildFunctionality(&qcPrime); + + d.concat(dPrime.invert()); + + zx::fullReduce(d); + + EXPECT_TRUE(d.isIdentity()); + EXPECT_TRUE(d.globalPhaseIsZero()); + EXPECT_TRUE(d.connected(d.getInput(0), d.getOutput(0))); +} + +TEST_F(ZXFunctionalityTest, XXminusYY) { + const auto theta = zx::PI / 4.; + const auto beta = -zx::PI / 2.; + + qc = qc::QuantumComputation(2); + qc.xx_minus_yy(theta, beta, 0, 1); + + auto qcPrime = qc::QuantumComputation(2); + qcPrime.rz(-beta, 1); + qcPrime.rz(-qc::PI_2, 0); + qcPrime.sx(0); + qcPrime.rz(qc::PI_2, 0); + qcPrime.s(1); + qcPrime.cx(0, 1); + qcPrime.ry(-theta / 2, 0); + qcPrime.ry(theta / 2, 1); + qcPrime.cx(0, 1); + qcPrime.sdg(1); + qcPrime.rz(-qc::PI_2, 0); + qcPrime.sxdg(0); + qcPrime.rz(qc::PI_2, 0); + qcPrime.rz(beta, 1); + + auto d = zx::FunctionalityConstruction::buildFunctionality(&qc); + + auto dPrime = zx::FunctionalityConstruction::buildFunctionality(&qcPrime); + + d.concat(dPrime.invert()); + + zx::fullReduce(d); + + EXPECT_TRUE(d.isIdentity()); + EXPECT_TRUE(d.globalPhaseIsZero()); + EXPECT_TRUE(d.connected(d.getInput(0), d.getOutput(0))); +}