DEV Community

Cover image for torch.export
제민욱
제민욱

Posted on • Edited on

torch.export

Summary of torch.export

torch.export.export() performs ahead-of-time (AOT) compilation on a Python callable (e.g., torch.nn.Module) with a forward() method, producing an ExportedProgram—a sound, functional graph of tensor computations.

If you're confusd between torch.compile and torch.export, check out this

Internally

torch.export() internally uses:

  • TorchDynamo: Traces PyTorch graphs at the bytecode level for broader code coverage.
  • AOT Autograd: Functionalizes the graph and lowers it to ATen operators.
  • torch.fx.graph: Provides the graph representation for transformations.

Comparison Table

Component Role
TorchDynamo Bytecode-level tracing
AOT Autograd Graph functionalization, ATen lowering
torch.fx.graph Graph representation, transformations

Example

import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args
)  # 🔥🔥
Enter fullscreen mode Exit fullscreen mode
ExportedProgram:  🔥🔥
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
            # code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)

            # code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)

            # code: return a + b
            add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos)
            return (add,)

    Graph signature: 🔥🔥
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='y'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='add'),
                    target=None
                )
            ]
        )
    Range constraints: {}
Enter fullscreen mode Exit fullscreen mode

2. ExportedProgram

ExportedProgram consists of two main components:

  1. GraphModule
  2. Graph Signature

ExportedProgram::GraphModule

GraphModule compiles every instruction into low-level ATen operations.

What is ATen?

ATen is fundamentally a tensor library, on top of which almost all other Python and C++ interfaces in PyTorch are built. It provides a core Tensor class, on which many hundreds of operations are defined. Most of these operations have both CPU and GPU implementations

Additionally, all arguments are lifted into the parameters of the forward() method.

ExportedProgram::Graph Signature

The graph signature is functional, meaning it has no side effects and will always produce the same output given the same input.


3. Strict vs Non-Strict Modes

Both modes eventually compile the model to a torch.fx.Graph.

  • Non-Strict Mode:

    • Requires Python runtime for compilation
    • Runs in eager mode
    • Uses tracing with ProxyTensor
  • Strict Mode:

    • torch.Dynamo inspects bytecode and compiles it
    • Potentially generates IR Graph with cuda.graph

4. export_for_training()

Use export_for_training() for training with non-functional ops (e.g., BatchNorm with state updates). It creates a generic IR with all ATen operators for eager PyTorch Autograd, ideal for cases like PT2 Quantization. It can be converted to inference IR via run_decompositions().


5. Dynamism

Dynamic shapes are supported using Dim() to generate range constraints during compilation.

from torch.export import Dim

batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}, "y": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
    Mod(), args=example_args, dynamic_shapes=dynamic_shapes
)
Enter fullscreen mode Exit fullscreen mode
ExportedProgram:
    class GraphModule(torch.nn.Module):
        ...

    Graph signature:
        ...
    Range constraints: {batch: VR[0, int_oo]}
Enter fullscreen mode Exit fullscreen mode

6. Serialization

  • torch.export.save(): Saves to *.pt2 format
  • torch.export.load(): Loads the saved model

7. Specialization

Contrast with generalization.

Certain values (e.g., input shapes, Python primitives, container structures) are fixed as constants during export.

Effect: Static values enable constant folding operations with all static inputs can be precomputed and removed from runtime.

torch.export fixes certain values as static constants in the graph:

  • Tensor Shapes: Static by default unless marked dynamic with dynamic_shapes.
  • Python Primitives: int, float, etc., are hardcoded unless using SymInt.
  • Python Containers: Lists, dictionaries, etc., have fixed structures at export.

Static inputs lead to precomputed results, simplifying the graph.


8. Limitations

Limitation::Graph Breaks

torch.export may fail on untraceable code, requiring rewrites or extra info (unlike torch.compile’s fallback). torch.Dynamo reduces rewrites; use ExportDB or non-strict mode for help.

Limitation::Missing Fake Kernels

Tracing needs FakeTensor kernels for shape inference; missing kernels cause failures or errors.


Top comments (0)