HomeArchiveBlog


Original contents are licensed under CC BY-NC 4.0. All rights reserved © 2026 Kai.
Back to Archives
Code Reading: Triton Dialect

An in-depth exploration of the Triton dialect within MLIR, examining its structure, operations, and how it facilitates GPU programming.

Sat Dec 20 2025
Wed Dec 31 2025
TritonMLIRGPU Compiler
On this page
  • Code Reading: Triton Dialect
    • 方言类型
    • 方言属性
    • 方言操作

Code Reading: Triton Dialect

这是阅读 Triton 源码系列对 Triton Dialect 代码结构的梳理.

方言类型

Triton 中定义类型的文件在 include/triton/Dialect/Triton/IR/TritonTypes.td 中. Triton 在 MLIR 原有的类型中选择了一些作为基本数据类型, 例如

def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">;
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;

Triton 没有保留 MLIR 中所有的基本数据类型, 是因为 GPU 运算实际上只支持它列举出来的这些类型, 其他都是不需要的. Triton 还定义了它自己的指针类型.

def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> {
    // ...
    let parameters = (ins "Type":$pointeeType, "int":$addressSpace);
    // ...
}
class TT_PtrOf<list[Type] pointeeType> : DialectType<...>;

指针类型接受两个参数, 一个是指针指向的数据类型, 另一个是地址空间. 指针可以指向标量类型, 但更常见的是指向一块张量, 也就是 def TT_TensorPtr : TT_PtrOf<[TT_Tensor]> 地址空间用来标记指针所指的内存位于哪一个内存层次, 是全局显存 (Global Memory), 还是片上共享内存 (Shared Memory)?

Triton 这个层级常用的复合类型是张量 (Tensor), 它屏蔽掉了很多数据 Layout 的细节. 对每一种标量类型都可以构造出对应的张量类型.

def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
def TT_IntTensor : RankedTensorOf<[TT_Int]>;
def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>;
def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>;

并且 Triton 将某一种类型本身和其构成的张量类型都归为 xxxLike 的类型族, 例如 FloatLike 类型族包含了 Float 和 FloatTensor 两种类型.

def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;

因此 Triton 的数据类型分为几大类, 包括 IntLike, FloatLike, PtrLike, TensorPtr 四种.

方言属性

Triton 中定义属性的文件在 include/triton/Dialect/Triton/IR/TritonAttrDefs.td 中. 这些属性大多为访存提供的辅助信息, 例如缓存策略, 内存语义, Padding 方式等.

def TT_CacheModifierAttr : I32EnumAttr<
    "CacheModifier", "",
    [
        I32EnumAttrCase<"NONE", 1, "none">,
        I32EnumAttrCase<"CA", 2, "ca">,
        I32EnumAttrCase<"CG", 3, "cg">,
        I32EnumAttrCase<"WB", 4, "wb">,
        I32EnumAttrCase<"CS", 5, "cs">,
        I32EnumAttrCase<"WT", 6, "wt">,
        I32EnumAttrCase<"CV", 7, "cv">,
    ]> {
}

这是缓存策略的枚举, 例如 CA 表示 Cache All, 数据在 L1 和 L2 都会被缓存; CG 表示 Cache Global, 数据只会被缓存到 L2; CS 表示 Cache Streaming, 数据只会被缓存到 L1; CV 表示 Cache Volatile, 每次读取数据都绕过 L1/L2, 直接从显存读取; WB 表示 Write Back, 写操作会先写到缓存, 后续同步到主存; WT 表示 Write Through, 写操作会直接写到显存, 同时更新缓存.

def TT_MemSemanticAttr : I32EnumAttr<
    "MemSemantic", "",
    [
      I32EnumAttrCase<"RELAXED", 1, "relaxed">,
      I32EnumAttrCase<"ACQUIRE", 2, "acquire">,
      I32EnumAttrCase<"RELEASE", 3, "release">,
      I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">,
    ]> {
}

这是内存语义的枚举, RELAXED 使用松散模型, 只保证操作是原子的, 不保证指令之间的先后顺序; ACQUIRE 获取语义, 后续的读写指令不能被重排到该指令之前; RELEASE 释放语义, 之前的读写指令不能被重排到该指令之后; ACQUIRE_RELEASE 同时具备获取和释放语义, 是最强的内存语义.

def TT_EvictionPolicyAttr : I32EnumAttr<
    "EvictionPolicy", "",
    [
        I32EnumAttrCase<"NORMAL", 1, "evict_normal">,
        I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
        I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
    ]> {
}

这是缓存驱逐的策略, NORMAL 表示默认的驱逐策略; EVICT_FIRST 表示优先淘汰所标记的数据; EVICT_LAST 表示优先保留所标记的数据.

def TT_PaddingOptionAttr : I32EnumAttr<
    "PaddingOption", "",
    [
        I32EnumAttrCase<"PAD_ZERO", 1, "zero">,
        I32EnumAttrCase<"PAD_NAN", 2, "nan">
    ]> {
}

这是 Padding 的选项, PAD_ZERO 表示使用 0 进行填充; PAD_NAN 表示使用 NaN 进行填充. 还有一些其他的辅助属性, 例如舍入模式等, 这里就不一一列举了.

方言操作

Triton 中定义操作的文件在 include/triton/Dialect/Triton/IR/TritonOps.td 中. 这里面的指令抽象层级都很高, 除了各种算术操作以外, 主要是一些和访存和张量操作相关的操作. 以下是一些重要的操作.

def TT_LoadOp : TT_Op<"load", [
    // ...
    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
    DeclareOpInterfaceMethods<InferTypeOpInterface>,
    TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">,
    TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
                    "($_op.getOperands().size() <= 1) || std::equal_to<>()">,
    TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)",
                    "($_op.getOperands().size() <= 2) || std::equal_to<>()">
]> {
    // ...
    let arguments = (
        ins
        AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
        Optional<TT_BoolLike>:$mask,
        Optional<TT_Type>:$other,

        DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
        OptionalAttr<TT_PaddingOptionAttr>:$padding,
        DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
        DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict,
        DefaultValuedAttr<BoolAttr, "false">:$isVolatile
    );

    let results = (outs TT_Type:$result);
    // ...
}