forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.h
72 lines (58 loc) · 2.24 KB
/
attention.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/macros/Export.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h>
#include <c10/util/Optional.h>
namespace at {
namespace native {
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, c10::optional<double> scale);
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b);
TORCH_API Tensor masked_softmax(
Tensor& attn_scores,
c10::optional<Tensor> attn_mask,
const Tensor& query,
c10::optional<int64_t> mask_type = {});
using transform_bias_rescale_qkv_fn = void(*)(
at::ScalarType type,
void* _q_k_v,
const void* _qkv,
const void* _qkv_bias,
int64_t B,
int64_t T,
int64_t D,
int64_t num_head);
DECLARE_DISPATCH(transform_bias_rescale_qkv_fn, transform_bias_rescale_qkv_stub);
TORCH_API Tensor transform0213_gemm_nt_bias(
const Tensor& a,
const Tensor& b,
const Tensor& c,
const Tensor& query);
TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b);
TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape);
TORCH_API Tensor qkv_projection(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const int64_t embed_dim,
const Tensor& qkv_weight);
using flash_attention_fn = void (*)(
const Tensor& output, const Tensor& logsumexp,
const Tensor& query, const Tensor& key, const Tensor& value,
double dropout_p, bool is_causal,
c10::optional<Tensor> attn_mask,
c10::optional<double> scale);
using flash_attention_backward_fn = void (*)(
const Tensor& grad_q, const Tensor& grad_k,
const Tensor& grad_v, const Tensor& grad_out,
const Tensor& query, const Tensor& key,
const Tensor& value, const Tensor& out, const Tensor& logsumexp,
double dropout_p, bool is_causal,
c10::optional<Tensor> attn_mask,
c10::optional<double> scale);
DECLARE_DISPATCH(flash_attention_fn, flash_attention_kernel);
DECLARE_DISPATCH(flash_attention_backward_fn, flash_attention_backward_kernel);
} // namespace native
} // namespace at