An overview of data layout strategies employed by the Triton compiler for efficient GPU computation.
Data Layout 是 Triton 编译器中的核心概念之一. 它定义了张量在全局内存, 共享内存和寄存器中的储存方式, 直接决定每个 CTA, Warp 和线程怎么访问数据. 常见的 Layout 属性包括 #blocked, 用于描述全局内存中的数据布局; #shared, 用于描述共享内存中的数据布局; 还有 #mma, 这是专门针对 Tensor Core 的数据布局. 一种比较新的 Layout 是 Linear Layout, 提供了一种统一的方式描述不同内存空间中的数据布局, 也是这篇文档的重点.
在 TableGen 中, Blocked Layout 的编码方式如下定义, 编码方式反映的是每一个原始数据会被分配到哪一个线程上:
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> {
let parameters = (
ins
ArrayRefParameter<"unsigned">:$sizePerThread,
ArrayRefParameter<"unsigned">:$threadsPerWarp,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
// CGALayout is optional in the textual IR. If omitted, we infer it to be a
// CGA with a single CTA (i.e. the trivial map onto dim0..dimn-1)
"CGAEncodingAttr":$CGALayout
)
}举几个例子, 对于一个 16x16 的矩阵, 如果采用以下的 Blocked Layout:
#blocked = #ttg.blocked_layout<{
sizePerThread = [2, 2]
threadsPerWarp = [8, 4]
warpsPerCTA = [1, 2]
order = [1, 0]
}>那么它的编码结果就是:
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]每个位置上的数字代表该位置上的数据被分配到哪个线程上. 在这个例子中整个矩阵被分配到同一个 CTA 中, 2 个 Warp 上, 每个 Warp 有 32 个线程, 因此总共有 64 个线程. sizePerThread 向量表明每个线程在不同维度上会处理多少数据, 例如这里表示每个线程会处理内存中相邻的 2x2 个数据, 所以最左上角开始的 2x2 个元素全被分配到第一个线程上. threadsPerWarp 向量表示每个 Warp 在不同维度上有多少线程, 在这里沿着 dim=0, 也就是水平方向有 4 个线程, 2*4*2=16 刚好等于原始矩阵的列数 (多一个乘 2 是因为这个方向上有两个 Warp). 沿着 dim=1, 也就是垂直方向有 8 个线程, 8*2=16 也刚好等于原始矩阵的行数. warpsPerCTA 向量表示每个 CTA 在不同维度上有多少 Warp, 这里沿着 dim=0 方向有 2 个 Warp, 沿着 dim=1 方向有 1 个 Warp (水平排布). order 表示维度变化的快慢顺序, 在这里沿着水平方向的分量 0 小于垂直方向的分量 1, 表示水平方向变化更快, 这实际上就是行主序 (row-major) 的存储方式.
再看一个例子, 将一个 32x32 的矩阵分配到 4 个 CTA 上, 每个 CTA 有 2 个 Warp, 每个 Warp 里面仍然是 32 个线程.
#cta_layout = #ttg.cta_layout<{
ctasPerCluster = [2, 2]
ctasSplitNum = [2, 2]
ctaOrder = [1, 0]
}>
#blocked = #ttg.blocked_layout<{
sizePerThread = [2, 2]
threadsPerWarp = [8, 4]
warpsPerCTA = [1, 2]
order = [1, 0]
CTALayout = #cta_layout
}>编码结果如下:
CTA [0,0] CTA [0,1]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
CTA [1,0] CTA [1,1]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ]
... ...
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]
[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ]前几个参数和第一个例子中一样, 唯一新增的是 CTALayout, 它定义了多个 CTA 在不同维度上的分布方式. 这里 ctasPerCluster 表示每个 Cluster 在不同维度上有多少个 CTA, 这里沿着两个方向上各有 2 个 CTA, 一共就是 4 个. ctasSplitNum 表示每个维度上 CTA 的划分方式, 这里沿着两个方向上都划分成 2 份. ctaOrder 表示 CTA 分布的快慢顺序, 这里和前面的 order 一样, 沿着水平方向变化更快. 柘同样是行主序的分布方式.
不过在新版 Triton 中, 已经使用 Linear Layout 来统一描述不同内存空间中的数据布局, 这一段介绍已经显得不适用了.