引言

整理一下当前主流的 GPU 技术栈进行概览,并重点探讨其关键组件及其连接方式。

技术栈图解

下面的 Mermaid 图展示了 GPU 技术栈的主要组件及其逻辑关系。

    graph TD
	    subgraph 高层框架
	        JAX[JAX];
	        TF[TensorFlow];
	        PT[PyTorch];
	    end
	
	    subgraph 中间层技术
	        MLIR[MLIR];
	        Triton[Triton];
	        ONNX[ONNX];
	        ORT[ONNX Runtime];
	        TVM["TVM (Relay)"];
	        XLA[XLA];
	        HLO["HLO (MLIR Dialect)"];
	        TritonIR[Triton IR];
	        TorchScript;
	        nvFuser;
	    end
	
	    subgraph 底层基础设施
	        LLVM[LLVM IR];
	        PTX;
	        EP[Execution Provider];
	        CUDA[CUDA];
	        ROCm;
	        NPU;
	        NPUHD[NPU Hardware];
	        CUDAEP[CUDA EP];
	        ROCmEP[ROC EP];
	        NPUEP[NPU EP];
	        GPU[GPU Hardware];
	    end
	
	    TF -- "图执行/Eager执行" --> XLA;
	    TF -- "模型导出" --> ONNX;
	    JAX -- "JIT 编译" --> XLA;
	    PT -- "模型导出" --> ONNX;
	    PT -- "TorchScript 编译" --> TVM;
	    PT -- "Kernel 生成" --> Triton;
	    PT -- "图编译" --> nvFuser;
	    PT -- "TorchScript 表示" --> TorchScript;
	
	    XLA -- "图优化" --> HLO;
	    HLO -- "Lowering" --> MLIR;
	    ONNX -- "onnx-mlir" --> MLIR;
	    ONNX -- "run" --> ORT;
	    ORT -- "leverage" --> EP;
	    MLIR -- "标准 Dialect" --> LLVM;
	    TVM -- "图优化 & Lowering" --> LLVM;
	
	    Triton -- "Kernel IR" --> TritonIR;
	    TritonIR -- "Lowering" --> LLVM;
	    nvFuser -- "图 IR" --> TorchScript;
	    TorchScript -- "Lowering & 编译" --> LLVM;
	  
	    LLVM -- "目标后端" --> NPU;
	    NPU --> NPUHD;
	
	    LLVM -- "目标后端" --> ROCm;
	
	    LLVM -- "代码生成" --> PTX;
	    PTX -- "汇编 SASS" --> CUDA;
	    CUDA -- "驱动 & 运行时" --> GPU;
	
	    LLVM -- "目标后端" --> GPU;
	  
	    EP -- "use" --> NPUEP;
	    EP -- "use" --> ROCmEP;
	    EP -- "use" --> CUDAEP;
	
	    NPUEP --> NPU;
	    ROCmEP --> ROCm;
	    CUDAEP --> CUDA;
	  
	    ROCm -- "驱动 & 运行时" --> GPU;

技术栈概览

该技术栈可以大致分为三个主要层次:底层基础设施、中间层技术和高层框架。

1. 底层基础设施

这一层是整个计算能力的基石,直接与硬件及其编程接口相关。

CUDA

  • NVIDIA 的并行计算平台和编程模型。
  • 提供了 C/C++ 语言扩展以及丰富的库,允许开发者直接利用 NVIDIA GPU 的并行计算能力。
  • 是 NVIDIA 硬件生态的核心。

ROCm

  • AMD 的开源 GPU 计算平台。
  • 旨在提供与 CUDA 竞争的功能,支持多种 AMD GPU 架构。

NPU (Neural Processing Unit)

  • 专为加速神经网络计算设计的专用处理器。
  • 代表产品包括华为昇腾、寒武纪、Intel Habana 等。
  • 通常提供高能效比和特定的指令集,需要专门的软件栈支持(如通过 MLIR、TVM、XLA 等后端)。

