An in-depth exploration of the Triton dialect within MLIR, examining its structure, operations, and how it facilitates GPU programming.
这是阅读 Triton 源码系列对 Triton Dialect 代码结构的梳理.
Triton 中定义类型的文件在 include/triton/Dialect/Triton/IR/TritonTypes.td 中. Triton 在 MLIR 原有的类型中选择了一些作为基本数据类型, 例如
def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">;
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;Triton 没有保留 MLIR 中所有的基本数据类型, 是因为 GPU 运算实际上只支持它列举出来的这些类型, 其他都是不需要的. Triton 还定义了它自己的指针类型.
def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
// ...
let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
// ...
}
class TT_PtrOf<list[Type] pointeeType> : DialectType<...>;指针类型接受两个参数, 一个是指针指向的数据类型, 另一个是地址空间. 指针可以指向标量类型, 但更常见的是指向一块张量, 也就是 def TT_TensorPtr : TT_PtrOf<[TT_Tensor]> 地址空间用来标记指针所指的内存位于哪一个内存层次, 是全局显存 (Global Memory), 还是片上共享内存 (Shared Memory)?
Triton 这个层级常用的复合类型是张量 (Tensor), 它屏蔽掉了很多数据 Layout 的细节. 对每一种标量类型都可以构造出对应的张量类型.
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_IntTensor : RankedTensorOf<[TT_Int]>;
def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>;
def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>;并且 Triton 将某一种类型本身和其构成的张量类型都归为 xxxLike 的类型族, 例如 FloatLike 类型族包含了 Float 和 FloatTensor 两种类型.
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;因此 Triton 的数据类型分为几大类, 包括 IntLike, FloatLike, PtrLike, TensorPtr 四种.
Triton 中定义属性的文件在 include/triton/Dialect/Triton/IR/TritonAttrDefs.td 中. 这些属性大多为访存提供的辅助信息, 例如缓存策略, 内存语义, Padding 方式等.
def TT_CacheModifierAttr : I32EnumAttr<
"CacheModifier", "",
[
I32EnumAttrCase<"NONE", 1, "none">,
I32EnumAttrCase<"CA", 2, "ca">,
I32EnumAttrCase<"CG", 3, "cg">,
I32EnumAttrCase<"WB", 4, "wb">,
I32EnumAttrCase<"CS", 5, "cs">,
I32EnumAttrCase<"WT", 6, "wt">,
I32EnumAttrCase<"CV", 7, "cv">,
]> {
}这是缓存策略的枚举, 例如 CA 表示 Cache All, 数据在 L1 和 L2 都会被缓存; CG 表示 Cache Global, 数据只会被缓存到 L2; CS 表示 Cache Streaming, 数据只会被缓存到 L1; CV 表示 Cache Volatile, 每次读取数据都绕过 L1/L2, 直接从显存读取; WB 表示 Write Back, 写操作会先写到缓存, 后续同步到主存; WT 表示 Write Through, 写操作会直接写到显存, 同时更新缓存.
def TT_MemSemanticAttr : I32EnumAttr<
"MemSemantic", "",
[
I32EnumAttrCase<"RELAXED", 1, "relaxed">,
I32EnumAttrCase<"ACQUIRE", 2, "acquire">,
I32EnumAttrCase<"RELEASE", 3, "release">,
I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">,
]> {
}这是内存语义的枚举, RELAXED 使用松散模型, 只保证操作是原子的, 不保证指令之间的先后顺序; ACQUIRE 获取语义, 后续的读写指令不能被重排到该指令之前; RELEASE 释放语义, 之前的读写指令不能被重排到该指令之后; ACQUIRE_RELEASE 同时具备获取和释放语义, 是最强的内存语义.
def TT_EvictionPolicyAttr : I32EnumAttr<
"EvictionPolicy", "",
[
I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
]> {
}这是缓存驱逐的策略, NORMAL 表示默认的驱逐策略; EVICT_FIRST 表示优先淘汰所标记的数据; EVICT_LAST 表示优先保留所标记的数据.
def TT_PaddingOptionAttr : I32EnumAttr<
"PaddingOption", "",
[
I32EnumAttrCase<"PAD_ZERO", 1, "zero">,
I32EnumAttrCase<"PAD_NAN", 2, "nan">
]> {
}这是 Padding 的选项, PAD_ZERO 表示使用 0 进行填充; PAD_NAN 表示使用 NaN 进行填充. 还有一些其他的辅助属性, 例如舍入模式等, 这里就不一一列举了.
Triton 中定义操作的文件在 include/triton/Dialect/Triton/IR/TritonOps.td 中. 这里面的指令抽象层级都很高, 除了各种算术操作以外, 主要是一些和访存和张量操作相关的操作. 以下是一些重要的操作.
def TT_LoadOp : TT_Op<"load", [
// ...
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">,
TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">
]> {
// ...
let arguments = (
ins
AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other,
DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
OptionalAttr<TT_PaddingOptionAttr>:$padding,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
);
let results = (outs TT_Type:$result);
// ...
}