Summary of torch.export
torch.export.export()
performs ahead-of-time (AOT) compilation on a Python callable (e.g., torch.nn.Module
) with a forward()
method, producing an ExportedProgram
—a sound, functional graph of tensor computations.
If you're confusd between
torch.compile
andtorch.export
, check out this
Internally
torch.export()
internally uses:
- TorchDynamo: Traces PyTorch graphs at the bytecode level for broader code coverage.
-
AOT Autograd: Functionalizes the graph and lowers it to
ATen
operators. - torch.fx.graph: Provides the graph representation for transformations.
Comparison Table
Component | Role |
---|---|
TorchDynamo | Bytecode-level tracing |
AOT Autograd | Graph functionalization, ATen lowering |
torch.fx.graph | Graph representation, transformations |
Example
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
) # 🔥🔥
ExportedProgram: 🔥🔥
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)
# code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)
# code: return a + b
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos)
return (add,)
Graph signature: 🔥🔥
ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None
),
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='y'),
target=None,
persistent=None
)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='add'),
target=None
)
]
)
Range constraints: {}
2. ExportedProgram
ExportedProgram
consists of two main components:
- GraphModule
- Graph Signature
ExportedProgram::GraphModule
GraphModule
compiles every instruction into low-level ATen operations.
What is ATen?
ATen is fundamentally a tensor library, on top of which almost all other Python and C++ interfaces in PyTorch are built. It provides a core Tensor class, on which many hundreds of operations are defined. Most of these operations have both CPU and GPU implementations
Additionally, all arguments are lifted into the parameters of the forward() method.
ExportedProgram::Graph Signature
The graph signature is functional, meaning it has no side effects and will always produce the same output given the same input.
3. Strict vs Non-Strict Modes
Both modes eventually compile the model to a torch.fx.Graph
.
-
Non-Strict Mode:
- Requires Python runtime for compilation
- Runs in eager mode
- Uses tracing with
ProxyTensor
-
Strict Mode:
-
torch.Dynamo
inspects bytecode and compiles it - Potentially generates IR Graph with
cuda.graph
-
4. export_for_training()
Use export_for_training()
for training with non-functional ops (e.g., BatchNorm
with state updates). It creates a generic IR with all ATen operators for eager PyTorch Autograd, ideal for cases like PT2 Quantization. It can be converted to inference IR via run_decompositions()
.
5. Dynamism
Dynamic shapes are supported using Dim()
to generate range constraints during compilation.
from torch.export import Dim
batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}, "y": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args, dynamic_shapes=dynamic_shapes
)
ExportedProgram:
class GraphModule(torch.nn.Module):
...
Graph signature:
...
Range constraints: {batch: VR[0, int_oo]}
6. Serialization
-
torch.export.save()
: Saves to *.pt2 format -
torch.export.load()
: Loads the saved model
7. Specialization
Contrast with generalization.
Certain values (e.g., input shapes, Python primitives, container structures) are fixed as constants during export.
Effect: Static values enable constant folding operations with all static inputs can be precomputed and removed from runtime.
torch.export
fixes certain values as static constants in the graph:
-
Tensor Shapes: Static by default unless marked dynamic with
dynamic_shapes
. -
Python Primitives:
int
,float
, etc., are hardcoded unless usingSymInt
. - Python Containers: Lists, dictionaries, etc., have fixed structures at export.
Static inputs lead to precomputed results, simplifying the graph.
8. Limitations
Limitation::Graph Breaks
torch.export
may fail on untraceable code, requiring rewrites or extra info (unlike torch.compile
’s fallback). torch.Dynamo
reduces rewrites; use ExportDB or non-strict mode for help.
Limitation::Missing Fake Kernels
Tracing needs FakeTensor
kernels for shape inference; missing kernels cause failures or errors.
Top comments (0)