forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnamed_barrier.hpp
More file actions
76 lines (62 loc) · 2.99 KB
/
Copy pathnamed_barrier.hpp
File metadata and controls
76 lines (62 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cutlass/arch/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work
// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80.
CUTLASS_DEVICE
static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) {
static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
}
CUTLASS_DEVICE
static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads));
cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);
}
CUTLASS_DEVICE
static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) {
static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);
uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;
cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
}
CUTLASS_DEVICE
static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {
uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);
cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);
asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class FwdNamedBarriers {
QueryEmpty = 0,
ProducerWG = 1,
TileCountSmemEmpty = 2,
TileCountSmemFull = 3,
WarpSchedulerWG1 = 4,
WarpSchedulerWG2 = 5,
WarpSchedulerWG3 = 6,
AppendKV = 7,
QueryRotated = 8,
};
enum class BwdNamedBarriers {
KVEmpty = 0,
PdS = 1,
// This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it
TileCountSmemEmpty = 2,
TileCountSmemFull = 3,
dQEmptyWG1 = 4,
dQEmptyWG2 = 5,
dQEmptyWG3 = 6,
dQFullWG1 = 7,
dQFullWG2 = 8,
dQFullWG3 = 9,
};
} // flash