LLVM (Low Level Virtual Machine)

  • 一个成熟的开源编译器基础设施项目。
  • 提供了一套强大的中间表示(IR)和用于构建编译器前端及后端的工具。
  • 是许多现代编译器(包括 GPU 编译器)的核心组件,负责代码优化和跨平台代码生成。

2. 中间层技术

这一层是连接高层抽象与底层硬件的关键,负责计算图的表示、优化和向底层代码的转换。

MLIR (Multi-Level Intermediate Representation)

  • 一个灵活、可扩展的编译器基础设施,支持表示多种抽象层次的程序。
  • 通过引入方言(Dialects)机制,能够统一表示从高层计算图到低层硬件指令等不同领域的IR。
  • 正成为连接深度学习框架、特定领域加速器(如 NPU)和传统编译器基础设施(如 LLVM)的重要桥梁。

XLA (Accelerated Linear Algebra)

  • TensorFlow 和 JAX 使用的特定领域编译器。
  • 专注于数值计算图的优化,例如操作融合、内存优化等。
  • 其高级中间表示 HLO (High Level Optimizer) 常被用作 MLIR 的一个方言(mhlo),实现与 MLIR 生态的融合。

Triton

  • 一种高级编程语言和编译器,用于编写高效的 GPU Kernel。
  • 旨在简化 GPU 编程,通过自动化优化和代码生成实现媲美甚至超越手写 CUDA 的性能。
  • 与 MLIR 有着紧密的集成。

TVM (Apache Tensor Virtual Machine)

  • 一个端到端、自动优化的深度学习编译器堆栈。
  • 包含多级中间表示(如 Relay 用于图级别,TensorIR/TE 用于计算级别)和自动化调优机制。
  • 目标是为多种硬件后端(包括 GPU、CPU、NPU 等)生成优化的代码。

ONNX (Open Neural Network Exchange)

  • 一种用于表示深度学习模型的开放格式。
  • 旨在促进不同深度学习框架之间的模型互操作性。
  • 通常用作模型导出、交换和部署的标准格式,通过 ONNX Runtime (ORT) 或与其他编译器(如 ONNX-MLIR)结合实现跨平台推理。

3. 高层框架

这是开发者构建、训练和部署模型的常用工具。

TensorFlow

  • 谷歌开发的开源机器学习框架,拥有庞大的社区和生态系统。
  • 支持静态图和动态图执行,通过 XLA 实现性能优化。

PyTorch

  • Meta 开发的机器学习框架,以其灵活的动态图和易用性受到研究界的欢迎。
  • 通过 TorchScript、nvFuser、与 Triton 集成等方式不断提升性能和部署能力。

JAX

  • 谷歌开发的基于 Autograd 和 XLA 的数值计算库。
  • 强调函数式编程风格,通过 XLA 提供了强大的 JIT 编译、自动微分和自动并行化能力。

技术栈连接与转换流程

