forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsdp_utils_cpp.h
503 lines (453 loc) · 15.4 KB
/
sdp_utils_cpp.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
#pragma once
#include <ATen/Context.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/grad_mode.h>
#include <ATen/native/DispatchStub.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymFloat.h>
#include <c10/util/string_view.h>
#include <c10/util/Array.h>
#include <cmath>
#include <cstdint>
#include <functional>
namespace sdp {
constexpr int32_t num_backends = 4;
enum class SDPBackend {
error = -1,
math = 0,
flash_attention = 1,
efficient_attention = 2,
cudnn_attention = 3
};
// Note that if this changed make sure to update
// the templated enum in mem_eff/kernel_forward.h and mem_eff/kernel_backward.h
enum class CustomMaskType {
NoCustomMask = 0,
CausalFromTopLeft = 1,
CausalFromBottomRight = 2,
NumCustomMaskTypes,
};
struct sdp_params {
at::Tensor query;
at::Tensor key;
at::Tensor value;
c10::optional<at::Tensor> attn_mask;
double dropout;
bool is_causal;
};
SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
inline c10::SymFloat calculate_scale(
const at::Tensor& query,
c10::optional<double> scale) {
const auto softmax_scale = scale.has_value()
? scale.value()
: (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
return c10::SymFloat(softmax_scale);
}
using c10::array_of;
inline bool input_requires_grad(sdp_params const& params) {
const bool any_inputs_require_grad = params.query.requires_grad() ||
params.key.requires_grad() || params.value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
return any_inputs_require_grad && gradmode_enabled;
}
inline bool has_for_nested_inputs(sdp_params const& params) {
return
(params.query.is_nested() && params.query.layout() == c10::kStrided) ||
(params.key.is_nested() && params.key.layout() == c10::kStrided) ||
(params.value.is_nested() && params.value.layout() == c10::kStrided);
}
inline bool has_for_dense_inputs(sdp_params const& params) {
return !params.query.is_nested() || !params.key.is_nested() || !params.value.is_nested();
}
inline bool has_only_dense_inputs(sdp_params const& params) {
return !params.query.is_nested() && !params.key.is_nested() && !params.value.is_nested();
}
template <typename dtype_vector>
inline bool check_tensor_dtype(
sdp_params const& params,
dtype_vector allowed_dtypes,
bool debug) {
auto query_dtype = params.query.dtype();
if (!(query_dtype == params.key.dtype() &&
query_dtype == params.value.dtype() &&
(std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
allowed_dtypes.end()))) {
if (debug) {
TORCH_WARN(
"Expected query, key and value to all be of dtype: {",
c10::Join(", ", allowed_dtypes),
"}. Got ",
"Query dtype: ",
params.query.dtype(),
", Key dtype: ",
params.key.dtype(),
", and Value dtype: ",
params.value.dtype(),
" instead.");
}
return false;
}
return true;
}
inline bool try_broadcast_param_size(
const c10::SymInt q_size,
const c10::SymInt k_size,
const c10::SymInt v_size,
c10::string_view param_name,
bool debug) {
auto max_size = std::max({q_size, k_size, v_size});
if ((q_size != max_size && q_size != 1) ||
(k_size != max_size && k_size != 1) ||
(v_size != max_size && v_size != 1)) {
if (debug) {
TORCH_WARN(
"Both fused kernels require query, key and value to have broadcastable ",
param_name,
"got Query ",
param_name,
q_size,
", Key ",
param_name,
k_size,
", Value ",
param_name,
v_size,
" instead.");
}
return false;
}
return true;
}
inline bool check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
at::Tensor const& param,
c10::string_view param_name,
bool debug) {
const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
const at::Tensor& sizes = nt_tensor_impl->get_nested_sizes();
auto num_head_dims = nt_tensor_impl->opt_size(1);
if (!num_head_dims.has_value()) {
// num_head_dims is ragged
if (debug) {
TORCH_WARN(
"Fused kernels do not support ragged num_head_dims, ",
param_name,
"has a ragged num_heads.");
}
return false;
}
auto* sizes_ptr = sizes.data_ptr<int64_t>();
const int64_t n_tensors = param.size(0);
const int64_t size_tensor_stride = sizes.stride(0);
// This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
for (const auto i : c10::irange(n_tensors)) {
if (sizes_ptr[(i * size_tensor_stride) + 1] == 0) {
if (debug) {
TORCH_WARN(
"Fused kernels do not support seq_len == 0, ",
param_name,
"has a seq len of 0.");
}
return false;
}
}
return true;
}
inline bool check_for_seq_len_0_nested_tensor(sdp_params const& params, bool debug) {
// When this function is called we are assured that the nt is dim==4
bool q_is_safe = params.query.is_nested()
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
params.query, "query ", debug)
: true;
// short circuit if any is unsafe
if (!q_is_safe) {
return false;
}
bool k_is_safe = params.key.is_nested()
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
params.key, "key ", debug)
: true;
if (!k_is_safe) {
return false;
}
bool v_is_safe = params.value.is_nested()
? check_for_seq_len_0_and_consistent_head_dim_nested_tensor_helper(
params.value, "value ", debug)
: true;
if (!v_is_safe) {
return false;
}
// We now know none of the inputs have ragged num_heads, so we can safely
// access .size(1)
auto q_num_heads = params.query.size(1);
auto k_num_heads = params.key.size(1);
auto v_num_heads = params.value.size(1);
bool same_num_heads =
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
if (!same_num_heads) {
if (input_requires_grad(params)){
if (debug) {
TORCH_WARN(
"Both fused kernels do not support training with broadcasted NT inputs.");
}
return false;
}
return try_broadcast_param_size(
q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
}
return true;
}
inline bool check_nested_tensor(sdp_params const& params, bool debug) {
// Return false if have nested tensor
if (!has_only_dense_inputs(params)) {
if (debug) {
TORCH_WARN(
"Both fused kernels of cpp version currently do not support Nested Tensor inputs.");
}
return false;
}
return true;
}
inline bool check_for_dropout(sdp_params const& params, bool debug) {
if (params.dropout > 0.0) {
if (debug) {
TORCH_WARN("Both fused kernels do not support non-zero dropout.");
}
return false;
}
return true;
}
inline bool check_requires_grad_and_nested(sdp_params const& params, bool debug) {
if (input_requires_grad(params)) {
if (debug) {
TORCH_WARN(
"Memory efficient attention currently doesn't support training with NT inputs.");
}
return false;
}
return true;
}
inline bool check_for_attn_mask(sdp_params const& params, bool debug) {
if (params.attn_mask.has_value()) {
if (debug) {
TORCH_WARN("Both fused kernels do not support non-null attn_mask.");
}
return false;
}
return true;
}
inline bool check_attn_mask_shape(sdp_params const& params, bool debug) {
auto attn_mask = params.attn_mask;
if (!attn_mask.has_value()) {
return true;
}
if (attn_mask.value().requires_grad()) {
return false;
}
auto batchSize = params.query.sym_size(0);
auto qSize = params.query.sym_size(2);
auto kvSize = params.key.sym_size(2);
auto num_head = params.query.sym_size(1);
if (attn_mask.value().sym_size(-2) != qSize && attn_mask.value().sym_size(-2) != 1) {
return false;
}
if (attn_mask.value().sym_size(-1) != kvSize && attn_mask.value().sym_size(-1) != 1) {
return false;
}
if (attn_mask.value().dim() == 2) {
return true;
} else if (attn_mask.value().dim() == 4) {
if ((attn_mask.value().sym_size(0) == 1 || attn_mask.value().sym_size(0) == batchSize)
&& (attn_mask.value().sym_size(1) == 1 || attn_mask.value().sym_size(1) == num_head)) {
return true;
}
}
if (debug) {
TORCH_WARN("Please use the following attn mask shapes: ",
"2d - ({Q_seq_len, 1} x {KV_seq_len, 1}); ",
"4d - ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1})");
}
return false;
}
inline bool check_tensor_shapes(sdp_params const& params, bool debug) {
auto query_dim = params.query.dim();
if (!(query_dim == params.key.dim() && query_dim == params.value.dim() &&
(query_dim == 4))) {
if (debug) {
TORCH_WARN(
"Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: ",
query_dim,
", Key dim: ",
params.key.dim(),
", Value dim: ",
params.value.dim(),
" instead.");
}
return false;
}
return true;
}
inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
const auto nt_tensor_impl = at::native::get_nested_tensor_impl(param);
auto seq_len = nt_tensor_impl->opt_size(2);
if (!seq_len.has_value()) {
if (debug) {
TORCH_WARN(
"For both fused kernels, if one of key/value batch_size requires "
"broadcasting and the other does not, then the other must have a ",
"consistent seq_len dim.")
}
return false;
}
return true;
}
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional
auto q_batch_size = params.query.sym_size(0);
auto k_batch_size = params.key.sym_size(0);
auto v_batch_size = params.value.sym_size(0);
bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
auto q_num_heads = params.query.sym_size(1);
auto k_num_heads = params.key.sym_size(1);
auto v_num_heads = params.value.sym_size(1);
bool same_num_heads =
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
if (!(same_batch_size && same_num_heads)) {
if (debug) {
TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ",
"Query.sizes(): ",
params.query.sizes(),
", Key sizes(): ",
params.key.sizes(),
", Value sizes(): ",
params.value.sizes(),
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
}
return false;
}
return true;
}
inline bool check_batch_size_nested(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional
auto q_batch_size = params.query.sym_size(0);
auto k_batch_size = params.key.sym_size(0);
auto v_batch_size = params.value.sym_size(0);
bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
// num_heads logic for nested input is checked in
// check_for_seq_len_0_nested_tensor as there is handling there to make sure
// num_heads is not ragged
bool broadcastable_batch_size = true;
if (!same_batch_size) {
if (input_requires_grad(params)){
if (debug) {
TORCH_WARN(
"Both fused kernels do not support training with broadcasted NT inputs.");
}
return false;
}
// try to broadcast batchsize
broadcastable_batch_size = try_broadcast_param_size(
q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);
// if only one of k or v require broadcasting of batch size, the other
// must have a consistent seq_len dim
if (broadcastable_batch_size) {
if (k_batch_size == 1 && v_batch_size != 1 &&
!check_safe_kv_broadcast(params.value, debug)) {
return false;
}
if (v_batch_size == 1 && k_batch_size != 1 &&
!check_safe_kv_broadcast(params.key, debug)) {
return false;
}
}
}
return broadcastable_batch_size;
}
inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool debug) {
// In some cases people will pass in 0 sized tensors, this will
// cause the fused path to error with unaligned mask
bool zero_seq_len_q = params.query.sym_size(-2) == 0;
bool zero_seq_len_k = params.key.sym_size(-2) == 0;
if (zero_seq_len_q || zero_seq_len_k) {
if (debug) {
TORCH_WARN(
"Both fused kernels do not support zero seq_len_q or seq_len_kv.");
}
return false;
}
return true;
}
template<bool ignore_singleton_dim>
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
// The stride checking for NestedTensors is done within the kernel
// And .contiguous will be called if needed
// This function checks that the last dimension of the inputs to
// fused_attention have stride 1
bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;
// https://github.com/pytorch/pytorch/issues/116333
// If the head_dim is size 1 the stride won't matter, but we
// check this condition before padding the head_dim to 1
if (ignore_singleton_dim){
qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
}
bool mask_stride_equal_1 = params.attn_mask.has_value()
? params.attn_mask.value().sym_stride(-1) == 1
: true;
if (!(qkv_strides_equal_1 && mask_stride_equal_1)) {
if (debug) {
std::ostringstream epilogue_message;
if (params.attn_mask.has_value()) {
epilogue_message << ", Attn_mask.stride(-1): "
<< params.attn_mask.value().sym_stride(-1);
}
epilogue_message << " instead.";
TORCH_WARN(
"Both fused kernels require the last dimension of the input to have stride 1. ",
"Got Query.stride(-1): ",
params.query.sym_stride(-1),
", Key.stride(-1): ",
params.key.sym_stride(-1),
", Value.stride(-1): ",
params.value.sym_stride(-1),
epilogue_message.str());
}
return false;
}
return true;
}
inline bool check_runtime_disabled_flash(sdp_params const& params, bool debug) {
// We check the global context to see if user has explicitly turned of flash
// sdp kernels
if (!at::globalContext().userEnabledFlashSDP()) {
if (debug) {
TORCH_WARN("Flash attention has been runtime disabled.");
}
return false;
}
return true;
}
inline bool check_runtime_disabled_mem_efficient(sdp_params const& params, bool debug) {
// We check the global context to see if user has explicitly turned of
// mem_efficient sdp kernels
if (!at::globalContext().userEnabledMemEfficientSDP()) {
if (debug) {
TORCH_WARN("Memory Efficient attention has been runtime disabled.");
}
return false;
}
return true;
}
} // namespace sdp