WARNING
该页面尚未翻译。 以下内容为英文原版。
tensorplay.stax
The stax module provides static graph optimization and acceleration for TensorPlay. It allows users to capture a sequence of operations and compile them into an optimized execution graph, similar to torch.compile.
Key Features
- Graph Capturing: Intercepts
p10::Tensoroperations to build a static representation of the computation. - Operator Fusion: Automatically fuses common operation sequences (e.g.,
mul+add) to reduce memory bandwidth and kernel launch overhead. - AOT/JIT Compilation: Supports both Ahead-of-Time and Just-in-Time compilation strategies.
Classes
class ProxyTensor [source]
python
ProxyTensor(tracer, node, shape=None, dtype=None)Methods
__init__(self, tracer, node, shape=None, dtype=None) [source]
Initialize self. See help(type(self)) for accurate signature.
class PythonGraphExecutor [source]
python
PythonGraphExecutor(tracer)Methods
__init__(self, tracer) [source]
Initialize self. See help(type(self)) for accurate signature.
record_node(self, kind, args, output_val) [source]
run(self, env) [source]
class Tracer [source]
python
Tracer()Methods
__init__(self) [source]
Initialize self. See help(type(self)) for accurate signature.
create_input(self, tensor) [source]
record(self, kind, *args) [source]
Functions
compile() [source]
python
compile(func)fused_mul_add_impl() [source]
python
fused_mul_add_impl(*args, **kwargs)patch_tensorplay()
python
patch_tensorplay(tracer)robust_add() [source]
python
robust_add(a, b)robust_div() [source]
python
robust_div(a, b)robust_mul() [source]
python
robust_mul(a, b)robust_pow() [source]
python
robust_pow(a, b)robust_sub() [source]
python
robust_sub(a, b)