Skip to content

Commit

Permalink
Get tests working
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Mar 5, 2025
1 parent 2336323 commit 6d1b3ae
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 43 deletions.
73 changes: 50 additions & 23 deletions thrust/testing/strided_iterator.cu
Original file line number Diff line number Diff line change
@@ -1,55 +1,82 @@
#include <thrust/device_vector.h>
#include <thrust/iterator/strided_iterator.h>

#include <cuda/std/__numeric/iota.h>
#include <cuda/std/array>
#include <cuda/std/utility>

#include <algorithm>
#include <numeric>

#include <unittest/unittest.h>

void TestStridedIterator()
void TestReadingStridedIterator()
{
thrust::host_vector<int> v(21);
std::iota(v.begin(), v.end(), -4);
auto iter = thrust::make_strided_iterator(v.begin() + 4, 2);

ASSERT_EQUAL(*iter, 0);
iter++;
ASSERT_EQUAL(*iter, 2);
iter++;
iter++;
ASSERT_EQUAL(*iter, 6);
iter += 5;
ASSERT_EQUAL(*iter, 16);
iter -= 10;
ASSERT_EQUAL(*iter, -4);
}
DECLARE_UNITTEST(TestReadingStridedIterator);

template <typename Vector>
void TestWritingStridedIterator()
{
// iterate over all second elements (runtime stride)
{
thrust::device_vector<int> v(10);
Vector v(10);
auto iter = thrust::make_strided_iterator(v.begin(), 2);
cuda::std::fill(iter, iter + 3, 42);
ASSERT_EQUAL(v, (thrust::device_vector{42, 0, 42, 0, 42, 0, 0, 0, 0, 0}));
ASSERT_EQUAL(v, (Vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
*iter = 33;
ASSERT_EQUAL(v, (Vector{33, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
auto iter2 = iter + 1;
*iter2 = 34;
ASSERT_EQUAL(v, (Vector{33, 0, 34, 0, 0, 0, 0, 0, 0, 0}));
thrust::fill(iter + 2, iter + 4, 42);
ASSERT_EQUAL(v, (Vector{33, 0, 34, 0, 42, 0, 42, 0, 0, 0}));
}

// iterate over all second elements (static stride)
{
thrust::device_vector<int> v(10);
Vector v(10);
auto iter = thrust::make_strided_iterator<2>(v.begin());
cuda::std::fill(iter, iter + 3, 42);
ASSERT_EQUAL(v, (thrust::device_vector{42, 0, 42, 0, 42, 0, 0, 0, 0, 0}));
thrust::fill(iter, iter + 3, 42);
ASSERT_EQUAL(v, (Vector{42, 0, 42, 0, 42, 0, 0, 0, 0, 0}));
}
}
DECLARE_UNITTEST(TestStridedIterator);
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestWritingStridedIterator);

void TestStridedIteratorStruct()
void TestWritingStridedIteratorToStructMember()
{
using arr_of_pairs = ::cuda::std::array<::cuda::std::pair<int, double>, 4>;
const auto reference = arr_of_pairs{{{1, 1337}, {3, 1337}, {5, 1337}, {7, 1337}}};
using pair = ::cuda::std::pair<int, double>;
using arr_of_pairs = ::cuda::std::array<pair, 4>;
const auto data = arr_of_pairs{{{1, 2}, {3, 4}, {5, 6}, {7, 8}}};
const auto reference = arr_of_pairs{{{1, 1337}, {3, 1337}, {5, 1337}, {7, 1337}}};
constexpr auto stride = sizeof(pair) / sizeof(double);

// iterate over all second elements (runtime stride)
{
auto arr = arr_of_pairs{{{1, 2}, {3, 4}, {5, 6}, {7, 8}}};
auto iter = thrust::make_strided_iterator(&arr[0].second, sizeof(::cuda::std::pair<int, double>));

cuda::std::fill(iter, iter + 4, 1337);

auto arr = data;
auto iter = thrust::make_strided_iterator(&arr[0].second, stride);
thrust::fill(iter, iter + 4, 1337);
ASSERT_EQUAL(arr == reference, true);
}

// iterate over all second elements (static stride)
{
auto arr = ::cuda::std::array<::cuda::std::pair<int, double>, 4>{{{1, 2}, {3, 4}, {5, 6}, {7, 8}}};
auto iter = thrust::make_strided_iterator<sizeof(::cuda::std::pair<int, double>)>(&arr[0].second);

cuda::std::fill(iter, iter + 4, 1337);

auto arr = data;
auto iter = thrust::make_strided_iterator<stride>(&arr[0].second);
thrust::fill(iter, iter + 4, 1337);
ASSERT_EQUAL(arr == reference, true);
}
}
DECLARE_UNITTEST(TestStridedIteratorStruct);
DECLARE_UNITTEST(TestWritingStridedIteratorToStructMember);
127 changes: 107 additions & 20 deletions thrust/thrust/iterator/strided_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,48 +14,135 @@
#endif // no system header

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

THRUST_NAMESPACE_BEGIN

//! \addtogroup iterators
//! \{

//! \addtogroup fancyiterator Fancy Iterators
//! \ingroup iterators
//! \{

//! Holds a runtime value
template <typename T>
struct runtime_value
{
T value;
};

//! Holds a compile-time value
// we cannot use ::cuda::std::integral_constant, because it has a conversion operator to T that causes an ambiguity
// with operator+(counting_iterator, counting_iterator::difference_type) in any expression `counting_iterator +
// integral`.
template <auto Value>
struct compile_time_value
{
static constexpr decltype(Value) value = Value;
};

namespace detail
{
struct deref
template <typename T>
_CCCL_INLINE_VAR constexpr bool is_compile_time_value = false;

template <auto Value>
_CCCL_INLINE_VAR constexpr bool is_compile_time_value<compile_time_value<Value>> = true;
} // namespace detail

//! A \p strided_iterator wraps another iterator and moves it by a specified stride each time it is incremented or
//! decremented.
//!
//! \param RandomAccessIterator A random access iterator
//! \param StrideHolder Either a \ref runtime_value or a \ref compile_time_value specifying the stride
template <typename RandomAccessIterator, typename StrideHolder>
class _CCCL_DECLSPEC_EMPTY_BASES strided_iterator
: public iterator_adaptor<strided_iterator<RandomAccessIterator, StrideHolder>, RandomAccessIterator>
, StrideHolder
{
template <typename It>
_CCCL_HOST_DEVICE auto operator()(It it) const -> it_reference_t<It>
//! \cond
using super_t = iterator_adaptor<strided_iterator, RandomAccessIterator>;
friend class iterator_core_access;

public:
using difference_type = typename super_t::difference_type;
//! \endcond

static_assert(::cuda::std::random_access_iterator<RandomAccessIterator>,
"The iterator underlying a strided_iterator must be a random access iterator.");
static_assert(::cuda::std::is_same_v<iterator_traversal_t<RandomAccessIterator>, random_access_traversal_tag>);
static_assert(::cuda::std::is_convertible_v<decltype(StrideHolder::value), difference_type>,
"The stride must be convertible to the iterator's difference_type");

strided_iterator() = default;

//! Creates a strided_iterator from an existing iterator and a stride.
_CCCL_HOST_DEVICE strided_iterator(RandomAccessIterator it, StrideHolder stride = {})
: super_t(it)
, StrideHolder(stride)
{}

static constexpr bool has_static_stride = detail::is_compile_time_value<StrideHolder>;

//! Returns either the \ref runtime_value or the \ref compile_time_value holding the stride's value
_CCCL_HOST_DEVICE const auto& stride_holder() const
{
return *it;
return static_cast<const StrideHolder&>(*this);
}
};
} // namespace detail

//! \addtogroup iterators
//! \{
//! Returns the stride's value
_CCCL_HOST_DEVICE auto stride() const -> difference_type
{
return static_cast<detail::it_difference_t<RandomAccessIterator>>(stride_holder().value);
}

//! \addtogroup fancyiterator Fancy Iterators
//! \ingroup iterators
//! \{
private:
//! \cond
_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void advance(difference_type n)
{
this->base_reference() += n * stride();
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void increment()
{
this->base_reference() += stride();
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE void decrement()
{
this->base_reference() -= stride();
}

template <typename Iterator, typename StrideHolder = detail::empty>
using strided_iterator =
transform_iterator<detail::deref,
counting_iterator<Iterator, use_default, random_access_traversal_tag, use_default, StrideHolder>>;
template <typename OtherStrideHolder>
_CCCL_HOST_DEVICE bool equal(strided_iterator<RandomAccessIterator, OtherStrideHolder> const& other) const
{
return this->base() == other.base();
}

_CCCL_HOST_DEVICE difference_type distance_to(strided_iterator const& other) const
{
const difference_type dist = other.base() - this->base();
_CCCL_ASSERT(dist % stride() == 0, "Underlying iterator difference must be divisible by the stride");
return dist / stride();
}
//! \endcond
};

//! Constructs a strided_iterator with a runtime stride
template <typename Iterator, typename Stride>
_CCCL_HOST_DEVICE auto make_strided_iterator(Iterator it, Stride stride)
{
return strided_iterator<Iterator, detail::runtime_stride_holder<Stride>>(
make_counting_iterator(it, stride), detail::deref{});
return strided_iterator<Iterator, runtime_value<Stride>>(it, {stride});
}

//! Constructs a strided_iterator with a compile-time stride
template <auto Stride, typename Iterator>
_CCCL_HOST_DEVICE auto make_strided_iterator(Iterator it)
{
return strided_iterator<Iterator, detail::compile_time_stride_holder<decltype(Stride), Stride>>(
make_counting_iterator<Stride>(it), detail::deref{});
return strided_iterator<Iterator, compile_time_value<Stride>>(it, {});
}

//! \} // end fancyiterators
Expand Down

0 comments on commit 6d1b3ae

Please sign in to comment.