Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Mar 5, 2025
1 parent e0eaa8d commit 1f1292a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 15 deletions.
45 changes: 44 additions & 1 deletion thrust/testing/counting_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/sort.h>

#include <cuda/std/__algorithm_>
#include <cuda/std/iterator>
#include <cuda/std/type_traits>

#include <complex>
#include <cstdint>
#include <numeric>

#include <unittest/unittest.h>

Expand Down Expand Up @@ -260,7 +263,7 @@ DECLARE_UNITTEST(TestCountingIteratorDifference);
void TestCountingIteratorDynamicStride()
{
auto iter = thrust::make_counting_iterator(0, 2);
// static_assert(sizeof(iter) == 2 * sizeof(int));
static_assert(sizeof(iter) == 2 * sizeof(int));

ASSERT_EQUAL(*iter, 0);
iter++;
Expand Down Expand Up @@ -293,4 +296,44 @@ void TestCountingIteratorStaticStride()
}
DECLARE_UNITTEST(TestCountingIteratorStaticStride);

void TestCountingIteratorPointer()
{
int arr[11];
std::iota(arr, arr + 11, 0);

auto iter = thrust::make_counting_iterator(&arr[2]);

ASSERT_EQUAL(*iter, &arr[2]);
ASSERT_EQUAL(**iter, 2);
iter++;
ASSERT_EQUAL(*iter, &arr[3]);
ASSERT_EQUAL(**iter, 3);
iter++;
iter++;
ASSERT_EQUAL(*iter, &arr[5]);
ASSERT_EQUAL(**iter, 5);
iter += 5;
ASSERT_EQUAL(*iter, &arr[10]);
ASSERT_EQUAL(**iter, 10);
iter -= 10;
ASSERT_EQUAL(*iter, &arr[0]);
ASSERT_EQUAL(**iter, 0);
}
DECLARE_UNITTEST(TestCountingIteratorPointer);

void TestCountingIteratorStridedPointer()
{
::cuda::std::array<std::pair<int, double>, 4> arr{{{1, 2}, {3, 4}, {5, 6}, {7, 8}}};

// iterate over all second elements
auto iter = thrust::make_counting_iterator(
&arr[0].second, ::cuda::std::integral_constant<int, sizeof(std::pair<int, double>)>{});

cuda::std::fill(iter, iter + 4, 1337);

const ::cuda::std::array<std::pair<int, double>, 4> reference{{{1, 1337}, {3, 1337}, {5, 1337}, {7, 1337}}};
ASSERT_EQUAL(arr, reference);
}
DECLARE_UNITTEST(TestCountingIteratorStridedPointer);

_CCCL_DIAG_POP
38 changes: 24 additions & 14 deletions thrust/thrust/iterator/counting_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ struct counting_iterator_equal<Difference,
}
};

struct empty
{};

template <typename T>
struct value_holder
{
Expand Down Expand Up @@ -280,10 +283,10 @@ template <typename Incrementable,
typename System = use_default,
typename Traversal = use_default,
typename Difference = use_default,
typename Step = ::cuda::std::monostate>
typename Step = detail::empty>
class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
: public detail::make_counting_iterator_base<Incrementable, System, Traversal, Difference, Step>::type
, public Step
, Step
{
//! \cond
using super_t =
Expand Down Expand Up @@ -330,16 +333,16 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator

private:
template <typename S = Step>
auto step() const -> Incrementable
auto step() const
{
static_assert(!::cuda::std::is_same_v<Step, ::cuda::std::monostate>);
static_assert(!::cuda::std::is_same_v<Step, detail::empty>);
return static_cast<const Step&>(*this)();
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void advance(difference_type n)
{
if constexpr (::cuda::std::is_same_v<Step, ::cuda::std::monostate>)
if constexpr (::cuda::std::is_same_v<Step, detail::empty>)
{
this->base_reference() = static_cast<Incrementable>(this->base_reference() + n);
}
Expand All @@ -352,7 +355,7 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void increment()
{
if constexpr (::cuda::std::is_same_v<Step, ::cuda::std::monostate>)
if constexpr (::cuda::std::is_same_v<Step, detail::empty>)
{
++this->base_reference();
}
Expand All @@ -365,7 +368,7 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void decrement()
{
if constexpr (::cuda::std::is_same_v<Step, ::cuda::std::monostate>)
if constexpr (::cuda::std::is_same_v<Step, detail::empty>)
{
--this->base_reference();
}
Expand All @@ -381,17 +384,21 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
}

// note that we implement equal specially for floating point counting_iterator
template <typename OtherIncrementable, typename OtherSystem, typename OtherTraversal, typename OtherDifference>
template <typename OtherIncrementable,
typename OtherSystem,
typename OtherTraversal,
typename OtherDifference,
typename OtherStep>
_CCCL_HOST_DEVICE bool
equal(counting_iterator<OtherIncrementable, OtherSystem, OtherTraversal, OtherDifference> const& y) const
equal(counting_iterator<OtherIncrementable, OtherSystem, OtherTraversal, OtherDifference, OtherStep> const& y) const
{
using e = detail::counting_iterator_equal<difference_type, Incrementable, OtherIncrementable>;
return e::equal(this->base(), y.base());
}

template <class OtherIncrementable>
template <class OtherIncrementable, typename OtherStep>
_CCCL_HOST_DEVICE difference_type
distance_to(counting_iterator<OtherIncrementable, System, Traversal, Difference> const& y) const
distance_to(counting_iterator<OtherIncrementable, System, Traversal, Difference, OtherStep> const& y) const
{
using d = typename detail::eval_if<
detail::is_numeric<Incrementable>::value,
Expand All @@ -418,8 +425,11 @@ inline _CCCL_HOST_DEVICE counting_iterator<Incrementable> make_counting_iterator
template <typename Incrementable, typename Stride>
inline _CCCL_HOST_DEVICE auto make_counting_iterator(Incrementable x, Stride stride)
{
return counting_iterator<Incrementable, use_default, use_default, use_default, detail::value_holder<Stride>>(
x, {stride});
return counting_iterator<Incrementable,
use_default,
random_access_traversal_tag,
use_default,
detail::value_holder<Stride>>(x, {stride});
}

template <typename Incrementable, typename Stride, Stride Value>
Expand All @@ -428,7 +438,7 @@ make_counting_iterator(Incrementable x, ::cuda::std::integral_constant<Stride, V
{
return counting_iterator<Incrementable,
use_default,
use_default,
random_access_traversal_tag,
use_default,
::cuda::std::integral_constant<Stride, Value>>(x, stride);
}
Expand Down

0 comments on commit 1f1292a

Please sign in to comment.