高层框架中定义的计算(通常表示为计算图)需要经过一系列转换和优化才能在底层硬件上高效执行。

  1. 从框架到中间表示:

    • TensorFlow 和 JAX 通常通过 XLA 将其计算图转换为 HLO。
    • PyTorch 可以导出为 ONNX 格式,或通过 TorchScript 表示。
    • 这些框架的中间表示(如 HLO、TorchScript)或标准交换格式(ONNX)为后续的优化和代码生成奠定基础。
  2. 中间表示的优化与转换:

    • HLO 可以在 XLA 中进行图级别的优化,然后可能被转换为 MLIR 的 mhlo 方言,以便利用 MLIR 的基础设施进行进一步处理。
    • ONNX 模型可以通过 ONNX-MLIR 等工具转换为 MLIR 表示,或由 ONNX Runtime 直接加载,并利用其 Execution Provider (EP) 机制调用特定硬件后端。
    • TVM 从框架 IR 或 ONNX 接收模型,进行图优化和低层表示转换(如到 TensorIR),最终生成代码。
    • Triton 将其高级语言表示转换为 Triton IR,然后 Lowering 到 LLVM IR。
    • MLIR 作为统一平台,可以接收来自不同源(HLO、ONNX、Triton 等)的表示,并进行跨领域、跨层次的优化,最终转换为标准的 MLIR 方言或直接 Lowering 到 LLVM IR。
  3. 从中间表示到硬件:

    • LLVM IR 是通用的低级中间表示,可以被 LLVM 的后端编译为针对特定目标硬件(x86 CPU, ARM CPU, NVIDIA GPU (通过 PTX), AMD GPU, NPU 等)的机器码。
    • 对于 NVIDIA GPU,LLVM 通常生成 PTX (Parallel Thread Execution),这是一种虚拟指令集,再由 NVIDIA 驱动编译为 SASS (硬件原生指令集),通过 CUDA 运行时加载执行。
    • ROCm 和 NPU 等硬件后端也通常通过各自的编译器或 LLVM 后端接收中间表示并生成目标代码。

实际应用案例

这个技术栈的协同工作在实际中带来了显著的性能提升和开发便利性。

深度学习模型优化

  • 高性能 Kernel: 利用 Triton 等工具可以直接编写或生成针对特定操作(如注意力机制中的矩阵乘法)高度优化的 GPU Kernel,取代通用实现,大幅提升模型性能。
  • 图级优化: XLA 和 TVM 等编译器通过操作融合、内存重排等方式优化整个计算图的执行,减少开销。
  • 跨平台部署: ONNX 使得在不同框架训练的模型能够在各种推理引擎和硬件上高效运行,简化了部署流程。

科学计算应用

  • 加速模拟: 利用 CUDA/ROCm 直接编写并行算法,或通过 JAX/TensorFlow 等框架结合 XLA 编译能力,将分子动力学、流体模拟等计算密集型任务加速数十到数百倍。
  • 代码生成效率: LLVM 和 MLIR 等编译器基础设施提供了强大的优化能力,能将高级数学表达式或领域特定语言高效地转换为 GPU/NPU 上的并行代码。

参考文献

官方文档

  1. MLIR 官方文档
  2. CUDA 编程指南
  3. LLVM 文档
  4. TensorFlow XLA 文档
  5. PyTorch 2.0 文档
  6. JAX 文档
  7. TVM 文档
  8. ONNX 文档
  9. Triton 文档

研究论文

  1. Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations
  2. TVM: An Automated End-to-End Optimizing Compiler for Deep Learning
  3. MLIR: A Compiler Infrastructure for the End of Moore’s Law
  4. XLA: TensorFlow, Compiled
  5. JAX: Autograd and XLA

技术博客

  1. NVIDIA Developer Blog
  2. Google AI Blog
  3. PyTorch Blog
  4. TensorFlow Blog
  5. LLVM Blog

教程资源

  1. CUDA C++ Programming Guide
  2. MLIR Tutorial
  3. TVM Tutorial
  4. PyTorch CUDA Tutorial
  5. TensorFlow GPU Guide

社区资源

  1. NVIDIA Developer Forums
  2. PyTorch Discussion Forums
  3. TensorFlow Discussion
  4. LLVM Discourse
  5. MLIR Discussion

性能优化指南

  1. CUDA Best Practices Guide
  2. PyTorch Performance Tuning Guide
  3. TensorFlow Performance Guide
  4. TVM Performance Tuning
  5. MLIR Performance Optimization

工具和库

  1. NVIDIA Nsight
  2. CUDA Profiler
  3. PyTorch Profiler
  4. TensorFlow Profiler
  5. LLVM Opt

会议和演讲

  1. GTC Conference
  2. MLIR Conference
  3. PyTorch Conference
  4. TensorFlow Dev Summit
  5. LLVM Developers’ Meeting