2026/3/12 14:23:20
网站建设
项目流程
hao123网站模板,网络营销策划推广,wordpress页面调用子页面,多语言 网站源码TorchInductor 源码深度剖析
为什么需要 TorchInductor#xff1f;从 Dynamo 到机器码的桥梁
TorchDynamo 的局限
在学习完 TorchDynamo 后#xff0c;我们知道 Dynamo 通过字节码级追踪可以捕获完整的计算图#xff08;FX Graph#xff09;#xff0c;但它有一个关键问…TorchInductor 源码深度剖析为什么需要 TorchInductor从 Dynamo 到机器码的桥梁TorchDynamo 的局限在学习完 TorchDynamo 后我们知道 Dynamo 通过字节码级追踪可以捕获完整的计算图FX Graph但它有一个关键问题# TorchDynamo 的输出torch.compiledefmodel(x,y):zxyreturnz.relu()# Dynamo 生成的 FX Graphdefforward(x,y):add_tensortorch.ops.aten.add.Tensor(x,y)relu_defaulttorch.ops.aten.relu.default(add_tensor)returnrelu_default# 问题这只是一个计算图还不是可执行的高效代码TorchDynamo 只负责捕获图不负责执行优化生成了 FX Graph计算图的中间表示但没有生成优化的 GPU/CPU 代码没有算子融合、内存优化等仍然需要逐个调用 PyTorch 算子TorchInductor 的使命TorchInductor 是 PyTorch 编译栈的代码生成器负责将 FX Graph 编译为高效的机器码。# 完整的编译流程用户代码 ↓[TorchDynamo]字节码追踪 → FX Graph ↓[AOTAutograd]前向/反向分离 → 优化的 FX Graph ↓[TorchInductor]代码生成 → 优化的机器码 ← 本文重点 ↓ GPU/CPU 执行技术对比维度TorchDynamoTorchInductor职责捕获计算图生成优化代码输入Python 字节码FX Graph输出FX GraphTriton/C 代码核心技术字节码分析 GuardLowering Fusion CodeGen优化无只捕获算子融合、内存规划、循环优化性能提升0x只是捕获1.5-5x真正的加速核心设计思想TorchDynamo: 分析字节码 → 构建 FX Graph [静态分析] [符号执行] TorchInductor: 分析 FX Graph → 生成优化代码 → 编译执行 [Lowering] [Fusion] [CodeGen]TorchInductor 是什么一句话总结TorchInductor 是 PyTorch 的 JIT 编译器后端通过 Lowering、Fusion、Scheduling 等技术将 FX Graph 编译为高效的 Triton/C 代码实现端到端的性能优化。核心组成TorchInductor 架构 ├─ Graph Lowering [转换] FX Graph → IR (Intermediate Representation) ├─ Scheduler [调度] 算子融合决策、内存规划 ├─ Code Generation [生成] IR → Triton/C 源代码 │ ├─ Triton CodeGen (GPU) │ ├─ C CodeGen (CPU) │ └─ Wrapper CodeGen (调用包装) ├─ Compilation [编译] 源代码 → 机器码 │ ├─ Triton Compiler → PTX/CUBIN │ └─ C Compiler → .so └─ Execution [执行] 动态加载并运行工作流程FX Graph (来自 Dynamo) | v [1] Graph Lowering | | 将高层算子转换为低层 IR | v IR Nodes (Pointwise, Reduction, etc.) | v [2] Scheduling Fusion | | 决定哪些算子可以融合 | 生成内存规划 | v Fused IR Groups | v [3] Code Generation | | 为每个融合组生成 Triton/C 代码 | v Source Code (Triton/C) | v [4] Compilation | | Triton → PTX → CUBIN | C → .so | v Machine Code | v [5] Execution | | 动态加载并执行 | v Result源码文件索引TorchInductor 的核心代码位于torch/_inductor/目录torch/_inductor/ ├── compile_fx.py # [入口] 编译 FX Graph 的主入口 ├── graph.py # [核心] GraphLowering - 图降低与转换 ├── lowering.py # [核心] FX 算子 → IR 的 lowering 规则 ├── ir.py # [核心] IR 节点定义 (Pointwise, Reduction, etc.) ├── scheduler.py # [核心] Scheduler - 融合决策与调度 ├── dependencies.py # 依赖分析 ├── sizevars.py # 符号化形状处理 ├── codegen/ # 代码生成 │ ├── common.py # 公共代码生成工具 │ ├── triton.py # [核心] Triton 代码生成 │ ├── cpp.py # C 代码生成 │ ├── wrapper.py # Wrapper 代码生成 │ └── simd.py # SIMD 优化 ├── kernel/ # 特殊 Kernel │ ├── mm.py # 矩阵乘法 │ ├── conv.py # 卷积 │ └── ... ├── fx_passes/ # FX Graph 优化 Pass │ ├── joint_graph.py # 联合图优化 │ ├── pre_grad.py # 梯度前优化 │ └── post_grad.py # 梯度后优化 ├── runtime/ # 运行时支持 │ ├── hints.py # 启发式提示 │ └── triton_heuristics.py # Triton 启发式 └── utils.py # 工具函数第一部分整体架构与编译流程章节 1从 FX Graph 到机器码的完整流程1.1 编译入口# torch/_inductor/compile_fx.pydefcompile_fx(gm:torch.fx.GraphModule,example_inputs:List[torch.Tensor],inner_compileNone,): 编译 FX GraphModule Args: gm: FX GraphModule (来自 TorchDynamo) example_inputs: 示例输入用于形状推导 inner_compile: 内部编译函数可选 Returns: compiled_fn: 编译后的可执行函数 # [1] 创建 GraphLowering 实例withV.set_graph_handler(GraphLowering(gm,example_inputsexample_inputs)):# [2] 执行 LoweringFX Graph → IRgraph_handlerV.graph graph_handler.run(*example_inputs)# [3] 编译为可执行函数compiled_fngraph_handler.compile_to_fn()returncompiled_fn执行流程时间线t0ms: compile_fx() 被调用 ↓ t1ms: 创建 GraphLowering 实例 ↓ t2ms: 开始 Lowering (FX Graph → IR) ├─→ 遍历 FX Graph 的每个节点 ├─→ 调用对应的 lowering 函数 └─→ 生成 IR 节点 ↓ t50ms: Lowering 完成 ↓ t51ms: 开始 Scheduling ├─→ 分析依赖关系 ├─→ 融合决策 └─→ 内存规划 ↓ t100ms: Scheduling 完成 ↓ t101ms: 开始 Code Generation ├─→ 为每个融合组生成 Triton 代码 ├─→ 生成 Wrapper 代码 └─→ 生成调用代码 ↓ t150ms: Code Generation 完成 ↓ t151ms: 开始 Compilation ├─→ Triton 编译为 PTX ├─→ PTX 编译为 CUBIN └─→ 动态加载 ↓ t300ms: Compilation 完成 ↓ t301ms: 返回 compiled_fn1.2 V.graph 上下文管理器# torch/_inductor/virtualized.pyclassV: 全局上下文管理器 提供对当前 GraphLowering 实例的访问 类似于线程局部存储Thread Local Storage _graph_handlerNonestaticmethoddefset_graph_handler(graph):设置当前的 GraphLowering 实例V._graph_handlergraphreturncontextlib.contextmanager(lambda:(yield))()propertydefgraph(self):获取当前的 GraphLowering 实例returnV._graph_handler# 使用方式withV.set_graph_handler(GraphLowering(gm)):# 在这个上下文中所有 lowering 函数都可以通过 V.graph 访问当前图V.graph.register_buffer(buf0,some_ir_node)1.3 GraphLowering - 核心控制器# torch/_inductor/graph.pyclassGraphLowering: 图降低器 - TorchInductor 的核心控制器 职责 1. 将 FX Graph 转换为 IR 2. 管理缓冲区buffers 3. 协调 Scheduler 4. 生成最终代码 def__init__(self,gm:torch.fx.GraphModule,example_inputsNone,):self.gmgm self.graphgm.graph self.example_inputsexample_inputs# 缓冲区管理self.buffers{}# name → IRNodeself.constants{}# name → Tensor# 图输入/输出self.graph_inputs{}self.graph_outputs[]# Scheduler稍后创建self.schedulerNonedefrun(self,*args): 执行 Lowering 遍历 FX Graph将每个节点转换为 IR # 一次遍历处理所有节点fornodeinself.graph.nodes:ifnode.opplaceholder:# [1] 处理输入节点self.graph_inputs[node.name]self.wrap_input(node,args)elifnode.opcall_function:# [2] 处理计算节点 - 调用对应的 lowering 函数self.call_function(node)elifnode.opcall_method:self.call_method(node)elifnode.opget_attr:self.get_attr(node)elifnode.opoutput:# [3] 处理输出节点self.graph_outputsnode.args[0]defcall_function(self,node): 处理 call_function 节点 例如torch.ops.aten.add.Tensor targetnode.target argsself.fetch_args(node.args)kwargsself.fetch_kwargs(node.kwargs)# 查找对应的 lowering 函数# lowerings 是一个全局字典映射 ATen 算子到 Inductor IR 的转换函数fromtorch._inductorimportlowering lowering_fnlowering.lowerings.get(target)iflowering_fnisNone:raiseNotImplementedError(fNo lowering for{target})# 调用 lowering 函数生成 IRresultlowering_fn(*args,**kwargs)# 保存结果self.register_buffer(node.name,result)defcompile_to_fn(self): 编译为可执行函数 流程 1. 创建 Scheduler 2. 执行调度与融合 3. 生成代码 4. 编译并加载 # [1] 创建 Schedulerfromtorch._inductor.schedulerimportScheduler self.schedulerScheduler(self.buffers)# [2] 执行调度self.scheduler.codegen()# [3] 生成 Python 模块returnself.compile_to_module().call1.4 完整示例逐步追踪让我们通过一个具体例子完整追踪从 FX Graph 到机器码的每一步# 用户代码importtorchtorch.compiledefmodel(x,y):zxyreturnz.relu()# 运行xtorch.randn(1024,devicecuda)ytorch.randn(1024,devicecuda)resultmodel(x,y)步骤 1FX Graph来自 Dynamo# TorchDynamo 生成的 FX Graphdefforward(x,y):add_tensortorch.ops.aten.add.Tensor(x,y)relu_defaulttorch.ops.aten.relu.default(add_tensor)returnrelu_default# 图结构# graph():# %x : Tensor(1024)# %y : Tensor(1024)# %add_tensor : Tensor(1024) call_function[targettorch.ops.aten.add.Tensor](args(%x, %y))# %relu_default : Tensor(1024) call_function[targettorch.ops.aten.relu.default](args(%add_tensor,))# return %relu_default步骤 2LoweringFX Graph → IR# torch/_inductor/lowering.pyregister_lowering(torch.ops.aten.add)defadd_tensor(x,y): 将 aten.add 转换为 Pointwise IR definner_fn(idx):returnops.add(ops.load(x,idx),ops.load(y,idx))returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)register_lowering(torch.ops.aten.relu)defrelu(x): 将 aten.relu 转换为 Pointwise IR definner_fn(idx):returnops.maximum(ops.load(x,idx),ops.constant(0.0,x.get_dtype()))returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)# 生成的 IR# buf0 Pointwise(inner_fnlambda idx: load(x, idx) load(y, idx), ranges[1024])# buf1 Pointwise(inner_fnlambda idx: max(load(buf0, idx), 0.0), ranges[1024])步骤 3Fusion算子融合# torch/_inductor/scheduler.pyclassScheduler:deffusion_pass(self): 算子融合决策 分析 - buf0 (add) 和 buf1 (relu) 都是 Pointwise - buf0 只被 buf1 使用 - 形状相同 → 可以融合 # 融合后的 IR# buf_fused Pointwise(# inner_fnlambda idx: max(load(x, idx) load(y, idx), 0.0),# ranges[1024]# )步骤 4Code Generation生成 Triton 代码# torch/_inductor/codegen/triton.pyclassTritonKernel:defcodegen(self): 生成 Triton 代码 # 生成的 Triton 代码code triton.jit def triton_poi_fused_add_relu_0( in_ptr0, # x in_ptr1, # y out_ptr0, # output xnumel, # 1024 XBLOCK: tl.constexpr, ): pid tl.program_id(0) xoffset pid * XBLOCK xindex xoffset tl.arange(0, XBLOCK) xmask xindex xnumel # Load tmp0 tl.load(in_ptr0 xindex, xmask) tmp1 tl.load(in_ptr1 xindex, xmask) # Compute (融合) tmp2 tmp0 tmp1 # add tmp3 tl.maximum(tmp2, 0.0) # relu # Store tl.store(out_ptr0 xindex, tmp3, xmask) returncode步骤 5Compilation编译# Triton 编译流程Triton Python Code ↓ Triton IR(TTIR)↓ LLVM IR ↓ PTX(GPU Assembly)↓ CUBIN(GPU Binary)步骤 6Execution执行# 生成的 Wrapper 代码defcall(args):xargs[0]# input xyargs[1]# input y# 分配输出缓冲区buf0torch.empty_strided((1024,),(1,),devicecuda,dtypetorch.float32)# 调用 Triton Kernelgridlambdameta:(triton.cdiv(1024,meta[XBLOCK]),)triton_poi_fused_add_relu_0[grid](x,y,buf0,1024,XBLOCK256,num_warps4,)return(buf0,)第二部分Lowering 机制 - FX Graph 到 IR 的转换章节 2Lowering 的核心原理2.1 什么是 LoweringLowering降低是将高层抽象FX Graph 中的 ATen 算子转换为低层中间表示IR的过程。# 高层FX Graphtorch.ops.aten.add.Tensor(x,y)# ↓ Lowering# 低层IRPointwise(inner_fnlambdaidx:ops.add(ops.load(x,idx),ops.load(y,idx)),ranges[N])为什么需要 Lowering统一表示将数百个 ATen 算子统一为少数几种 IR 类型便于优化IR 更容易分析和优化融合、内存规划等跨平台同一个 IR 可以生成不同后端的代码Triton/C/CUDA2.2 IR 节点类型# torch/_inductor/ir.pyclassIRNode:IR 节点基类defget_size(self)-List[Expr]:返回张量形状符号表达式passdefget_dtype(self)-torch.dtype:返回数据类型passdefget_device(self)-torch.device:返回设备passdefget_reads(self)-Set[str]:返回读取的缓冲区pass# 主要 IR 类型classPointwise(IRNode): 逐点操作 特点每个输出元素只依赖对应位置的输入元素 例如add, mul, relu, sigmoid def__init__(self,inner_fn,ranges,**kwargs):self.inner_fninner_fn# 计算逻辑self.rangesranges# 输出形状...classReduction(IRNode): 归约操作 特点多个输入元素归约为一个输出元素 例如sum, mean, max, min def__init__(self,inner_fn,ranges,reduction_ranges,reduction_type,**kwargs):self.inner_fninner_fn self.rangesranges# 输出形状self.reduction_rangesreduction_ranges# 归约维度self.reduction_typereduction_type# sum/max/min...classTensorBox(IRNode): 张量引用 包装一个 IRNode提供张量接口 这是 Inductor 中最常用的包装类型 def__init__(self,data:IRNode):self.datadatadefrealize(self): 实体化将延迟计算转换为实际的缓冲区 调用时机 1. 当张量被多次使用时避免重复计算 2. 当张量作为输出时 3. 当无法融合时如算子类型不兼容 实体化后会创建 ComputedBuffer ifisinstance(self.data,ComputedBuffer):return# 已经实体化# 创建新的 ComputedBuffernameV.graph.register_buffer(self.data)self.dataComputedBuffer(namename,layoutself.data.get_layout(),dataself.data)classComputedBuffer(IRNode): 计算缓冲区 表示一个需要计算并存储的中间结果 def__init__(self,name,layout,data):self.namename self.layoutlayout self.datadata# IRNode...2.3 Lowering 规则注册# torch/_inductor/lowering.py# 全局 lowering 注册表lowerings{}defregister_lowering(aten_op): 装饰器注册 lowering 函数 用法 register_lowering(torch.ops.aten.add) def add_tensor(x, y): ... defdecorator(fn):lowerings[aten_op]fnreturnfnreturndecorator# 示例add 的 loweringregister_lowering(torch.ops.aten.add)defadd_tensor(x,y): aten.add → Pointwise IR Args: x, y: TensorBox (包装的 IR 节点) Returns: TensorBox (包装的 Pointwise IR) definner_fn(idx):# idx 是符号索引例如 (i, j) 对于 2D 张量x_valops.load(x,idx)y_valops.load(y,idx)returnops.add(x_val,y_val)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)2.4 Lowering 执行流程# torch/_inductor/graph.pyclassGraphLowering:defcall_function(self,node): 处理 FX Graph 中的 call_function 节点 例如 %add call_function[targettorch.ops.aten.add.Tensor](args(%x, %y)) # [1] 获取目标算子targetnode.target# torch.ops.aten.add.Tensor# [2] 获取参数已经是 IR 节点args[self.get_buffer(arg)forarginnode.args]kwargs{k:self.get_buffer(v)fork,vinnode.kwargs.items()}# [3] 查找 lowering 函数fromtorch._inductorimportlowering lowering_fnlowering.lowerings.get(target)iflowering_fnisNone:raiseNotImplementedError(fNo lowering for{target})# [4] 调用 lowering 函数resultlowering_fn(*args,**kwargs)# [5] 注册结果缓冲区self.register_buffer(node.name,result)returnresult2.5 深入理解inner_fn 和 ops.load 的工作机制这是 TorchInductor 最精妙的设计 —— 延迟计算 符号化表达式树问题的本质# 在 lowering 函数中我们经常看到这样的代码definner_fn(idx):returnops.maximum(ops.load(x,idx),ops.constant(0.0,x.get_dtype()))关键问题inner_fn为什么不立即执行ops.load返回的是什么这些如何转换成 Triton 代码inner_fn 的本质计算逻辑的模板# ❌ 错误理解认为 inner_fn 会立即计算definner_fn(idx):returnops.maximum(ops.load(x,idx),# 马上加载 x[idx] 的值NO!ops.constant(0.0,x.get_dtype()))# ✅ 正确理解inner_fn 是一个计算模板# - 定义如何计算而不是立即计算# - 类似于 SQL 查询语句只是描述# - 类比把计算逻辑序列化成数据结构为什么需要 inner_fn# 原因 1延迟绑定 - 在不知道具体索引时定义计算逻辑# 对于 1D 张量idx 可能是 xindex# 对于 2D 张量idx 可能是 (i, j)# 对于 3D 张量idx 可能是 (i, j, k)# 原因 2便于融合 - 可以内联到其他计算中# relu 的 inner_fn 可以内联到 add 的结果中# 原因 3便于优化 - 可以分析和重写表达式树# 例如识别 x * 0 并优化为 0ops.load 返回的是表达式节点# torch/_inductor/ops.py (简化版)classLoad: Load 操作的符号表示 不会真正加载数据只是构建一个加载节点 def__init__(self,name,index):self.namename# 缓冲区名称如 buf0, in_ptr0self.indexindex# 索引表达式如 xindex, (i, j)def__repr__(self):returnfLoad({self.name}[{self.index}])classMaximum:Maximum 操作的符号表示def__init__(self,lhs,rhs):self.lhslhs# 左操作数可能是 Load 节点self.rhsrhs# 右操作数可能是 Constant 节点def__repr__(self):returnfMaximum({self.lhs},{self.rhs})classConstant:常量的符号表示def__init__(self,value,dtype):self.valuevalue self.dtypedtypedef__repr__(self):returnfConstant({self.value})# ops 模块提供构建这些节点的工厂函数defload(buffer,index):returnLoad(buffer.name,index)defmaximum(lhs,rhs):returnMaximum(lhs,rhs)defconstant(value,dtype):returnConstant(value,dtype)执行 inner_fn 时构建表达式树# 假设我们有xTensorBox(namebuf0,size[1024])# 定义 inner_fndefinner_fn(idx):returnops.maximum(ops.load(x,idx),ops.constant(0.0,torch.float32))# 现在符号执行 inner_fnidxxindex# 这只是一个符号不是具体的数值exprinner_fn(idx)# expr 的结果是一个表达式树AST# Maximum# / \# Load Constant# / \ |# buf0 xindex 0.0print(expr)# 输出Maximum(Load(buf0[xindex]), Constant(0.0))完整示例从表达式树到融合# 例子(x y).relu()# [1] add 的 inner_fndefadd_inner_fn(idx):returnops.add(ops.load(x,idx),# Load(x, xindex)ops.load(y,idx)# Load(y, xindex))# 符号执行add_expradd_inner_fn(xindex)# 结果表达式树# Add# / \# Load Load# / \ / \# x xindex y xindex# [2] relu 的 inner_fn引用 add_bufferdefrelu_inner_fn(idx):returnops.maximum(ops.load(add_buffer,idx),# 引用上一步的结果ops.constant(0.0,torch.float32))relu_exprrelu_inner_fn(xindex)# 结果Maximum(Load(add_buffer[xindex]), Constant(0.0))# [3] 融合内联 add_buffer 的计算deffused_inner_fn(idx):# 直接内联 add 的计算add_resultops.add(ops.load(x,idx),ops.load(y,idx))# 在 add 结果上应用 relureturnops.maximum(add_result,# 不再需要 load(add_buffer)ops.constant(0.0,torch.float32))fused_exprfused_inner_fn(xindex)# 融合后的表达式树# Maximum# / \# Add Constant# / \ |# Load Load 0.0# / \ / \# x idx y idx# 这样生成的 Triton 代码就只有一个 Kernel# 没有中间结果 add_buffer 写回内存从表达式树到 Triton 代码# torch/_inductor/codegen/triton.pyclassTritonKernel:defcodegen_pointwise(self,pointwise_node): 为 Pointwise IR 生成 Triton 代码 # [1] 符号执行 inner_fn获取表达式树symbolic_idxxindexexpr_treepointwise_node.inner_fn(symbolic_idx)# expr_tree Maximum(Load(x[xindex]), Constant(0.0))# [2] 递归生成代码result_varself.codegen_expr(expr_tree)# [3] 生成 storeself.stores.writeline(ftl.store(out_ptr0 xindex,{result_var}, xmask))defcodegen_expr(self,expr): 递归生成表达式的 Triton 代码 ifisinstance(expr,ops.Load):# 生成 load 指令tmp_varself.new_tmp_var()# tmp0self.loads.writeline(f{tmp_var} tl.load({expr.name}{expr.index}, xmask))returntmp_varelifisinstance(expr,ops.Maximum):# 递归生成左右操作数lhs_varself.codegen_expr(expr.lhs)# tmp0rhs_varself.codegen_expr(expr.rhs)# 0.0# 生成 maximum 指令tmp_varself.new_tmp_var()# tmp1self.compute.writeline(f{tmp_var} tl.maximum({lhs_var},{rhs_var}))returntmp_varelifisinstance(expr,ops.Constant):# 常量直接返回字符串表示returnstr(expr.value)else:raiseNotImplementedError(fUnknown expr:{type(expr)})完整转换流程可视化┌─────────────────────────────────────────────────────────────────┐ │ 阶段 1Lowering定义 inner_fn │ └─────────────────────────────────────────────────────────────────┘ register_lowering(torch.ops.aten.relu) def relu(x): def inner_fn(idx): ← 定义计算模板不执行 return ops.maximum(ops.load(x, idx), 0.0) return Pointwise.create(inner_fninner_fn, ...) ┌─────────────────────────────────────────────────────────────────┐ │ 阶段 2Scheduler符号执行构建表达式树 │ └─────────────────────────────────────────────────────────────────┘ expr inner_fn(xindex) ← 符号执行不求值 生成表达式树 Maximum / \ Load Constant / \ | x xindex 0.0 ┌─────────────────────────────────────────────────────────────────┐ │ 阶段 3CodeGen遍历表达式树生成 Triton 代码 │ └─────────────────────────────────────────────────────────────────┘ 遍历 Maximum 节点 └─ 遍历 Load 节点 → 生成: tmp0 tl.load(in_ptr0 xindex, xmask) └─ 遍历 Constant 节点 → 生成: 0.0 └─ 生成: tmp1 tl.maximum(tmp0, 0.0) 最终 Triton 代码 triton.jit def kernel(...): xindex pid * XBLOCK tl.arange(0, XBLOCK) xmask xindex xnumel tmp0 tl.load(in_ptr0 xindex, xmask) ← 这时才真正加载 tmp1 tl.maximum(tmp0, 0.0) ← 这时才真正计算 tl.store(out_ptr0 xindex, tmp1, xmask) ← 这时才真正写入设计优势总结1. 延迟绑定 - 灵活处理不同维度# 同一个 inner_fn 可以处理不同维度的张量inner_fn(xindex)# 1D: idx xindexinner_fn((i,j))# 2D: idx (i, j)inner_fn((i,j,k))# 3D: idx (i, j, k)2. 便于融合 - 内联表达式# 融合前两个独立的 Kernel中间结果写回内存buf0Pointwise(inner_fnlambdaidx:ops.add(ops.load(x,idx),ops.load(y,idx)))buf1Pointwise(inner_fnlambdaidx:ops.maximum(ops.load(buf0,idx),0.0))# 融合后一个 Kernel无中间结果buf_fusedPointwise(inner_fnlambdaidx:ops.maximum(ops.add(ops.load(x,idx),ops.load(y,idx)),# 内联0.0))3. 便于优化 - 代数简化# 优化器可以分析表达式树并优化exprops.mul(ops.load(x,idx),ops.constant(0.0))# 识别 x * 0 0optimizedops.constant(0.0)exprops.add(ops.load(x,idx),ops.constant(0.0))# 识别 x 0 xoptimizedops.load(x,idx)4. 跨后端 - 统一的 IR# 同一个表达式树可以生成不同后端的代码# Triton 后端triton_codegen.codegen_expr(expr)# → tl.maximum(tmp0, 0.0)# C 后端cpp_codegen.codegen_expr(expr)# → std::max(tmp0, 0.0f)# CUDA 后端cuda_codegen.codegen_expr(expr)# → fmaxf(tmp0, 0.0f)类比理解inner_fn ≈ SQL 查询语句描述操作不执行 ops.load ≈ SQL 中的 SELECT构建查询节点 表达式树 ≈ SQL 查询计划AST Pointwise IR ≈ 逻辑查询计划 CodeGen ≈ 物理查询计划 执行 Triton 代码 ≈ 实际的数据库操作核心要点inner_fn是延迟执行的计算配方ops.load返回符号化的表达式节点不加载数据转换过程定义模板 → 构建表达式树 → 遍历树生成代码这是编译器中经典的AST抽象语法树设计模式章节 3常见算子的 Lowering 实现3.1 Pointwise 算子# torch/_inductor/lowering.py# 一元算子 register_lowering(torch.ops.aten.relu)defrelu(x):ReLU: max(x, 0)definner_fn(idx):x_valops.load(x,idx)zeroops.constant(0.0,x.get_dtype())returnops.maximum(x_val,zero)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)register_lowering(torch.ops.aten.sigmoid)defsigmoid(x):Sigmoid: 1 / (1 exp(-x))definner_fn(idx):x_valops.load(x,idx)neg_xops.neg(x_val)exp_neg_xops.exp(neg_x)oneops.constant(1.0,x.get_dtype())denomops.add(one,exp_neg_x)returnops.truediv(one,denom)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)register_lowering(torch.ops.aten.tanh)deftanh(x):Tanh: (exp(x) - exp(-x)) / (exp(x) exp(-x))definner_fn(idx):x_valops.load(x,idx)# 使用 libdevice.tanhGPU 库函数returnops.tanh(x_val)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)# 二元算子 register_lowering(torch.ops.aten.mul)defmul_tensor(x,y):Element-wise multiplicationdefinner_fn(idx):x_valops.load(x,idx)y_valops.load(y,idx)returnops.mul(x_val,y_val)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)register_lowering(torch.ops.aten.maximum)defmaximum(x,y):Element-wise maximumdefinner_fn(idx):x_valops.load(x,idx)y_valops.load(y,idx)returnops.maximum(x_val,y_val)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(x.get_size()),)3.2 Reduction 算子register_lowering(torch.ops.aten.sum)defsum_dim(x,dimNone,keepdimFalse): Sum reduction Args: x: 输入张量 dim: 归约维度None 表示全部 keepdim: 是否保持维度 ifdimisNone:# 全局归约dimlist(range(len(x.get_size())))ifnotisinstance(dim,(list,tuple)):dim[dim]# 标准化维度处理负数索引ndimlen(x.get_size())dim[difd0elsedndimfordindim]# 计算输出形状output_size[]reduction_ranges[]fori,sizeinenumerate(x.get_size()):ifiindim:reduction_ranges.append(size)ifkeepdim:output_size.append(1)else:output_size.append(size)definner_fn(idx,reduction_idx): idx: 输出索引 reduction_idx: 归约维度的索引 # 构建完整的输入索引full_idx[]idx_iteriter(idx)reduction_iteriter(reduction_idx)foriinrange(ndim):ifiindim:full_idx.append(next(reduction_iter))else:full_idx.append(next(idx_iter))returnops.load(x,tuple(full_idx))returnReduction.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangesoutput_size,reduction_rangesreduction_ranges,reduction_typesum,)register_lowering(torch.ops.aten.mean)defmean_dim(x,dimNone,keepdimFalse):Mean sum / count# 先计算 sumsum_resultsum_dim(x,dim,keepdim)# 计算元素数量ifdimisNone:numel1forsizeinx.get_size():numel*sizeelse:ifnotisinstance(dim,(list,tuple)):dim[dim]numel1fordindim:numel*x.get_size()[d]# sum / numeldefinner_fn(idx):sum_valops.load(sum_result,idx)countops.constant(numel,x.get_dtype())returnops.truediv(sum_val,count)returnPointwise.create(devicex.get_device(),dtypex.get_dtype(),inner_fninner_fn,rangeslist(sum_result.get_size()),)3.3 复杂算子LayerNormregister_lowering(torch.ops.aten.native_layer_norm)deflayer_norm(input,normalized_shape,weightNone,biasNone,eps1e-5): Layer Normalization 公式 y (x - mean) / sqrt(var eps) * weight bias 其中 mean 和 var 在 normalized_shape 维度上计算 # [1] 计算归约维度ndimlen(input.get_size())norm_ndimlen(normalized_shape)reduction_dimslist(range(ndim-norm_ndim,ndim))# [2] 计算 meanmeanmean_dim(input,dimreduction_dims,keepdimTrue)# [3] 计算 variance# var mean((x - mean)^2)defcentered_squared(idx):x_valops.load(input,idx)mean_valops.load(mean,idx)# 会自动广播centeredops.sub(x_val,mean_val)returnops.mul(centered,centered)centered_sqPointwise.create(deviceinput.get_device(),dtypeinput.get_dtype(),inner_fncentered_squared,rangeslist(input.get_size()),)varmean_dim(centered_sq,dimreduction_dims,keepdimTrue)# [4] 归一化defnormalize(idx):x_valops.load(input,idx)mean_valops.load(mean,idx)var_valops.load(var,idx)# (x - mean) / sqrt(var eps)centeredops.sub(x_val,mean_val)eps_constops.constant(eps,input.get_dtype())var_epsops.add(var_val,eps_const)stdops.sqrt(var_eps)normalizedops.truediv(centered,std)# * weight biasifweightisnotNone:weight_valops.load(weight,idx[-norm_ndim:])# 只取最后几维normalizedops.mul(normalized,weight_val)ifbiasisnotNone:bias_valops.load(bias,idx[-norm_ndim:])normalizedops.add(normalized,bias_val)returnnormalized outputPointwise.create(deviceinput.get_device(),dtypeinput.get_dtype(),inner_fnnormalize,rangeslist(input.get_size()),)returnoutput,mean,var# 返回 (output, mean, rstd)第三部分Scheduler 调度系统章节 4Scheduler 的核心职责4.1 Scheduler 是什么Scheduler调度器负责决定如何执行 IR 节点包括融合决策哪些节点可以融合成一个 Kernel执行顺序节点的执行顺序拓扑排序内存规划缓冲区的分配与复用# torch/_inductor/scheduler.pyclassScheduler: 调度器 输入IR 节点列表来自 GraphLowering 输出调度计划融合组 执行顺序 def__init__(self,buffers): Args: buffers: Dict[str, IRNode] - 所有缓冲区 self.buffersbuffers self.nodes[]# SchedulerNode 列表self.fused_nodes[]# 融合后的节点组defcodegen(self): 主调度流程 1. 创建 SchedulerNode 2. 分析依赖关系 3. 融合决策 4. 生成代码 # [1] 为每个缓冲区创建 SchedulerNodeself.create_scheduler_nodes()# [2] 分析依赖关系self.compute_dependencies()# [3] 融合决策self.fusion_pass()# [4] 拓扑排序self.topological_sort()# [5] 生成代码self.generate_code()4.2 SchedulerNode - 调度节点classSchedulerNode: 调度节点 包装一个 IR 节点添加调度信息 def__init__(self,scheduler,node:IRNode):self.schedulerscheduler self.nodenode# IR 节点# 依赖关系self.read_writesself.node.get_read_writes()self.unmet_dependenciesset()# 未满足的依赖self.users[]# 使用该节点的节点列表# 融合信息self.groupNone# 所属融合组self.can_inplaceFalse# 是否可以原地操作defcan_fuse(self,other:SchedulerNode)-bool: 判断是否可以与另一个节点融合 条件 1. 都是 Pointwise 或 Reduction 2. 设备相同 3. 形状兼容 4. 无循环依赖 # [1] 类型检查ifnotself.is_fusable_type()ornotother.is_fusable_type():returnFalse# [2] 设备检查ifself.node.get_device()!other.node.get_device():returnFalse# [3] 形状检查ifnotself.is_compatible_shape(other):returnFalse# [4] 依赖检查ifself.has_circular_dependency(other):returnFalsereturnTruedefis_fusable_type(self)-bool:是否是可融合的类型returnisinstance(self.node,(Pointwise,Reduction))defis_compatible_shape(self,other)-bool:形状是否兼容# 简化要求形状完全相同returnself.node.get_size()other.node.get_size()章节 5融合决策算法5.1 融合的基本原则classScheduler:deffusion_pass(self): 融合决策 策略 1. 优先融合 Pointwise 操作最容易 2. 尝试融合 Reduction Pointwise 3. 避免融合会增加内存使用的情况 # [1] 构建融合候选fusion_candidatesself.find_fusion_candidates()# [2] 贪心融合forproducer,consumerinfusion_candidates:ifself.should_fuse(producer,consumer):self.fuse_nodes(producer,consumer)defshould_fuse(self,producer,consumer)-bool: 决定是否融合两个节点 考虑因素 1. 是否可以融合类型、形状等 2. 融合后的收益减少内存访问 3. 融合后的成本增加寄存器压力 # [1] 基本检查ifnotproducer.can_fuse(consumer):returnFalse# [2] 检查 producer 是否只被 consumer 使用iflen(producer.users)!1:# producer 有多个使用者融合会导致重复计算returnFalse# [3] 估算收益benefitself.estimate_fusion_benefit(producer,consumer)costself.estimate_fusion_cost(producer,consumer)returnbenefitcostdefestimate_fusion_benefit(self,producer,consumer)-float: 估算融合收益 主要收益减少内存访问 # producer 的输出不需要写回内存producer_sizeproducer.node.get_numel()elem_sizeproducer.node.get_dtype().itemsize# 节省的内存访问字节saved_bytesproducer_size*elem_size*2# 1次写 1次读returnsaved_bytesdefestimate_fusion_cost(self,producer,consumer)-float: 估算融合成本 主要成本增加寄存器使用 # 简化假设每个操作需要固定数量的寄存器producer_regsself.estimate_register_usage(producer)consumer_regsself.estimate_register_usage(consumer)# 融合后的寄存器使用fused_regsproducer_regsconsumer_regs# 如果超过阈值成本很高MAX_REGS64iffused_regsMAX_REGS:returnfloat(inf)return0.0# 否则成本可忽略5.2 融合实现classScheduler:deffuse_nodes(self,producer,consumer): 融合两个节点 策略 1. 内联 producer 的计算到 consumer 2. 更新依赖关系 3. 移除 producer # [1] 创建融合后的 inner_fndeffused_inner_fn(idx):# 内联 producer 的计算producer_resultproducer.node.inner_fn(idx)# 在 consumer 的 inner_fn 中# 将对 producer 的 load 替换为直接使用 producer_resultreturnconsumer.node.inner_fn_with_inline(producer.name,producer_result,idx)# [2] 创建融合节点fused_nodePointwise.create(deviceconsumer.node.get_device(),dtypeconsumer.node.get_dtype(),inner_fnfused_inner_fn,rangesconsumer.node.get_size(),)# [3] 更新图# 将 consumer 替换为 fused_nodeself.replace_node(consumer,fused_node)# 移除 producerself.remove_node(producer)# [4] 更新依赖关系# producer 的依赖 → fused_node 的依赖fordepinproducer.unmet_dependencies:fused_node.unmet_dependencies.add(dep)章节 6内存规划6.1 缓冲区生命周期分析classScheduler:defcompute_buffer_lifetimes(self): 计算每个缓冲区的生命周期 生命周期 [first_use, last_use] lifetimes{}# 拓扑排序后的节点fori,nodeinenumerate(self.ordered_nodes):# 该节点读取的缓冲区forbuf_nameinnode.read_writes.reads:ifbuf_namenotinlifetimes:lifetimes[buf_name][i,i]else:lifetimes[buf_name][1]i# 更新 last_use# 该节点写入的缓冲区forbuf_nameinnode.read_writes.writes:ifbuf_namenotinlifetimes:lifetimes[buf_name][i,i]# first_use 已经设置returnlifetimesdefallocate_buffers(self): 分配缓冲区 策略 1. 复用生命周期不重叠的缓冲区 2. 对齐内存以提高访问效率 lifetimesself.compute_buffer_lifetimes()# 按大小排序大的优先分配buffers_by_sizesorted(self.buffers.items(),keylambdax:x[1].get_numel(),reverseTrue)# 内存池memory_pool[]allocations{}forbuf_name,buf_nodeinbuffers_by_size:buf_lifetimelifetimes[buf_name]buf_sizebuf_node.get_numel()*buf_node.get_dtype().itemsize# 尝试从内存池中复用reusedFalseforpool_entryinmemory_pool:pool_buf,pool_lifetime,pool_offsetpool_entry# 检查生命周期是否不重叠if(buf_lifetime[0]pool_lifetime[1]orbuf_lifetime[1]pool_lifetime[0]):# 可以复用allocations[buf_name](pool_buf,pool_offset)# 更新生命周期pool_entry[1][min(pool_lifetime[0],buf_lifetime[0]),max(pool_lifetime[1],buf_lifetime[1])]reusedTruebreakifnotreused:# 分配新缓冲区new_buffbuf_pool_{len(memory_pool)}allocations[buf_name](new_buf,0)memory_pool.append([new_buf,buf_lifetime,buf_size])returnallocations第四部分CodeGen 代码生成章节 7Triton 代码生成原理7.1 IndentedBuffer - 代码缓冲区工具# torch/_inductor/codegen/common.pyclassIndentedBuffer: 缩进缓冲区 - 用于生成格式化的代码 特点 1. 自动管理缩进 2. 支持嵌套 3. 可以拼接其他 IndentedBuffer def__init__(self,indent_level0):self._lines[]self._indent_levelindent_level self._indent_str # 4 空格defwriteline(self,line):写入一行自动添加缩进ifline:self._lines.append(self._indent_str*self._indent_levelline)else:self._lines.append()defindent(self):增加缩进级别self._indent_level1defdedent(self):减少缩进级别self._indent_level-1defsplice(self,other_buffer):拼接另一个缓冲区self._lines.extend(other_buffer._lines)defgetvalue(self):获取完整代码字符串return\n.join(self._lines)# 使用示例codeIndentedBuffer()code.writeline(def my_function():)code.indent()code.writeline(x 1)code.writeline(return x)code.dedent()print(code.getvalue())# 输出# def my_function():# x 1# return x7.2 Triton 简介Triton 是一种 Python DSL领域特定语言用于编写 GPU Kernel。优势Python 语法易于学习自动内存管理自动并行化性能接近手写 CUDA90-95%示例importtritonimporttriton.languageastltriton.jitdefadd_kernel(x_ptr,# 输入指针y_ptr,# 输入指针out_ptr,# 输出指针n_elements,# 元素数量BLOCK_SIZE:tl.constexpr,# 编译时常量):# 获取当前线程块的 IDpidtl.program_id(0)# 计算该线程块处理的元素范围block_startpid*BLOCK_SIZE offsetsblock_starttl.arange(0,BLOCK_SIZE)# 边界检查maskoffsetsn_elements# 加载数据xtl.load(x_ptroffsets,maskmask)ytl.load(y_ptroffsets,maskmask)# 计算outputxy# 存储结果tl.store(out_ptroffsets,output,maskmask)# 调用gridlambdameta:(triton.cdiv(n,meta[BLOCK_SIZE]),)add_kernel[grid](x,y,out,n,BLOCK_SIZE1024)7.3 TritonKernel - 代码生成器# torch/_inductor/codegen/triton.pyclassTritonKernel: Triton Kernel 代码生成器 职责 1. 将 IR 节点转换为 Triton 代码 2. 管理参数、索引、加载、计算、存储 3. 选择最优的 Kernel 参数BLOCK_SIZE, num_warps 等 def__init__(self,*groups): Args: groups: 融合后的节点组 self.groupsgroups# 代码缓冲区self.argsIndentedBuffer()# 参数列表self.indexingIndentedBuffer()# 索引计算self.loadsIndentedBuffer()# 加载操作self.computeIndentedBuffer()# 计算逻辑self.storesIndentedBuffer()# 存储操作# 临时变量计数器self.tmp_counter0# 参数配置self.block_sizeNoneself.num_warpsNonedefcodegen(self): 生成完整的 Triton Kernel 代码 流程 1. 生成参数列表 2. 生成索引计算 3. 生成加载/计算/存储 4. 组装完整代码 # [1] 生成参数self.codegen_args()# [2] 生成索引self.codegen_indexing()# [3] 生成计算self.codegen_body()# [4] 选择参数self.select_kernel_config()# [5] 组装代码returnself.assemble()defcodegen_args(self): 生成参数列表 包括 - 输入缓冲区指针 - 输出缓冲区指针 - 元素数量 - 编译时常量BLOCK_SIZE 等 # 输入缓冲区fori,inpinenumerate(self.get_inputs()):self.args.writeline(fin_ptr{i},)# 输出缓冲区fori,outinenumerate(self.get_outputs()):self.args.writeline(fout_ptr{i},)# 元素数量self.args.writeline(xnumel,)# 编译时常量self.args.writeline(XBLOCK: tl.constexpr,)defcodegen_indexing(self): 生成索引计算代码 对于 1D 张量 xindex pid * XBLOCK tl.arange(0, XBLOCK) xmask xindex xnumel self.indexing.writeline(# 计算索引)self.indexing.writeline(pid tl.program_id(0))self.indexing.writeline(xoffset pid * XBLOCK)self.indexing.writeline(xindex xoffset tl.arange(0, XBLOCK))self.indexing.writeline(xmask xindex xnumel)defcodegen_body(self): 生成计算主体 遍历融合组中的所有节点生成对应的代码 forgroupinself.groups:fornodeingroup.nodes:self.codegen_node(node)defcodegen_node(self,node): 为单个 IR 节点生成代码 ifisinstance(node,Pointwise):self.codegen_pointwise(node)elifisinstance(node,Reduction):self.codegen_reduction(node)else:raiseNotImplementedError(fUnknown node type:{type(node)})defcodegen_pointwise(self,node): 生成 Pointwise 节点的代码 策略 1. 符号执行 inner_fn 2. 递归生成表达式树 3. 生成 load/compute/store # 符号执行symbolic_idxsympy.Symbol(xindex)exprnode.inner_fn(symbolic_idx)# 生成表达式代码result_varself.codegen_expr(expr)# 生成 storeself.stores.writeline(ftl.store(out_ptr0 xindex,{result_var}, xmask))defcodegen_expr(self,expr): 递归生成表达式的 Triton 代码 Args: expr: 符号表达式ops.Load, ops.Add, etc. Returns: tmp_var: 临时变量名 ifisinstance(expr,ops.Load):# 加载操作buffer_nameexpr.name indexxindextmp_varself.new_tmp_var()self.loads.writeline(f{tmp_var} tl.load({buffer_name}{index}, xmask))returntmp_varelifisinstance(expr,ops.Add):# 加法操作lhs_varself.codegen_expr(expr.lhs)rhs_varself.codegen_expr(expr.rhs)tmp_varself.new_tmp_var()self.compute.writeline(f{tmp_var}{lhs_var}{rhs_var})returntmp_varelifisinstance(expr,ops.Maximum):# Maximum 操作lhs_varself.codegen_expr(expr.lhs)rhs_varself.codegen_expr(expr.rhs)tmp_varself.new_tmp_var()self.compute.writeline(f{tmp_var} tl.maximum({lhs_var},{rhs_var}))returntmp_varelifisinstance(expr,ops.Constant):# 常量returnstr(expr.value)else:raiseNotImplementedError(fUnknown expr:{type(expr)})defnew_tmp_var(self):分配新的临时变量varftmp{self.tmp_counter}self.tmp_counter1returnvardefassemble(self): 组装完整的 Kernel 代码 codeIndentedBuffer()# 函数签名code.writeline(triton.jit)code.writeline(fdef{self.kernel_name}()code.indent()code.splice(self.args)code.dedent()code.writeline():)# 函数体code.indent()code.splice(self.indexing)code.writeline()code.splice(self.loads)code.writeline()code.splice(self.compute)code.writeline()code.splice(self.stores)code.dedent()returncode.getvalue()章节 8完整示例 - 从 IR 到 Triton 代码让我们通过一个完整的例子看看如何从 IR 生成 Triton 代码# 输入 IR融合后fused_irPointwise(inner_fnlambdaidx:ops.maximum(ops.add(ops.load(in_ptr0,idx),# xops.load(in_ptr1,idx)# y),ops.constant(0.0,torch.float32)),ranges[1024],namebuf_fused)# 代码生成过程kernelTritonKernel(fused_ir)# [1] 生成参数# in_ptr0,# in_ptr1,# out_ptr0,# xnumel,# XBLOCK: tl.constexpr,# [2] 生成索引# pid tl.program_id(0)# xoffset pid * XBLOCK# xindex xoffset tl.arange(0, XBLOCK)# xmask xindex xnumel# [3] 生成计算# 符号执行 inner_fn(xindex)exprfused_ir.inner_fn(xindex)# expr ops.maximum(# ops.add(# ops.load(in_ptr0, xindex),# ops.load(in_ptr1, xindex)# ),# ops.constant(0.0, torch.float32)# )# 递归生成代码# tmp0 tl.load(in_ptr0 xindex, xmask) # load x# tmp1 tl.load(in_ptr1 xindex, xmask) # load y# tmp2 tmp0 tmp1 # add# tmp3 tl.maximum(tmp2, 0.0) # relu# [4] 生成 store# tl.store(out_ptr0 xindex, tmp3, xmask)# 最终生成的 Triton 代码generated_code triton.jit def triton_poi_fused_add_relu_0( in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr, ): # 计算索引 pid tl.program_id(0) xoffset pid * XBLOCK xindex xoffset tl.arange(0, XBLOCK) xmask xindex xnumel # Load tmp0 tl.load(in_ptr0 xindex, xmask) tmp1 tl.load(in_ptr1 xindex, xmask) # Compute tmp2 tmp0 tmp1 tmp3 tl.maximum(tmp2, 0.0) # Store tl.store(out_ptr0 xindex, tmp3, xmask) 章节 9Kernel 参数选择启发式9.1 BLOCK_SIZE 选择classTritonKernel:defselect_block_size(self,numel,dtype): 选择最优的 BLOCK_SIZE 考虑因素 1. 元素数量 2. 数据类型大小 3. 内存对齐 4. 并行度 elem_sizedtype.itemsize# 候选值2 的幂次candidates[128,256,512,1024]forblock_sizeinreversed(candidates):# 计算 grid sizegrid_sizetriton.cdiv(numel,block_size)# 条件 1: grid 不能太小至少要有足够的并行度min_gridself.get_device_sm_count()*4ifgrid_sizemin_grid:continue# 条件 2: 内存对齐128 字节 缓存行大小ifblock_size*elem_size%1280:returnblock_sizereturn256# 默认值defselect_num_warps(self,block_size): 选择 num_warps 每个 warp 32 threads num_warps ceil(block_size / 32)向上取整到 2 的幂次 min_warps(block_size31)//32importmath num_warps2**math.ceil(math.log2(min_warps))# 限制范围returnmin(max(num_warps,1),32)章节 10Wrapper 代码生成# torch/_inductor/codegen/wrapper.pyclassWrapperCodegen: 生成 Python Wrapper 代码 职责 1. 解包输入参数 2. 分配输出缓冲区 3. 调用 Triton Kernel 4. 返回结果 defgenerate(self):生成完整的 Python 模块codeIndentedBuffer()# [1] 导入code.writeline(import torch)code.writeline(import triton)code.writeline(import triton.language as tl)code.writeline()# [2] Kernel 定义forkernelinself.kernels:code.splice(kernel.code)code.writeline()# [3] 调用函数code.writeline(def call(args):)code.indent()# [3.1] 解包参数fori,inpinenumerate(self.graph_inputs):code.writeline(fprimals_{i1} args[{i}])code.writeline()# [3.2] 分配输出缓冲区fori,bufinenumerate(self.buffers):code.writeline(fbuf{i} torch.empty_strided()code.indent()code.writeline(f{buf.size},)code.writeline(f{buf.stride},)code.writeline(fdevice{buf.device},)code.writeline(fdtype{buf.dtype})code.dedent()code.writeline())code.writeline()# [3.3] 调用 Kernelforkernel_callinself.kernel_calls:code.splice(kernel_call)code.writeline()# [3.4] 返回结果output_names[fbuf{i}foriinself.output_indices]code.writeline(freturn ({, .join(output_names)},))code.dedent()returncode.getvalue()第五部分优化技术章节 11内存优化11.1 缓冲区复用classScheduler:defoptimize_memory(self): 内存优化 策略 1. 分析缓冲区生命周期 2. 复用生命周期不重叠的缓冲区 3. 原地操作inplace # [1] 计算生命周期lifetimesself.compute_buffer_lifetimes()# [2] 构建干扰图interference_graphself.build_interference_graph(lifetimes)# [3] 图着色寄存器分配算法allocationself.graph_coloring(interference_graph)returnallocationdefbuild_interference_graph(self,lifetimes): 构建干扰图 如果两个缓冲区的生命周期重叠则它们干扰 graph{}bufferslist(lifetimes.keys())fori,buf1inenumerate(buffers):graph[buf1]set()forbuf2inbuffers[i1:]:# 检查生命周期是否重叠ifself.lifetimes_overlap(lifetimes[buf1],lifetimes[buf2]):graph[buf1].add(buf2)ifbuf2notingraph:graph[buf2]set()graph[buf2].add(buf1)returngraphdefgraph_coloring(self,graph): 图着色算法 为每个节点分配一个颜色内存池 ID 相邻节点不能有相同颜色 allocation{}colors_used{}# 按度数排序度数高的优先nodessorted(graph.keys(),keylambdan:len(graph[n]),reverseTrue)fornodeinnodes:# 找到邻居使用的颜色neighbor_colors{allocation[neighbor]forneighboringraph[node]ifneighborinallocation}# 选择最小的未使用颜色color0whilecolorinneighbor_colors:color1allocation[node]color colors_used[color]colors_used.get(color,0)1returnallocation11.2 内存布局优化classLayoutOptimizer: 内存布局优化 目标 1. 减少 transpose 操作 2. 提高内存访问连续性 3. 利用 Tensor Core如果可用 defoptimize_layout(self,graph): 优化整个图的内存布局 # [1] 分析每个节点的首选布局preferred_layoutsself.analyze_preferred_layouts(graph)# [2] 传播布局约束final_layoutsself.propagate_layouts(graph,preferred_layouts)# [3] 插入必要的 transposeself.insert_transposes(graph,final_layouts)returnfinal_layoutsdefanalyze_preferred_layouts(self,graph): 分析每个节点的首选布局 例如 - Matmul 倾向于 (M, K) (K, N) → (M, N) - Conv2d 倾向于 NHWC在某些硬件上 layouts{}fornodeingraph.nodes:ifnode.opcall_function:ifnode.targettorch.ops.aten.mm:# Matmul: 倾向于行主序layouts[node]row_majorelifnode.targettorch.ops.aten.conv2d:# Conv: 根据硬件选择ifself.has_tensor_cores():layouts[node]nhwcelse:layouts[node]nchwelse:layouts[node]anyreturnlayouts章节 12循环优化12.1 Loop TilingclassLoopOptimizer: 循环优化 技术 1. Tiling分块 2. Unrolling展开 3. Vectorization向量化 defapply_tiling(self,loop,tile_size): 应用循环分块 原始循环 for i in range(N): for j in range(M): C[i][j] A[i][j] B[i][j] 分块后 for ii in range(0, N, TILE_I): for jj in range(0, M, TILE_J): for i in range(ii, min(iiTILE_I, N)): for j in range(jj, min(jjTILE_J, M)): C[i][j] A[i][j] B[i][j] # 在 Triton 中这通过 BLOCK_SIZE 自动实现pass章节 13算子特化13.1 Matmul 优化# torch/_inductor/kernel/mm.pyclassMatmulKernel: 矩阵乘法的特化实现 策略 1. 调用 cuBLAS/cuDNN最优 2. 如果需要融合生成 Triton Kernel staticmethoddefshould_use_extern_kernel(A,B): 决定是否使用外部库cuBLAS 条件 1. 矩阵足够大 128x128 2. 没有需要融合的后续操作 M,KA.get_size()K2,NB.get_size()# 小矩阵使用 TritonifM*N*K128*128*128:returnFalse# 大矩阵使用 cuBLASreturnTruestaticmethoddefgenerate_triton_matmul(A,B): 生成 Triton Matmul Kernel 使用 Triton 的 tl.dot 指令 code triton.jit def matmul_kernel( A_ptr, B_ptr, C_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): pid_m tl.program_id(0) pid_n tl.program_id(1) # 计算该 block 处理的范围 rm pid_m * BLOCK_M tl.arange(0, BLOCK_M) rn pid_n * BLOCK_N tl.arange(0, BLOCK_N) # 累加器 acc tl.zeros((BLOCK_M, BLOCK_N), dtypetl.float32) # 循环遍历 K 维度 for k in range(0, K, BLOCK_K): rk k tl.arange(0, BLOCK_K) # 加载 A 和 B 的块 A_block tl.load(A_ptr rm[:, None] * stride_am rk[None, :] * stride_ak) B_block tl.load(B_ptr rk[:, None] * stride_bk rn[None, :] * stride_bn) # 矩阵乘法 acc tl.dot(A_block, B_block) # 存储结果 C acc.to(tl.float16) tl.store(C_ptr rm[:, None] * stride_cm rn[None, :] * stride_cn, C) returncode第六部分编译与执行章节 14Triton 编译流程14.1 编译管道# Triton 编译流程Triton Python Code ↓[1]解析 Triton AST ↓[2]类型推断 Typed Triton AST ↓[3]转换为 Triton IR(TTIR)TTIR ↓[4]优化 TTIR Optimized TTIR ↓[5]转换为 LLVM IR LLVM IR ↓[6]LLVM 优化 Optimized LLVM IR ↓[7]转换为 PTX PTX(GPU Assembly)↓[8]PTX 编译为 CUBIN CUBIN(GPU Binary)14.2 缓存机制# torch/_inductor/codecache.pyclassCodeCache: 代码缓存 避免重复编译相同的 Kernel def__init__(self):self.cache_dirPath(/tmp/torchinductor_cache)self.cache_dir.mkdir(exist_okTrue)defget_cache_key(self,code,config): 计算缓存 Key 基于 1. Kernel 代码 2. 编译配置BLOCK_SIZE, num_warps 等 3. 设备信息 importhashlib key_data{code:code,config:config,device:torch.cuda.get_device_name(),triton_version:triton.__version__,}key_strstr(key_data)returnhashlib.sha256(key_str.encode()).hexdigest()defload(self,cache_key):从缓存加载编译结果cache_fileself.cache_dir/f{cache_key}.cubinifcache_file.exists():returncache_file.read_bytes()returnNonedefsave(self,cache_key,cubin):保存编译结果到缓存cache_fileself.cache_dir/f{cache_key}.cubincache_file.write_bytes(cubin)章节 15动态加载与执行# torch/_inductor/runtime/runtime_utils.pyclassCompiledModule: 编译后的模块 包含 1. 编译后的 Kernel 2. Wrapper 函数 3. 元数据 def__init__(self,module_path): Args: module_path: 生成的 Python 模块路径 # 动态导入模块importimportlib.util specimportlib.util.spec_from_file_location(compiled_module,module_path)self.moduleimportlib.util.module_from_spec(spec)spec.loader.exec_module(self.module)# 获取调用函数self.callself.module.calldef__call__(self,*args,**kwargs):执行编译后的代码returnself.call(args)第七部分高级主题章节 17AutoTuning 机制# Triton AutoTuningtriton.autotune(configs[triton.Config({BLOCK_SIZE:128},num_warps2),triton.Config({BLOCK_SIZE:256},num_warps4),triton.Config({BLOCK_SIZE:512},num_warps8),triton.Config({BLOCK_SIZE:1024},num_warps16),],key[n_elements],)triton.jitdefautotuned_kernel(x_ptr,y_ptr,out_ptr,n_elements,BLOCK_SIZE:tl.constexpr):# Kernel 代码pass# 第一次调用会测试所有配置选择最快的# 后续调用直接使用缓存的最优配置章节 18自定义后端# 为新硬件添加自定义后端classMyDeviceBackend:自定义设备后端staticmethoddefcompile_fx(gm,example_inputs):编译 FX Graph# 自定义编译逻辑passstaticmethoddefgenerate_code(ir_nodes):生成设备特定的代码# 例如生成 NPU 指令pass# 注册后端torch._dynamo.list_backends()[mydevice]MyDeviceBackend.compile_fx章节 19调试技巧# [1] 查看生成的代码importos os.environ[TORCH_LOGS]output_codetorch.compiledefmodel(x):returnx.relu()xtorch.randn(100,devicecuda)ymodel(x)# 会打印生成的 Triton 代码# [2] 保存生成的代码到文件torch._inductor.config.debugTruetorch._inductor.config.trace.enabledTruetorch._inductor.config.trace.log_dir./inductor_logs# [3] 禁用特定优化torch._inductor.config.triton.autotuneFalsetorch._inductor.config.fx_graph_cacheFalse# [4] 强制重新编译torch._dynamo.reset()第八部分实战案例章节 20案例 1 - GELU 激活函数优化importtorchimportmathtorch.compiledefgelu_approximate(x): GELU 激活函数的近似实现 GELU(x) ≈ 0.5 * x * (1 tanh(sqrt(2/π) * (x 0.044715 * x^3))) sqrt_2_over_pimath.sqrt(2.0/math.pi)x_cubedx*x*x innersqrt_2_over_pi*(x0.044715*x_cubed)tanh_innertorch.tanh(inner)result0.5*x*(1.0tanh_inner)returnresult# TorchInductor 会将所有操作融合成一个 Kernel# 性能提升~9x相比 Eager 模式xtorch.randn(1000000,devicecuda)ygelu_approximate(x)章节 21案例 2 - LayerNorm 优化torch.compiledeflayer_norm_manual(x,weight,bias,eps1e-5): 手动实现的 LayerNorm TorchInductor 会优化为 1-2 个融合 Kernel # 计算 mean 和 variancemeanx.mean(dim-1,keepdimTrue)varx.var(dim-1,keepdimTrue,unbiasedFalse)# 归一化x_normalized(x-mean)/torch.sqrt(vareps)# Scale and shiftreturnx_normalized*weightbias# 性能对比xtorch.randn(128,512,devicecuda)weighttorch.randn(512,devicecuda)biastorch.randn(512,devicecuda)# Eager: ~0.5ms# Compiled: ~0.1ms (5x 加速)章节 22案例 3 - 自定义融合算子torch.compiledeffused_linear_gelu(x,weight,bias): 融合 Linear GELU TorchInductor 会 1. Linear 使用 cuBLAS 2. GELU 生成融合 Triton Kernel # Linearouttorch.nn.functional.linear(x,weight,bias)# GELUreturntorch.nn.functional.gelu(out,approximatetanh)# 性能提升# - Linear: 使用 cuBLAS最优# - GELU: 融合 Kernel避免中间结果写回# 总体加速~2-3x总结核心要点TorchInductor 的职责将 FX Graph 编译为高效的机器码通过 Lowering、Fusion、CodeGen 实现端到端优化关键技术Lowering: FX Graph → IR统一表示Fusion: 算子融合减少内存访问CodeGen: IR → Triton/C高效代码生成Compilation: Triton → PTX → CUBIN机器码性能优化算子融合减少 Kernel 启动和内存访问内存规划缓冲区复用、布局优化循环优化Tiling、向量化AutoTuning自动选择最优参数与 TorchDynamo 的关系TorchDynamo: 捕获图字节码 → FX Graph TorchInductor: 优化执行FX Graph → 机器码技术栈总览用户 Python 代码 ↓ [TorchDynamo] FX Graph ↓ [TorchInductor Lowering] IR (Pointwise, Reduction, etc.) ↓ [Scheduler] Fused IR Groups ↓ [CodeGen] Triton/C Code ↓ [Compilation] PTX/CUBIN ↓ [Execution] 高效计算结果推荐阅读顺序如果你想深入阅读 TorchInductor 源码建议按以下顺序torch/_inductor/compile_fx.py- 理解编译入口 (200 行)torch/_inductor/graph.py- 理解 GraphLowering (1000 行核心)torch/_inductor/lowering.py- 理解 Lowering 规则 (3000 行)torch/_inductor/ir.py- 理解 IR 节点定义 (2000 行)torch/_inductor/scheduler.py- 理解调度与融合 (2000 行核心)torch/_inductor/codegen/triton.py- 理解 Triton 代码生成 (3000 行核心)总计~15,000 行核心代码本文档基于 PyTorch 2.0 源码编写部分实现细节可能因版本而异