Explain Pytorch Tensor.stride and Tensor.storage with code examples

Yang Zhang
3 min readJun 22, 2020

Pytorch’s Tensor class has a storage() and a stride() method. They are not very often used directly, but can sometimes be helpful when you need to take a closer look at the underlying data. (I’ll show an example of using them to illustrate the difference between Tensor.expand() and Tensor.repeat() at the end.)

As explained in Pytorch’s document, storage() simply “returns the underlying storage”, which is relatively straightforward.

But the explanation for stride from Pytorch’s document a bit difficult for me to understand:

Each strided tensor has an associated torch.Storage, which holds its data. These tensors provide multi-dimensional, strided view of a storage. Strides are a list of integers: the k-th stride represents the jump in the memory necessary to go from one element to the next one in the k-th dimension of the Tensor.

I had to try a few examples, which I found useful to help myself understand.

First, create an example tensor. Let it be of 3-row and 4-column.

x = torch.arange(12).view((3,4))

The output is:

tensor([[ 0,  1,  2,  3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])

Take a look at storage():

storage = x.storage()

Not surprisingly, the output contains the numbers in the tensor:

[torch.LongStorage of size 12]

Now look at stride()

stride = x.stride()

The output is (4, 1).

Here comes the part that explains stride. Suppose you want to access the element of the tensor x at row-2 and column-1, whose value is 9. We can write this in code. Note that idx are the indices of the element being accessed.

idx = (2, 1)
item = x[idx].item()

The output is as expected: 9.

Now, using idx together with stride, we can access the item directly from the storage, like this:

loc = idx[0]*stride[0] + idx[1]*stride[1]

Or equivalently (but more generally):

loc = sum([i*s for i,s in zip(idx, stride)])

If we run storage[loc]==item, the output is True, confirming that the same item is accessed from the storage.

The above code shows how stride tells idx how to access an element in the storage. Note that stride is a tuple whose size is the same as the dimension of the tensor (in this example the dimension is 2). For each dimension (e.g., dim-0), the corresponding element of stride (in this casestride[0]) tells how much an index (idx[0]) matters in terms of moving along the 1-dimensional storage.

Let’s quickly look at an example with a larger tensor of 3 dimensions.

x = torch.rand(1200).view((10, 30, 4))
storage = x.storage()
stride = x.stride()

stride is now(120, 4, 1).

idx = (4, 23, 2)
item = x[idx].item()
loc = sum([i*s for i,s in zip(idx, stride)])

Output is still True, again confirming that the same item is accessed.

Other than showing the underlying data storage and access, stride() and storage() can also be useful to help understand some inner works of Pytorch. Here’s an example. There are some helpful discussions about the differences between torch.repeat() and torch.expand(), one of the topic being that expand() does not use extra memory (see the discussions for details). We can use stride() and storage() to verify that.

Let’s create an example tensor first

x = torch.tensor([[0, 1, 2]])

Now we create new tensors y and z using expand and repeat, respectively.

y = x.expand(2, 3)
z = x.repeat(2, 1)

If you print them out, y and z look the same (both repeated x along dim-0):

tensor([[0, 1, 2],
[0, 1, 2]])

To show they use different amount of memory, here come stride() and storage():

xstorage = x.storage()
xstride = x.stride()
ystorage = y.storage()
ystride = y.stride()
zstorage = z.storage()
zstride = z.stride()

Here, xstride is (3, 1), and xstorage is

[torch.LongStorage of size 3]

ystorage is the same as xstorage:

[torch.LongStorage of size 3]

But zstorage is different (doubled):

[torch.LongStorage of size 6]

Notice, ystride is (0, 1), different from xstride ((3, 1)); and zstride: (3, 1) , the same as xstride.

Similar to before, here’s some code showing stride and storage work for y and z, respectively.

idx = (1,2)
yloc = sum([i*s for i,s in zip(idx, ystride)])
zloc = sum([i*s for i,s in zip(idx, zstride)])
assert y[idx].item()==ystorage[yloc]
assert z[idx].item()==zstorage[zloc]

That’s it. Thanks for reading :)

Code is here: https://github.com/yang-zhang/yang-zhang.github.io/blob/master/ds_code/torch_stride.ipynb