Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thrust::strided_iterator as a thrust::counting_iterator with step #4014

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions thrust/testing/counting_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/sort.h>

#include <cuda/std/__algorithm_>
#include <cuda/std/iterator>
#include <cuda/std/type_traits>

#include <complex>
#include <cstdint>
#include <numeric>

#include <unittest/unittest.h>

Expand Down Expand Up @@ -285,4 +288,65 @@ void TestCountingIteratorDifference()
}
DECLARE_UNITTEST(TestCountingIteratorDifference);

void TestCountingIteratorDynamicStride()
{
auto iter = thrust::make_counting_iterator(0, 2);
static_assert(sizeof(iter) == 2 * sizeof(int));

ASSERT_EQUAL(*iter, 0);
iter++;
ASSERT_EQUAL(*iter, 2);
iter++;
iter++;
ASSERT_EQUAL(*iter, 6);
iter += 5;
ASSERT_EQUAL(*iter, 16);
iter -= 10;
ASSERT_EQUAL(*iter, -4);
}
DECLARE_UNITTEST(TestCountingIteratorDynamicStride);

void TestCountingIteratorStaticStride()
{
auto iter = thrust::make_counting_iterator<2>(0);
static_assert(sizeof(decltype(iter)) == sizeof(int));

ASSERT_EQUAL(*iter, 0);
iter++;
ASSERT_EQUAL(*iter, 2);
iter++;
iter++;
ASSERT_EQUAL(*iter, 6);
iter += 5;
ASSERT_EQUAL(*iter, 16);
iter -= 10;
ASSERT_EQUAL(*iter, -4);
}
DECLARE_UNITTEST(TestCountingIteratorStaticStride);

void TestCountingIteratorPointer()
{
int arr[11];
std::iota(arr, arr + 11, 0);

auto iter = thrust::make_counting_iterator(&arr[2]);

ASSERT_EQUAL(*iter, &arr[2]);
ASSERT_EQUAL(**iter, 2);
iter++;
ASSERT_EQUAL(*iter, &arr[3]);
ASSERT_EQUAL(**iter, 3);
iter++;
iter++;
ASSERT_EQUAL(*iter, &arr[5]);
ASSERT_EQUAL(**iter, 5);
iter += 5;
ASSERT_EQUAL(*iter, &arr[10]);
ASSERT_EQUAL(**iter, 10);
iter -= 10;
ASSERT_EQUAL(*iter, &arr[0]);
ASSERT_EQUAL(**iter, 0);
}
DECLARE_UNITTEST(TestCountingIteratorPointer);

_CCCL_DIAG_POP
82 changes: 82 additions & 0 deletions thrust/testing/strided_iterator.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include <thrust/device_vector.h>
#include <thrust/iterator/strided_iterator.h>

#include <cuda/std/array>
#include <cuda/std/utility>

#include <algorithm>
#include <numeric>

#include <unittest/unittest.h>

void TestReadingStridedIterator()
{
thrust::host_vector<int> v(21);
std::iota(v.begin(), v.end(), -4);
auto iter = thrust::make_strided_iterator(v.begin() + 4, 2);

ASSERT_EQUAL(*iter, 0);
iter++;
ASSERT_EQUAL(*iter, 2);
iter++;
iter++;
ASSERT_EQUAL(*iter, 6);
iter += 5;
ASSERT_EQUAL(*iter, 16);
iter -= 10;
ASSERT_EQUAL(*iter, -4);
}
DECLARE_UNITTEST(TestReadingStridedIterator);

template <typename Vector>
void TestWritingStridedIterator()
{
// iterate over all second elements (runtime stride)
{
Vector v(10);
auto iter = thrust::make_strided_iterator(v.begin(), 2);
ASSERT_EQUAL(v, (Vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
*iter = 33;
ASSERT_EQUAL(v, (Vector{33, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
auto iter2 = iter + 1;
*iter2 = 34;
ASSERT_EQUAL(v, (Vector{33, 0, 34, 0, 0, 0, 0, 0, 0, 0}));
thrust::fill(iter + 2, iter + 4, 42);
ASSERT_EQUAL(v, (Vector{33, 0, 34, 0, 42, 0, 42, 0, 0, 0}));
}

// iterate over all second elements (static stride)
{
Vector v(10);
auto iter = thrust::make_strided_iterator<2>(v.begin());
thrust::fill(iter, iter + 3, 42);
ASSERT_EQUAL(v, (Vector{42, 0, 42, 0, 42, 0, 0, 0, 0, 0}));
}
}
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestWritingStridedIterator);

void TestWritingStridedIteratorToStructMember()
{
using pair = ::cuda::std::pair<int, double>;
using arr_of_pairs = ::cuda::std::array<pair, 4>;
const auto data = arr_of_pairs{{{1, 2}, {3, 4}, {5, 6}, {7, 8}}};
const auto reference = arr_of_pairs{{{1, 1337}, {3, 1337}, {5, 1337}, {7, 1337}}};
constexpr auto stride = sizeof(pair) / sizeof(double);

// iterate over all second elements (runtime stride)
{
auto arr = data;
auto iter = thrust::make_strided_iterator(&arr[0].second, stride);
thrust::fill(iter, iter + 4, 1337);
ASSERT_EQUAL(arr == reference, true);
}

// iterate over all second elements (static stride)
{
auto arr = data;
auto iter = thrust::make_strided_iterator<stride>(&arr[0].second);
thrust::fill(iter, iter + 4, 1337);
ASSERT_EQUAL(arr == reference, true);
}
}
DECLARE_UNITTEST(TestWritingStridedIteratorToStructMember);
138 changes: 121 additions & 17 deletions thrust/thrust/iterator/counting_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

THRUST_NAMESPACE_BEGIN

template <typename Incrementable, typename System, typename Traversal, typename Difference>
template <typename Incrementable, typename System, typename Traversal, typename Difference, typename StrideHolder>
class counting_iterator;

namespace detail
Expand All @@ -60,7 +60,7 @@ template <typename Number>
using counting_iterator_difference_type =
::cuda::std::_If<::cuda::std::is_integral_v<Number> && sizeof(Number) < sizeof(int), int, ::cuda::std::ptrdiff_t>;

template <typename Incrementable, typename System, typename Traversal, typename Difference>
template <typename Incrementable, typename System, typename Traversal, typename Difference, typename StrideHolder>
struct make_counting_iterator_base
{
using system =
Expand All @@ -75,14 +75,17 @@ struct make_counting_iterator_base
// to the internal state of an iterator causes subtle bugs (consider the temporary
// iterator created in the expression *(iter + i)) and has no compelling use case
using type =
iterator_adaptor<counting_iterator<Incrementable, System, Traversal, Difference>,
iterator_adaptor<counting_iterator<Incrementable, System, Traversal, Difference, StrideHolder>,
Incrementable,
Incrementable,
system,
traversal,
Incrementable,
difference>;
};

struct empty
{};
} // namespace detail

//! \addtogroup iterators
Expand Down Expand Up @@ -164,14 +167,17 @@ struct make_counting_iterator_base
//!
//! \see make_counting_iterator
template <typename Incrementable,
typename System = use_default,
typename Traversal = use_default,
typename Difference = use_default>
typename System = use_default,
typename Traversal = use_default,
typename Difference = use_default,
typename StrideHolder = detail::empty>
class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
: public detail::make_counting_iterator_base<Incrementable, System, Traversal, Difference>::type
: public detail::make_counting_iterator_base<Incrementable, System, Traversal, Difference, StrideHolder>::type
, StrideHolder
{
//! \cond
using super_t = typename detail::make_counting_iterator_base<Incrementable, System, Traversal, Difference>::type;
using super_t =
typename detail::make_counting_iterator_base<Incrementable, System, Traversal, Difference, StrideHolder>::type;
friend class iterator_core_access;

public:
Expand All @@ -187,12 +193,14 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
//! Copy constructor copies the value of another counting_iterator with related System type.
//!
//! \param rhs The \p counting_iterator to copy.
template <class OtherSystem,
detail::enable_if_convertible_t<
typename iterator_system<counting_iterator<Incrementable, OtherSystem, Traversal, Difference>>::type,
typename iterator_system<super_t>::type,
int> = 0>
_CCCL_HOST_DEVICE counting_iterator(counting_iterator<Incrementable, OtherSystem, Traversal, Difference> const& rhs)
template <
class OtherSystem,
detail::enable_if_convertible_t<
typename iterator_system<counting_iterator<Incrementable, OtherSystem, Traversal, Difference, StrideHolder>>::type,
typename iterator_system<super_t>::type,
int> = 0>
_CCCL_HOST_DEVICE
counting_iterator(counting_iterator<Incrementable, OtherSystem, Traversal, Difference, StrideHolder> const& rhs)
: super_t(rhs.base())
{}

Expand All @@ -204,18 +212,72 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator
: super_t(x)
{}

_CCCL_HOST_DEVICE explicit counting_iterator(Incrementable x, StrideHolder stride)
: super_t(x)
, StrideHolder(stride)
{}

//! \cond

private:
template <typename S = StrideHolder>
auto stride() const
{
return static_cast<const S&>(*this).value;
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void advance(difference_type n)
{
if constexpr (::cuda::std::is_same_v<StrideHolder, detail::empty>)
{
this->base_reference() = static_cast<Incrementable>(this->base_reference() + n);
}
else
{
this->base_reference() += n * stride();
}
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void increment()
{
if constexpr (::cuda::std::is_same_v<StrideHolder, detail::empty>)
{
++this->base_reference();
}
else
{
this->base_reference() += stride();
}
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void decrement()
{
if constexpr (::cuda::std::is_same_v<StrideHolder, detail::empty>)
{
--this->base_reference();
}
else
{
this->base_reference() -= stride();
}
}

_CCCL_HOST_DEVICE reference dereference() const
{
return this->base_reference();
}

// note that we implement equal specially for floating point counting_iterator
template <typename OtherSystem, typename OtherTraversal, typename OtherDifference>
template <typename OtherSystem,
typename OtherTraversal,
typename OtherDifference,
typename OtherStrideHolder>
_CCCL_HOST_DEVICE bool
equal(counting_iterator<Incrementable, OtherSystem, OtherTraversal, OtherDifference> const& y) const
equal(counting_iterator<Incrementable, OtherSystem, OtherTraversal, OtherDifference, OtherStrideHolder> const& y)
const
{
if constexpr (::cuda::is_floating_point_v<Incrementable>)
{
Expand All @@ -229,7 +291,7 @@ class _CCCL_DECLSPEC_EMPTY_BASES counting_iterator

template <typename OtherSystem, typename OtherTraversal, typename OtherDifference>
_CCCL_HOST_DEVICE difference_type
distance_to(counting_iterator<Incrementable, OtherSystem, OtherTraversal, OtherDifference> const& y) const
distance_to(counting_iterator<Incrementable, OtherSystem, OtherTraversal, OtherDifference, StrideHolder> const& y) const
{
if constexpr (::cuda::std::is_integral<Incrementable>::value)
{
Expand All @@ -255,6 +317,48 @@ inline _CCCL_HOST_DEVICE counting_iterator<Incrementable> make_counting_iterator
return counting_iterator<Incrementable>(x);
}

namespace detail
{
// Holds a runtime stride
template <typename T>
struct runtime_stride_holder
{
T value;
};

// Holds a compile-time stride
// (we cannot use ::cuda::std::integral_constant, because it has a conversion operator to T that causes an ambiguity
// with operator+(counting_iterator, counting_iterator::difference_type) in any expression `counting_iterator +
// integral`.
template <typename T, T Value>
struct compile_time_stride_holder
{
static constexpr T value = Value;
};
} // namespace detail

//! Constructs a counting_iterator with a runtime stride
template <typename Incrementable, typename Stride>
_CCCL_HOST_DEVICE auto make_counting_iterator(Incrementable x, Stride stride)
{
return counting_iterator<Incrementable,
use_default,
random_access_traversal_tag,
use_default,
detail::runtime_stride_holder<Stride>>(x, {stride});
}

//! Constructs a counting_iterator with a compile-time stride
template <auto Stride, typename Incrementable>
_CCCL_HOST_DEVICE auto make_counting_iterator(Incrementable x)
{
return counting_iterator<Incrementable,
use_default,
random_access_traversal_tag,
use_default,
detail::compile_time_stride_holder<decltype(Stride), Stride>>(x, {});
}

//! \} // end fancyiterators
//! \} // end iterators

Expand Down
Loading