torch.jit.trace and torch.jit.script

Script mode is called through torch.jit.trace or torch.jit.script. Both functions are two different ways of converting python code to TorchScript.

torch.jit.trace passes a specific input (usually a tensor, we need to provide an input) to a PyTorch model, and torch.jit.trace will trace The calculation process of this input in the model is then converted into a Torch script. This approach works well for models that can be fully defined in a static graph, such as neural networks with fixed input sizes. Typically used to convert pre-trained models.

torch.jit.script directly converts a Python function (or a Python module) into a Torch script through python syntax rules and compilation. torch.jit.script is more suitable for dynamic graph models, whose structure and inputs can change at runtime. For example, for RNN or some models with variable sequence length, it will be more convenient to use torch.jit.script.

In general, use torch.jit.trace instead of torch.jit.script.

In terms of model deployment, onnx is heavily used. The process of exporting onnx is also the process of model performing torch.jit.trace, so here we introduce the torch trace in a little more detail.

In order to write the model so that it can be traced by JIT, some compromises need to be made to the code, for example:

1. If there is a submodule of DataParallel in the model, or the model converts tensors to numpy arrays, or calls opencv functions, etc., in this case, the model is not a correct one on a single device. Above, the graph is correctly connected. In this case, whether using torch.jit.script or torch.jit.trace, the correct TorchScript cannot be traced.

2. The input and output of the model should be of type Union[Tensor, Tuple[Tensor], Dict[str, Tensor]], and the values in dict should be of the same type. However, the input and output of the intermediate sub-module of the model can be of any type, such as dicts of Any, classes, kwargs and those supported by Python. The restrictions on model input and output types are relatively easy to meet. In Detectron2, there are similar examples:

outputs = model(inputs) # inputs and outputs are python types, such as dictsor classes
# torch.jit.trace(model, inputs) # Failed! trace only supports Union[Tensor,Tuple[Tensor], Dict[str, Tensor]] types
adapter = TracingAdapter(model, inputs) # Use Adapter to wrap model inputs into types supported by trace
traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Now successful with trace

# The output of Traced model can only be of tuple tensors type:
flattened_outputs = traced(*adapter.flattened_inputs)
# Then convert it to the desired output type through the adapter
new_outputs = adapter.outputs_schema(flattened_outputs)

3. Some numerical type problems. For example, the following code snippet

import torch
a=torch.tensor([1,2])
print(type(a.size(0)))
print(type(a.size()[0]))
print(type(a.shape[0]))

In eager mode, the types of these return values are all int. The output of the above code is

<class 'int'>
<class 'int'>
<class 'int'>

But in trace mode, the return value types of these expressions are all Tensor types. Therefore, some expressions are used improperly. If during the trace process, the return value type of some shape expressions is int, it may cause this piece of code to not be traced. In the code, you can use torch.jit.is_tracing to check whether this code is executed in trace mode.

4. Due to the dynamic control flow, the model is not fully traced. Look at the following example:

import torch

def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)

m = torch.jit.trace(f, torch.tensor(3))
print(m.code)

The output is

def f(x: Tensor) -> Tensor:
  return torch.sqrt(x)

You can see that the model after trace only retains one branch. Therefore, due to the dynamic control flow caused by input, errors are prone to occur after tracing.

In this case, we can use torch.jit.script to convert TorchScript.

import torch

def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)

m = torch.jit.script(f)
print(m.code)

The output is

def f(x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = torch.sqrt(x)
  else:
    _0 = torch.square(x)
  return _0

In most cases, we should use torch.jit.trace, but in the case of dynamic control flow like the above, we can mix and use torch.jit.trace and torch.jit.script, which will be explained later in this article.

In addition, in some blogs, the definition of dynamic control flow is wrong. For example, if x[0] == 4: x + = 1 is dynamic control flow, but

model: nn.Sequential = ...
for m in model:
  x = m(x)

as well as

class A(nn.Module):
  backbone: nn.Module
  head: Option[nn.Module]
  def forward(self, x):
    x = self.backbone(x)
    if self.head is not None:
        x = self.head(x)
    return x

Neither is dynamic control flow. Dynamic control flow is the execution of different branches due to the judgment of input conditions.

5. During the trace process, change the trace variable into a constant. Look at the following example

import torch
a, b = torch.rand(1), torch.rand(2)

def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))

