From f9031f68b7d83f5cb75d65b6a31a8add2d5a815d Mon Sep 17 00:00:00 2001 From: WrathfulSpatula Date: Sat, 5 Oct 2024 10:41:53 -0400 Subject: [PATCH] OutProbs() --- CMakeLists.txt | 2 +- include/pinvoke_api.hpp | 2 ++ src/pinvoke_api.cpp | 54 ++++++++++++++++++++++++++++++++++++++--- src/wasm_api.cpp | 20 +++++++++++++-- 4 files changed, 72 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index daf626b7f..73579c5b4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required (VERSION 3.9) -project (Qrack VERSION 9.9.65 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX) +project (Qrack VERSION 9.10.0 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX) # Installation commands include (GNUInstallDirs) diff --git a/include/pinvoke_api.hpp b/include/pinvoke_api.hpp index 26478c935..a51e5ea20 100644 --- a/include/pinvoke_api.hpp +++ b/include/pinvoke_api.hpp @@ -128,9 +128,11 @@ MICROSOFT_QUANTUM_DECL void Dump(_In_ uintq sid, _In_ ProbAmpCallback callback); #if FPPOW < 6 MICROSOFT_QUANTUM_DECL void InKet(_In_ uintq sid, _In_ float* ket); MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ float* ket); +MICROSOFT_QUANTUM_DECL void OutProbs(_In_ uintq sid, _In_ float* ket); #else MICROSOFT_QUANTUM_DECL void InKet(_In_ uintq sid, _In_ double* ket); MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ double* ket); +MICROSOFT_QUANTUM_DECL void OutProbs(_In_ uintq sid, _In_ double* ket); #endif MICROSOFT_QUANTUM_DECL size_t random_choice(_In_ uintq sid, _In_ size_t n, _In_reads_(n) double* p); diff --git a/src/pinvoke_api.cpp b/src/pinvoke_api.cpp index 381449025..c5da9c5e2 100644 --- a/src/pinvoke_api.cpp +++ b/src/pinvoke_api.cpp @@ -1068,7 +1068,7 @@ MICROSOFT_QUANTUM_DECL void InKet(_In_ uintq sid, _In_ real1_s* ket) } /** - * (External API) Set state vector for the selected simulator ID. + * (External API) Get state vector for the selected simulator ID. */ MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ real1_s* ket) { @@ -1097,6 +1097,34 @@ MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ real1_s* ket) #endif } +/** + * (External API) Get basis dimension probabilities for the selected simulator ID. + */ +MICROSOFT_QUANTUM_DECL void OutProb(_In_ uintq sid, _In_ real1_s* probs) +{ + SIMULATOR_LOCK_GUARD_VOID(sid) + +#if (FPPOW == 5) || (FPPOW == 6) + try { + simulator->GetProbs(probs); + } catch (const std::exception& ex) { + simulatorErrors[sid] = 1; + std::cout << ex.what() << std::endl; + } +#else + const size_t maxQPower = (size_t)simulator->GetMaxQPower(); + std::unique_ptr _probs(new real1[maxQPower]); + try { + simulator->GetProbs(_probs.get()); + } catch (const std::exception& ex) { + simulatorErrors[sid] = 1; + std::cout << ex.what() << std::endl; + return; + } + std::transform(_probs, _probs + maxQPower, probs, [](real1 c) { return (real1_s)c; }); +#endif +} + /** * (External API) Select from a distribution of "n" elements according the discrete probabilities in "d." */ @@ -2439,6 +2467,18 @@ MICROSOFT_QUANTUM_DECL void ProbAll(_In_ uintq sid, _In_ uintq n, _In_reads_(n) _q[i] = shards[simulator.get()][q[i]]; } + bool isOutProbs = false; + if (_q.size() == simulator->GetQubitCount()) { + isOutProbs = true; + for (size_t i = 0U; i < _q.size(); ++i) { + if (_q[i] == i) { + continue; + } + isOutProbs = false; + break; + } + } + #if (FPPOW < 5) || (FPPOW > 6) const bitCapIntOcl npow = pow2Ocl(n); std::unique_ptr _p(new real1[npow]); @@ -2446,10 +2486,18 @@ MICROSOFT_QUANTUM_DECL void ProbAll(_In_ uintq sid, _In_ uintq n, _In_reads_(n) try { #if (FPPOW < 5) || (FPPOW > 6) - simulator->ProbBitsAll(_q, _p.get()); + if (isOutProbs) { + simulator->GetProbs(_p.get()); + } else { + simulator->ProbBitsAll(_q, _p.get()); + } std::transform(_p.get(), _p.get() + npow, p, [](real1 c) { return (real1_s)c; }); #else - simulator->ProbBitsAll(_q, p); + if (isOutProbs) { + simulator->GetProbs(p); + } else { + simulator->ProbBitsAll(_q, p); + } #endif } catch (const std::exception& ex) { simulatorErrors[sid] = 1; diff --git a/src/wasm_api.cpp b/src/wasm_api.cpp index 05a22f9ec..efc567c0a 100644 --- a/src/wasm_api.cpp +++ b/src/wasm_api.cpp @@ -1738,12 +1738,28 @@ std::vector ProbAll(quid sid, std::vector q) return std::vector(); } - for (size_t i = 0; i < q.size(); ++i) { + for (size_t i = 0U; i < q.size(); ++i) { q[i] = shards[simulator.get()][q[i]]; } + bool isOutProbs = false; + if (q.size() == simulator->GetQubitCount()) { + isOutProbs = true; + for (size_t i = 0U; i < q.size(); ++i) { + if (q[i] == i) { + continue; + } + isOutProbs = false; + break; + } + } + std::vector p(pow2Ocl(q.size())); - simulator->ProbBitsAll(q, &(p[0U])); + if (isOutProbs) { + simulator->GetProbs(&(p[0U])); + } else { + simulator->ProbBitsAll(q, &(p[0U])); + } return p; }