diff --git a/cpp/include/cugraph/utilities/shuffle_comm.cuh b/cpp/include/cugraph/utilities/shuffle_comm.cuh index 98fa2cb1706..d0b44021428 100644 --- a/cpp/include/cugraph/utilities/shuffle_comm.cuh +++ b/cpp/include/cugraph/utilities/shuffle_comm.cuh @@ -145,43 +145,58 @@ compute_tx_rx_counts_offsets_ranks(raft::comms::comms_t const& comm, template struct key_group_id_less_t { - KeyToGroupIdOp key_to_group_id_op{}; - int pivot{}; + key_group_id_less_t(KeyToGroupIdOp op, int pivot_) : key_to_group_id_op(::std::move(op)), pivot(pivot_) {} __device__ bool operator()(key_type k) const { return key_to_group_id_op(k) < pivot; } + +private: + KeyToGroupIdOp key_to_group_id_op; + int pivot; }; template struct value_group_id_less_t { - ValueToGroupIdOp value_to_group_id_op{}; - int pivot{}; + value_group_id_less_t(ValueToGroupIdOp op, int pivot_) : value_to_group_id_op(::std::move(op)), pivot(pivot_) {} __device__ bool operator()(value_type v) const { return value_to_group_id_op(v) < pivot; } + +private: + ValueToGroupIdOp value_to_group_id_op; + int pivot; }; template struct kv_pair_group_id_less_t { - KeyToGroupIdOp key_to_group_id_op{}; - int pivot{}; + kv_pair_group_id_less_t(KeyToGroupIdOp op, int pivot_) : key_to_group_id_op(::std::move(op)), pivot(pivot_) {} __device__ bool operator()(thrust::tuple t) const { return key_to_group_id_op(thrust::get<0>(t)) < pivot; } + +private: + KeyToGroupIdOp key_to_group_id_op; + int pivot; }; template struct value_group_id_greater_equal_t { - ValueToGroupIdOp value_to_group_id_op{}; - int pivot{}; + value_group_id_greater_equal_t(ValueToGroupIdOp op, int pivot_) : value_to_group_id_op(::std::move(op)), pivot(pivot_) {} __device__ bool operator()(value_type v) const { return value_to_group_id_op(v) >= pivot; } + +private: + ValueToGroupIdOp value_to_group_id_op; + int pivot; }; template struct kv_pair_group_id_greater_equal_t { - KeyToGroupIdOp key_to_group_id_op{}; - int pivot{}; + kv_pair_group_id_greater_equal_t(KeyToGroupIdOp op, int pivot_) : key_to_group_id_op(::std::move(op)), pivot(pivot_) {} __device__ bool operator()(thrust::tuple t) const { return key_to_group_id_op(thrust::get<0>(t)) >= pivot; } + +private: + KeyToGroupIdOp key_to_group_id_op; + int pivot; }; template