From a73d1e66272ab24cdee1e4a386702951474b4e69 Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Tue, 20 Aug 2024 18:02:44 -0400 Subject: [PATCH] Refactor QUnit (inlining) --- include/qunit.hpp | 49 ++++++++++++++++++++++++++++++++++++++++++++--- src/qunit.cpp | 47 --------------------------------------------- 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/include/qunit.hpp b/include/qunit.hpp index a36f278a2..6ffa3e095 100644 --- a/include/qunit.hpp +++ b/include/qunit.hpp @@ -466,9 +466,52 @@ class QUnit : public QParity, public QInterface { protected: virtual complex GetAmplitudeOrProb(bitCapInt perm, bool isProb); - virtual void XBase(bitLenInt target); - virtual void YBase(bitLenInt target); - virtual void ZBase(bitLenInt target); + virtual void XBase(bitLenInt target) + { + if (target >= qubitCount) { + throw std::invalid_argument("QUnit::XBase qubit index parameter must be within allocated qubit bounds!"); + } + + QEngineShard& shard = shards[target]; + + if (shard.unit) { + shard.unit->X(shard.mapped); + } + + std::swap(shard.amp0, shard.amp1); + } + + virtual void YBase(bitLenInt target) + { + if (target >= qubitCount) { + throw std::invalid_argument("QUnit::YBase qubit index parameter must be within allocated qubit bounds!"); + } + + QEngineShard& shard = shards[target]; + + if (shard.unit) { + shard.unit->Y(shard.mapped); + } + + const complex_x Y0 = shard.amp0; + shard.amp0 = -I_CMPLX_X * shard.amp1; + shard.amp1 = I_CMPLX_X * Y0; + } + + virtual void ZBase(bitLenInt target) + { + if (target >= qubitCount) { + throw std::invalid_argument("QUnit::ZBase qubit index parameter must be within allocated qubit bounds!"); + } + + QEngineShard& shard = shards[target]; + + if (shard.unit) { + shard.unit->Z(shard.mapped); + } + + shard.amp1 = -shard.amp1; + } virtual real1_f ProbBase(bitLenInt qubit); virtual bool TrySeparateClifford(bitLenInt qubit); diff --git a/src/qunit.cpp b/src/qunit.cpp index bd3dff46a..641f353c2 100644 --- a/src/qunit.cpp +++ b/src/qunit.cpp @@ -2143,53 +2143,6 @@ void QUnit::IS(bitLenInt target) shard.amp1 = -I_CMPLX_X * shard.amp1; } -void QUnit::XBase(bitLenInt target) -{ - if (target >= qubitCount) { - throw std::invalid_argument("QUnit::XBase qubit index parameter must be within allocated qubit bounds!"); - } - - QEngineShard& shard = shards[target]; - - if (shard.unit) { - shard.unit->X(shard.mapped); - } - - std::swap(shard.amp0, shard.amp1); -} - -void QUnit::YBase(bitLenInt target) -{ - if (target >= qubitCount) { - throw std::invalid_argument("QUnit::YBase qubit index parameter must be within allocated qubit bounds!"); - } - - QEngineShard& shard = shards[target]; - - if (shard.unit) { - shard.unit->Y(shard.mapped); - } - - const complex_x Y0 = shard.amp0; - shard.amp0 = -I_CMPLX_X * shard.amp1; - shard.amp1 = I_CMPLX_X * Y0; -} - -void QUnit::ZBase(bitLenInt target) -{ - if (target >= qubitCount) { - throw std::invalid_argument("QUnit::ZBase qubit index parameter must be within allocated qubit bounds!"); - } - - QEngineShard& shard = shards[target]; - - if (shard.unit) { - shard.unit->Z(shard.mapped); - } - - shard.amp1 = -shard.amp1; -} - #define CTRLED_GEN_WRAP(ctrld) \ ApplyEitherControlled( \ controlVec, { target }, \