Script mode is called through torch.jit.trace
or torch.jit.script
. Both functions are two different ways of converting python code to TorchScript.
torch.jit.trace
passes a specific input (usually a tensor, we need to provide an input) to a PyTorch model, and torch.jit.trace
will trace The calculation process of this input in the model is then converted into a Torch script. This approach works well for models that can be fully defined in a static graph, such as neural networks with fixed input sizes. Typically used to convert pre-trained models.
torch.jit.script
directly converts a Python function (or a Python module) into a Torch script through python syntax rules and compilation. torch.jit.script
is more suitable for dynamic graph models, whose structure and inputs can change at runtime. For example, for RNN or some models with variable sequence length, it will be more convenient to use torch.jit.script
.
In general, use torch.jit.trace
instead of torch.jit.script
.
In terms of model deployment, onnx is heavily used. The process of exporting onnx is also the process of model performing torch.jit.trace
, so here we introduce the torch trace in a little more detail.
In order to write the model so that it can be traced by JIT, some compromises need to be made to the code, for example:
1. If there is a submodule of DataParallel
in the model, or the model converts tensors to numpy arrays, or calls opencv functions, etc., in this case, the model is not a correct one on a single device. Above, the graph is correctly connected. In this case, whether using torch.jit.script
or torch.jit.trace
, the correct TorchScript cannot be traced.
2. The input and output of the model should be of type Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]
, and the values in dict should be of the same type. However, the input and output of the intermediate sub-module of the model can be of any type, such as dicts of Any, classes, kwargs and those supported by Python. The restrictions on model input and output types are relatively easy to meet. In Detectron2, there are similar examples:
outputs = model(inputs) # inputs and outputs are python types, such as dictsor classes # torch.jit.trace(model, inputs) # Failed! trace only supports Union[Tensor,Tuple[Tensor], Dict[str, Tensor]] types adapter = TracingAdapter(model, inputs) # Use Adapter to wrap model inputs into types supported by trace traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Now successful with trace # The output of Traced model can only be of tuple tensors type: flattened_outputs = traced(*adapter.flattened_inputs) # Then convert it to the desired output type through the adapter new_outputs = adapter.outputs_schema(flattened_outputs)
3. Some numerical type problems. For example, the following code snippet
import torch a=torch.tensor([1,2]) print(type(a.size(0))) print(type(a.size()[0])) print(type(a.shape[0]))
In eager mode, the types of these return values are all int. The output of the above code is
<class 'int'> <class 'int'> <class 'int'>
But in trace mode, the return value types of these expressions are all Tensor
types. Therefore, some expressions are used improperly. If during the trace process, the return value type of some shape expressions is int, it may cause this piece of code to not be traced. In the code, you can use torch.jit.is_tracing
to check whether this code is executed in trace mode.
4. Due to the dynamic control flow, the model is not fully traced. Look at the following example:
import torch def f(x): return torch.sqrt(x) if x.sum() > 0 else torch.square(x) m = torch.jit.trace(f, torch.tensor(3)) print(m.code)
The output is
def f(x: Tensor) -> Tensor: return torch.sqrt(x)
You can see that the model after trace only retains one branch. Therefore, due to the dynamic control flow caused by input, errors are prone to occur after tracing.
In this case, we can use torch.jit.script
to convert TorchScript.
import torch def f(x): return torch.sqrt(x) if x.sum() > 0 else torch.square(x) m = torch.jit.script(f) print(m.code)
The output is
def f(x: Tensor) -> Tensor: if bool(torch.gt(torch.sum(x), 0)): _0 = torch.sqrt(x) else: _0 = torch.square(x) return _0
In most cases, we should use torch.jit.trace
, but in the case of dynamic control flow like the above, we can mix and use torch.jit.trace
and torch.jit.script
, which will be explained later in this article.
In addition, in some blogs, the definition of dynamic control flow is wrong. For example, if x[0] == 4: x + = 1
is dynamic control flow, but
model: nn.Sequential = ... for m in model: x = m(x)
as well as
class A(nn.Module): backbone: nn.Module head: Option[nn.Module] def forward(self, x): x = self.backbone(x) if self.head is not None: x = self.head(x) return x
Neither is dynamic control flow. Dynamic control flow is the execution of different branches due to the judgment of input conditions.
5. During the trace process, change the trace variable into a constant. Look at the following example
import torch a, b = torch.rand(1), torch.rand(2) def f1(x): return torch.arange(x.shape[0]) def f2(x): return torch.arange(len(x)) print(torch.jit.trace(f1, a)(b)) # Output: tensor([0, 1]) # You can see that the model after trace is no problem. Here we use variable a as the example input of torch.jit.trace, and then use the converted TorchScript variable b as input. Under normal circumstances, the shape of b is 2-dimensional. , so the return value is tensor([0,1]) which is correct print(torch.jit.trace(f2, a)(b)) # Output: # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. #tensor([0]) # You can see that this output result is wrong. The dimension of b is 2 dimensions, and the output should be tensor([0,1]). Here torch.jit.trace also prompts that using len may cause incorrect trace. # Let’s print the difference between the two print(torch.jit.trace(f1, a).code, '\ ',torch.jit.trace(f2, a).code) # output # def f1(x: Tensor) -> Tensor: # _0 = ops.prim.NumToTensor(torch.size(x, 0)) # _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False) # return _1 # def f2(x: Tensor) -> Tensor: # _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False) # return _0 # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. # It can be seen from the trace code that when using x.shape, there is a variable value of shape in the code after trace, but if you use len directly, in the code after trace, directly 1
The process of exporting onnx is also the process of torch.jit.trace. When exporting onnx, we sometimes encounter
TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
With such a prompt message, you should check the code at this time to see if the variables may be treated as constants during the trace process, which may cause the exported onnx accuracy to be abnormal.
In addition to len
which will cause trace errors, several others will also cause trace problems:
-
.item()
will convert tensors to int/float during the trace process -
Any code that converts torch types to numpy/python types
-
Some problematic operators, such as advanced indexing
- torch.jit.trace will not take effect on the incoming device
import torch def f(x): return torch.arange(x.shape[0], device=x.device) m = torch.jit.trace(f, torch.tensor([3])) print(m.code) # output # def f(x: Tensor) -> Tensor: # _0 = ops.prim.NumToTensor(torch.size(x, 0)) # _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False) # return _1 print(m(torch.tensor([3]).cuda()).device) # Output: device(type='cpu')
Trace will not take effect on the incoming cuda device.
In order to ensure that the trace is correct, we can use the following methods to try to ensure that the model after trace will not go wrong:
1. Pay attention to warnings information. Something like thisTracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
TraceWarnings information, which may cause the model results to be incorrect, but it is only a warning level.
2. Do unit testing. It is necessary to verify whether the model output of eager mode is consistent with the model output after trace.
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
3. Avoid some special situations. For example the following code
if x.numel() > 0: output = self.layers(x) else: output = torch.zeros((0, C, H, W)) # will create an empty output
Avoid some special situations such as empty input and output.
4. Pay attention to the use of shape. As mentioned earlier, tensor.size()
will return Tensor
type data during the trace process, and the Tensor
type will be added during the calculation process. In the calculation graph, you should avoid converting the Tensor type shape into a constant. Mainly pay attention to the following two points:
- Use
torch.size(0)
instead oflen(tensor)
becausetorch.size(0)
returnsTensor
,len(tensor)
returnsint
. For custom classes, implement a.size
method or use the.__len__()
method instead oflen()
, such as this example - Do not use
int()
ortorch.as_tensor
to convert the type of size, because these operations will also be treated as constants.
5. Mix tracing and scripting methods. You can use torch.jit.script
to convert some small code fragments that torch.jit.trace
cannot handle. Mixing tracing and scripting can basically solve all problems.
Mix tracing and scripting
Both tracing and scripting have their problems, and a mixture can solve most of them. However, in order to minimize the negative impact on code quality, in most cases, torch.jit.trace
should be used, and torch.jit.script
should be used only when necessary. .
1. When using torch.jit.trace
, use the @script_if_tracing
decorator to allow the decorated function to be compiled using scripting.
def forward(self, ...): # ... some forward logic @torch.jit.script_if_tracing def _inner_impl(x, y, z, flag: bool): # use control flow, etc. return... output = _inner_impl(x, y, z, flag) # ... other forward logic
However, when using @script_if_tracing
, you need to ensure that there are no pytorch modules in the function. If there are, some modifications need to be made, such as the following:
# Because there is self.layers() in the code, it is a pytorch module, so @script_if_tracing cannot be used if x.numel() > 0: x = preprocess(x) output = self.layers(x) else: #Create empty outputs output = torch.zeros(...)
The following modifications need to be made here:
# You need to move self.layers out of the if judgment. In this case, you can use @script_if_tracing if x.numel() > 0: x = preprocess(x) else: #Create empty inputs x = torch.zeros(...) # You need to modify self.layers() to support empty input, or add the original conditional judgment to self.layers output = self.layers(x)
2. Merge the results of multiple tracings
Using the model generated by torch.jit.script
has two advantages over using torch.jit.trace
:
- You can use conditional control flow. For example, using a bool value in the model to control the forward flow is not supported in traced modules.
- Using the traced module, there can only be one forward() function, but using the scripted module, you can have multiple forward calculation functions.
class Detector(nn.Module): do_keypoint: bool def forward(self, img): box = self.predict_boxes(img) if self.do_keypoint: kpts = self.predict_keypoint(img, box) @torch.jit.export def predict_boxes(self, img): pass @torch.jit.export def predict_keypoint(self, img, box): pass
For this kind of control flow with bool values, in addition to using script, you can also trace multiple times and then merge the results.
det1 = torch.jit.trace(Detector(do_keypoint=True), inputs) det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
Then copy their weights and merge the results of the two traces
det2.submodule.weight = det1.submodule.weight class Wrapper(nn.ModuleList): def forward(self, img, do_keypoint: bool): if do_keypoint: return self[0](img) else: return self[1](img) exported = torch.jit.script(Wrapper([det1, det2]))
Performance of trace and script
Tracing will always generate the same or simpler calculation graph than scripting, so the performance will be better. Because scripting will completely express the logic of python code, even some unnecessary code will be expressed truthfully. For example the following example:
class A(nn.Module): def forward(self, x1, x2, x3): z = [0, 1, 2] xs = [x1, x2, x3] for k in z: x1 + = xs[k] return x1 model = A() print(torch.jit.script(model).code) # def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor: # z = [0, 1, 2] # xs = [x1, x2, x3] # x10 = x1 # for _0 in range(torch.len(z)): # k = z[_0] # x10 = torch.add_(x10, xs[k]) # return x10 print(torch.jit.trace(model, [torch.tensor(1)] * 3).code) # def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor: # x10 = torch.add_(x1, x1) # x11 = torch.add_(x10, x2) # return torch.add_(x11, x3)
Summary
Tracing has obvious limitations: most of this article is about tracing’s limitations and how to work around them. In fact, this is exactly the advantage of tracing: it has clear limitations (and solutions), so you can reason about whether it works.
Instead, scripting is more like a black box: no one knows whether it will work before trying it. The article doesn’t mention any tricks on how to fix scripting: there are a lot of tricks, but it’s not worth your time to dig into and fix a black box.
Both tracing and scripting will affect the way the code is written, but because tracing has clear requirements, some modifications to our original code will not be too serious:
- It restricts input/output formats, but only to the outermost module. (As mentioned above, this problem can be solved with a wrapper).
- It requires some code modifications to make it universal (such as adding some scripting when tracing), but these modifications only involve the internal implementation of the affected modules, not their interfaces.