print(torch.jit.trace(f1, a)(b))
# Output: tensor([0, 1])
# You can see that the model after trace is no problem. Here we use variable a as the example input of torch.jit.trace, and then use the converted TorchScript variable b as input. Under normal circumstances, the shape of b is 2-dimensional. , so the return value is tensor([0,1]) which is correct

print(torch.jit.trace(f2, a)(b))
# Output:
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
#tensor([0])
# You can see that this output result is wrong. The dimension of b is 2 dimensions, and the output should be tensor([0,1]). Here torch.jit.trace also prompts that using len may cause incorrect trace.

# Let’s print the difference between the two
print(torch.jit.trace(f1, a).code, '\
',torch.jit.trace(f2, a).code)
# output
# def f1(x: Tensor) -> Tensor:
# _0 = ops.prim.NumToTensor(torch.size(x, 0))
# _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
# return _1
 
# def f2(x: Tensor) -> Tensor:
# _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
# return _0

# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

# It can be seen from the trace code that when using x.shape, there is a variable value of shape in the code after trace, but if you use len directly, in the code after trace, directly 1

The process of exporting onnx is also the process of torch.jit.trace. When exporting onnx, we sometimes encounter

TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

With such a prompt message, you should check the code at this time to see if the variables may be treated as constants during the trace process, which may cause the exported onnx accuracy to be abnormal.

In addition to len which will cause trace errors, several others will also cause trace problems:

  • .item() will convert tensors to int/float during the trace process

  • Any code that converts torch types to numpy/python types

  • Some problematic operators, such as advanced indexing

  1. torch.jit.trace will not take effect on the incoming device
import torch
def f(x):
    return torch.arange(x.shape[0], device=x.device)
m = torch.jit.trace(f, torch.tensor([3]))
print(m.code)
# output
# def f(x: Tensor) -> Tensor:
# _0 = ops.prim.NumToTensor(torch.size(x, 0))
# _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
# return _1
print(m(torch.tensor([3]).cuda()).device)
# Output: device(type='cpu')

Trace will not take effect on the incoming cuda device.

In order to ensure that the trace is correct, we can use the following methods to try to ensure that the model after trace will not go wrong:

1. Pay attention to warnings information. Something like thisTracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. TraceWarnings information, which may cause the model results to be incorrect, but it is only a warning level.

2. Do unit testing. It is necessary to verify whether the model output of eager mode is consistent with the model output after trace.

assert allclose(torch.jit.trace(model, input1)(input2), model(input2))

3. Avoid some special situations. For example the following code

if x.numel() > 0:
  output = self.layers(x)
else:
  output = torch.zeros((0, C, H, W)) # will create an empty output

Avoid some special situations such as empty input and output.

4. Pay attention to the use of shape. As mentioned earlier, tensor.size() will return Tensor type data during the trace process, and the Tensor type will be added during the calculation process. In the calculation graph, you should avoid converting the Tensor type shape into a constant. Mainly pay attention to the following two points:

  • Use torch.size(0) instead of len(tensor) because torch.size(0) returns Tensor, len(tensor) returns int. For custom classes, implement a .size method or use the .__len__() method instead of len(), such as this example
  • Do not use int() or torch.as_tensor to convert the type of size, because these operations will also be treated as constants.

5. Mix tracing and scripting methods. You can use torch.jit.script to convert some small code fragments that torch.jit.trace cannot handle. Mixing tracing and scripting can basically solve all problems.

Mix tracing and scripting

Both tracing and scripting have their problems, and a mixture can solve most of them. However, in order to minimize the negative impact on code quality, in most cases, torch.jit.trace should be used, and torch.jit.script should be used only when necessary. .

1. When using torch.jit.trace, use the @script_if_tracing decorator to allow the decorated function to be compiled using scripting.

def forward(self, ...):
  # ... some forward logic
  @torch.jit.script_if_tracing
  def _inner_impl(x, y, z, flag: bool):
      # use control flow, etc.
      return...
  output = _inner_impl(x, y, z, flag)
  # ... other forward logic

However, when using @script_if_tracing, you need to ensure that there are no pytorch modules in the function. If there are, some modifications need to be made, such as the following:

# Because there is self.layers() in the code, it is a pytorch module, so @script_if_tracing cannot be used
if x.numel() > 0:
  x = preprocess(x)
  output = self.layers(x)
else:
  #Create empty outputs
  output = torch.zeros(...)

The following modifications need to be made here:

# You need to move self.layers out of the if judgment. In this case, you can use @script_if_tracing
if x.numel() > 0:
  x = preprocess(x)
else:
  #Create empty inputs
  x = torch.zeros(...)
# You need to modify self.layers() to support empty input, or add the original conditional judgment to self.layers
output = self.layers(x)

2. Merge the results of multiple tracings

Using the model generated by torch.jit.script has two advantages over using torch.jit.trace:

  • You can use conditional control flow. For example, using a bool value in the model to control the forward flow is not supported in traced modules.
  • Using the traced module, there can only be one forward() function, but using the scripted module, you can have multiple forward calculation functions.
class Detector(nn.Module):
  do_keypoint: bool

  def forward(self, img):
      box = self.predict_boxes(img)
      if self.do_keypoint:
          kpts = self.predict_keypoint(img, box)

  @torch.jit.export
  def predict_boxes(self, img): pass

  @torch.jit.export
  def predict_keypoint(self, img, box): pass

For this kind of control flow with bool values, in addition to using script, you can also trace multiple times and then merge the results.

det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)

Then copy their weights and merge the results of the two traces

det2.submodule.weight = det1.submodule.weight
class Wrapper(nn.ModuleList):
  def forward(self, img, do_keypoint: bool):
    if do_keypoint:
        return self[0](img)
    else:
        return self[1](img)
exported = torch.jit.script(Wrapper([det1, det2]))

Performance of trace and script

Tracing will always generate the same or simpler calculation graph than scripting, so the performance will be better. Because scripting will completely express the logic of python code, even some unnecessary code will be expressed truthfully. For example the following example:

class A(nn.Module):
  def forward(self, x1, x2, x3):
    z = [0, 1, 2]
    xs = [x1, x2, x3]
    for k in z: x1 + = xs[k]
    return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# z = [0, 1, 2]
# xs = [x1, x2, x3]
# x10 = x1
# for _0 in range(torch.len(z)):
# k = z[_0]
# x10 = torch.add_(x10, xs[k])
# return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# x10 = torch.add_(x1, x1)
# x11 = torch.add_(x10, x2)
# return torch.add_(x11, x3)

Summary

Tracing has obvious limitations: most of this article is about tracing’s limitations and how to work around them. In fact, this is exactly the advantage of tracing: it has clear limitations (and solutions), so you can reason about whether it works.

Instead, scripting is more like a black box: no one knows whether it will work before trying it. The article doesn’t mention any tricks on how to fix scripting: there are a lot of tricks, but it’s not worth your time to dig into and fix a black box.

Both tracing and scripting will affect the way the code is written, but because tracing has clear requirements, some modifications to our original code will not be too serious:

  • It restricts input/output formats, but only to the outermost module. (As mentioned above, this problem can be solved with a wrapper).
  • It requires some code modifications to make it universal (such as adding some scripting when tracing), but these modifications only involve the internal implementation of the affected modules, not their interfaces.