Explain Pytorch Tensor.stride and Tensor.storage with code examples
Pytorch’s Tensor
class has a storage()
and a stride()
method. They are not very often used directly, but can sometimes be helpful when you need to take a closer look at the underlying data. (I’ll show an example of using them to illustrate the difference between Tensor.expand()
and Tensor.repeat()
at the end.)
As explained in Pytorch’s document, storage()
simply “returns the underlying storage”, which is relatively straightforward.
But the explanation for stride
from Pytorch’s document a bit difficult for me to understand:
Each strided tensor has an associated
torch.Storage
, which holds its data. These tensors provide multi-dimensional, strided view of a storage. Strides are a list of integers: the k-th stride represents the jump in the memory necessary to go from one element to the next one in the k-th dimension of the Tensor.
I had to try a few examples, which I found useful to help myself understand.
First, create an example tensor. Let it be of 3-row and 4-column.
x = torch.arange(12).view((3,4))
x
The output is:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
Take a look at storage()
:
storage = x.storage()
storage
Not surprisingly, the output contains the numbers in the tensor:
0
1
2
3
4
5
6
7
8
9
10
11
[torch.LongStorage of size 12]
Now look at stride()
stride = x.stride()
stride
The output is (4, 1)
.
Here comes the part that explains stride
. Suppose you want to access the element of the tensor x
at row-2 and column-1, whose value is 9
. We can write this in code. Note that idx
are the indices of the element being accessed.
idx = (2, 1)
item = x[idx].item()
item
The output is as expected: 9
.
Now, using idx
together with stride
, we can access the item directly from the storage, like this:
loc = idx[0]*stride[0] + idx[1]*stride[1]
Or equivalently (but more generally):
loc = sum([i*s for i,s in zip(idx, stride)])
If we run storage[loc]==item
, the output is True
, confirming that the same item is accessed from the storage.
The above code shows how stride
tells idx
how to access an element in the storage. Note that stride
is a tuple whose size is the same as the dimension of the tensor (in this example the dimension is 2
). For each dimension (e.g., dim-0
), the corresponding element of stride
(in this casestride[0]
) tells how much an index (idx[0]
) matters in terms of moving along the 1-dimensional storage.
Let’s quickly look at an example with a larger tensor of 3 dimensions.
x = torch.rand(1200).view((10, 30, 4))
storage = x.storage()
stride = x.stride()
stride
stride
is now(120, 4, 1)
.
idx = (4, 23, 2)
item = x[idx].item()
loc = sum([i*s for i,s in zip(idx, stride)])
storage[loc]==item
Output is still True
, again confirming that the same item is accessed.
Other than showing the underlying data storage and access, stride()
and storage()
can also be useful to help understand some inner works of Pytorch. Here’s an example. There are some helpful discussions about the differences between torch.repeat()
and torch.expand()
, one of the topic being that expand()
does not use extra memory (see the discussions for details). We can use stride()
and storage()
to verify that.
Let’s create an example tensor first
x = torch.tensor([[0, 1, 2]])
Now we create new tensors y
and z
using expand
and repeat
, respectively.
y = x.expand(2, 3)
z = x.repeat(2, 1)
If you print them out, y
and z
look the same (both repeated x
along dim-0
):
tensor([[0, 1, 2],
[0, 1, 2]])
To show they use different amount of memory, here come stride()
and storage()
:
xstorage = x.storage()
xstride = x.stride()
ystorage = y.storage()
ystride = y.stride()
zstorage = z.storage()
zstride = z.stride()
Here, xstride
is (3, 1)
, and xstorage
is
0
1
2
[torch.LongStorage of size 3]
ystorage
is the same as xstorage
:
0
1
2
[torch.LongStorage of size 3]
But zstorage
is different (doubled):
0
1
2
0
1
2
[torch.LongStorage of size 6]
Notice, ystride
is (0, 1)
, different from xstride
((3, 1)
); and zstride
: (3, 1)
, the same as xstride
.
Similar to before, here’s some code showing stride
and storage
work for y
and z
, respectively.
idx = (1,2)
yloc = sum([i*s for i,s in zip(idx, ystride)])
zloc = sum([i*s for i,s in zip(idx, zstride)])
assert y[idx].item()==ystorage[yloc]
assert z[idx].item()==zstorage[zloc]
That’s it. Thanks for reading :)
Code is here: https://github.com/yang-zhang/yang-zhang.github.io/blob/master/ds_code/torch_stride.ipynb
References: