Skip to content

Commit

Permalink
OutProbs()
Browse files Browse the repository at this point in the history
  • Loading branch information
WrathfulSpatula committed Oct 5, 2024
1 parent c959241 commit f9031f6
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required (VERSION 3.9)
project (Qrack VERSION 9.9.65 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX)
project (Qrack VERSION 9.10.0 DESCRIPTION "High Performance Quantum Bit Simulation" LANGUAGES CXX)

# Installation commands
include (GNUInstallDirs)
Expand Down
2 changes: 2 additions & 0 deletions include/pinvoke_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ MICROSOFT_QUANTUM_DECL void Dump(_In_ uintq sid, _In_ ProbAmpCallback callback);
#if FPPOW < 6
MICROSOFT_QUANTUM_DECL void InKet(_In_ uintq sid, _In_ float* ket);
MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ float* ket);
MICROSOFT_QUANTUM_DECL void OutProbs(_In_ uintq sid, _In_ float* ket);
#else
MICROSOFT_QUANTUM_DECL void InKet(_In_ uintq sid, _In_ double* ket);
MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ double* ket);
MICROSOFT_QUANTUM_DECL void OutProbs(_In_ uintq sid, _In_ double* ket);
#endif

MICROSOFT_QUANTUM_DECL size_t random_choice(_In_ uintq sid, _In_ size_t n, _In_reads_(n) double* p);
Expand Down
54 changes: 51 additions & 3 deletions src/pinvoke_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ MICROSOFT_QUANTUM_DECL void InKet(_In_ uintq sid, _In_ real1_s* ket)
}

/**
* (External API) Set state vector for the selected simulator ID.
* (External API) Get state vector for the selected simulator ID.
*/
MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ real1_s* ket)
{
Expand Down Expand Up @@ -1097,6 +1097,34 @@ MICROSOFT_QUANTUM_DECL void OutKet(_In_ uintq sid, _In_ real1_s* ket)
#endif
}

/**
* (External API) Get basis dimension probabilities for the selected simulator ID.
*/
MICROSOFT_QUANTUM_DECL void OutProb(_In_ uintq sid, _In_ real1_s* probs)
{
SIMULATOR_LOCK_GUARD_VOID(sid)

#if (FPPOW == 5) || (FPPOW == 6)
try {
simulator->GetProbs(probs);
} catch (const std::exception& ex) {
simulatorErrors[sid] = 1;
std::cout << ex.what() << std::endl;
}
#else
const size_t maxQPower = (size_t)simulator->GetMaxQPower();
std::unique_ptr<real1[]> _probs(new real1[maxQPower]);
try {
simulator->GetProbs(_probs.get());
} catch (const std::exception& ex) {
simulatorErrors[sid] = 1;
std::cout << ex.what() << std::endl;
return;
}
std::transform(_probs, _probs + maxQPower, probs, [](real1 c) { return (real1_s)c; });
#endif
}

/**
* (External API) Select from a distribution of "n" elements according the discrete probabilities in "d."
*/
Expand Down Expand Up @@ -2439,17 +2467,37 @@ MICROSOFT_QUANTUM_DECL void ProbAll(_In_ uintq sid, _In_ uintq n, _In_reads_(n)
_q[i] = shards[simulator.get()][q[i]];
}

bool isOutProbs = false;
if (_q.size() == simulator->GetQubitCount()) {
isOutProbs = true;
for (size_t i = 0U; i < _q.size(); ++i) {
if (_q[i] == i) {
continue;
}
isOutProbs = false;
break;
}
}

#if (FPPOW < 5) || (FPPOW > 6)
const bitCapIntOcl npow = pow2Ocl(n);
std::unique_ptr<real1[]> _p(new real1[npow]);
#endif

try {
#if (FPPOW < 5) || (FPPOW > 6)
simulator->ProbBitsAll(_q, _p.get());
if (isOutProbs) {
simulator->GetProbs(_p.get());
} else {
simulator->ProbBitsAll(_q, _p.get());
}
std::transform(_p.get(), _p.get() + npow, p, [](real1 c) { return (real1_s)c; });
#else
simulator->ProbBitsAll(_q, p);
if (isOutProbs) {
simulator->GetProbs(p);
} else {
simulator->ProbBitsAll(_q, p);
}
#endif
} catch (const std::exception& ex) {
simulatorErrors[sid] = 1;
Expand Down
20 changes: 18 additions & 2 deletions src/wasm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1738,12 +1738,28 @@ std::vector<real1> ProbAll(quid sid, std::vector<bitLenInt> q)
return std::vector<real1>();
}

for (size_t i = 0; i < q.size(); ++i) {
for (size_t i = 0U; i < q.size(); ++i) {
q[i] = shards[simulator.get()][q[i]];
}

bool isOutProbs = false;
if (q.size() == simulator->GetQubitCount()) {
isOutProbs = true;
for (size_t i = 0U; i < q.size(); ++i) {
if (q[i] == i) {
continue;
}
isOutProbs = false;
break;
}
}

std::vector<real1> p(pow2Ocl(q.size()));
simulator->ProbBitsAll(q, &(p[0U]));
if (isOutProbs) {
simulator->GetProbs(&(p[0U]));
} else {
simulator->ProbBitsAll(q, &(p[0U]));
}

return p;
}
Expand Down

0 comments on commit f9031f6

Please sign in to comment.