# Explain Pytorch Tensor.stride and Tensor.storage with code examples

Pytorch’s `Tensor`

class has a `storage()`

and a `stride()`

method. They are not very often used directly, but can sometimes be helpful when you need to take a closer look at the underlying data. (I’ll show an example of using them to illustrate the difference between `Tensor.expand()`

and `Tensor.repeat()`

at the end.)

As explained in Pytorch’s document, `storage()`

simply “returns the underlying storage”, which is relatively straightforward.

But the explanation for `stride`

from Pytorch’s document a bit difficult for me to understand:

Each strided tensor has an associated

`torch.Storage`

, which holds its data. These tensors provide multi-dimensional, strided view of a storage. Strides are a list of integers: the k-th stride represents the jump in the memory necessary to go from one element to the next one in the k-th dimension of the Tensor.

I had to try a few examples, which I found useful to help myself understand.

First, create an example tensor. Let it be of 3-row and 4-column.

`x = torch.arange(12).view((3,4))`

x

The output is:

`tensor([[ 0, 1, 2, 3],`

[ 4, 5, 6, 7],

[ 8, 9, 10, 11]])

Take a look at `storage()`

:

`storage = x.storage()`

storage

Not surprisingly, the output contains the numbers in the tensor:

` 0`

1

2

3

4

5

6

7

8

9

10

11

[torch.LongStorage of size 12]

Now look at `stride()`

`stride = x.stride()`

stride

The output is `(4, 1)`

.

Here comes the part that explains `stride`

. Suppose you want to access the element of the tensor `x`

at row-2 and column-1, whose value is `9`

. We can write this in code. Note that `idx`

are the indices of the element being accessed.

`idx = (2, 1)`

item = x[idx].item()

item

The output is as expected: `9`

.

Now, using `idx`

together with `stride`

, we can access the item directly from the storage, like this:

`loc = idx[0]*stride[0] + idx[1]*stride[1]`

Or equivalently (but more generally):

`loc = sum([i*s for i,s in zip(idx, stride)])`

If we run `storage[loc]==item`

, the output is `True`

, confirming that the same item is accessed from the storage.

The above code shows how `stride`

tells `idx`

how to access an element in the storage. Note that `stride`

is a tuple whose size is the same as the dimension of the tensor (in this example the dimension is `2`

). For each dimension (e.g., dim-`0`

), the corresponding element of `stride`

(in this case`stride[0]`

) tells how much an index (`idx[0]`

) matters in terms of moving along the 1-dimensional storage.

Let’s quickly look at an example with a larger tensor of 3 dimensions.

`x = torch.rand(1200).view((10, 30, 4))`

storage = x.storage()

stride = x.stride()

stride

`stride`

is now`(120, 4, 1)`

.

`idx = (4, 23, 2)`

item = x[idx].item()

loc = sum([i*s for i,s in zip(idx, stride)])

storage[loc]==item

Output is still `True`

, again confirming that the same item is accessed.

Other than showing the underlying data storage and access, `stride()`

and `storage()`

can also be useful to help understand some inner works of Pytorch. Here’s an example. There are some helpful discussions about the differences between `torch.repeat()`

and `torch.expand()`

, one of the topic being that `expand()`

does not use extra memory (see the discussions for details). We can use `stride()`

and `storage()`

to verify that.

Let’s create an example tensor first

`x = torch.tensor([[0, 1, 2]])`

Now we create new tensors `y`

and `z`

using `expand`

and `repeat`

, respectively.

`y = x.expand(2, 3)`

z = x.repeat(2, 1)

If you print them out, `y`

and `z`

look the same (both repeated `x`

along dim-`0`

):

`tensor([[0, 1, 2],`

[0, 1, 2]])

To show they use different amount of memory, here come `stride()`

and `storage()`

:

`xstorage = x.storage()`

xstride = x.stride()

ystorage = y.storage()

ystride = y.stride()

zstorage = z.storage()

zstride = z.stride()

Here, `xstride`

is `(3, 1)`

, and `xstorage`

is

` 0`

1

2

[torch.LongStorage of size 3]

`ystorage`

is the same as `xstorage`

:

` 0`

1

2

[torch.LongStorage of size 3]

But `zstorage`

is different (doubled):

` 0`

1

2

0

1

2

[torch.LongStorage of size 6]

Notice, `ystride`

is `(0, 1)`

, different from `xstride`

(`(3, 1)`

); and `zstride`

: `(3, 1)`

, the same as `xstride`

.

Similar to before, here’s some code showing `stride`

and `storage`

work for `y`

and `z`

, respectively.

`idx = (1,2)`

yloc = sum([i*s for i,s in zip(idx, ystride)])

zloc = sum([i*s for i,s in zip(idx, zstride)])

assert y[idx].item()==ystorage[yloc]

assert z[idx].item()==zstorage[zloc]

That’s it. Thanks for reading :)

Code is here: https://github.com/yang-zhang/yang-zhang.github.io/blob/master/ds_code/torch_stride.ipynb

References: