# Explain Pytorch Tensor.stride and Tensor.storage with code examples

Pytorch’s `Tensor` class has a `storage()` and a `stride()` method. They are not very often used directly, but can sometimes be helpful when you need to take a closer look at the underlying data. (I’ll show an example of using them to illustrate the difference between `Tensor.expand()` and `Tensor.repeat()` at the end.)

As explained in Pytorch’s document, `storage()` simply “returns the underlying storage”, which is relatively straightforward.

But the explanation for `stride` from Pytorch’s document a bit difficult for me to understand:

Each strided tensor has an associated `torch.Storage`, which holds its data. These tensors provide multi-dimensional, strided view of a storage. Strides are a list of integers: the k-th stride represents the jump in the memory necessary to go from one element to the next one in the k-th dimension of the Tensor.

I had to try a few examples, which I found useful to help myself understand.

First, create an example tensor. Let it be of 3-row and 4-column.

`x = torch.arange(12).view((3,4))x`

The output is:

`tensor([[ 0,  1,  2,  3],        [ 4,  5,  6,  7],        [ 8,  9, 10, 11]])`

Take a look at `storage()`:

`storage = x.storage()storage`

Not surprisingly, the output contains the numbers in the tensor:

` 0 1 2 3 4 5 6 7 8 9 10 11[torch.LongStorage of size 12]`

Now look at `stride()`

`stride = x.stride()stride`

The output is `(4, 1)`.

Here comes the part that explains `stride`. Suppose you want to access the element of the tensor `x` at row-2 and column-1, whose value is `9`. We can write this in code. Note that `idx` are the indices of the element being accessed.

`idx = (2, 1)item = x[idx].item()item`

The output is as expected: `9`.

Now, using `idx` together with `stride`, we can access the item directly from the storage, like this:

`loc = idx*stride + idx*stride`

Or equivalently (but more generally):

`loc = sum([i*s for i,s in zip(idx, stride)])`

If we run `storage[loc]==item`, the output is `True`, confirming that the same item is accessed from the storage.

The above code shows how `stride` tells `idx` how to access an element in the storage. Note that `stride` is a tuple whose size is the same as the dimension of the tensor (in this example the dimension is `2`). For each dimension (e.g., dim-`0`), the corresponding element of `stride` (in this case`stride`) tells how much an index (`idx`) matters in terms of moving along the 1-dimensional storage.

Let’s quickly look at an example with a larger tensor of 3 dimensions.

`x = torch.rand(1200).view((10, 30, 4))storage = x.storage()stride = x.stride()stride`

`stride` is now`(120, 4, 1)`.

`idx = (4, 23, 2)item = x[idx].item()loc = sum([i*s for i,s in zip(idx, stride)])storage[loc]==item`

Output is still `True`, again confirming that the same item is accessed.

Other than showing the underlying data storage and access, `stride()` and `storage()` can also be useful to help understand some inner works of Pytorch. Here’s an example. There are some helpful discussions about the differences between `torch.repeat()` and `torch.expand()`, one of the topic being that `expand()` does not use extra memory (see the discussions for details). We can use `stride()` and `storage()` to verify that.

Let’s create an example tensor first

`x = torch.tensor([[0, 1, 2]])`

Now we create new tensors `y` and `z` using `expand` and `repeat`, respectively.

`y = x.expand(2, 3)z = x.repeat(2, 1)`

If you print them out, `y` and `z` look the same (both repeated `x` along dim-`0`):

`tensor([[0, 1, 2],        [0, 1, 2]])`

To show they use different amount of memory, here come `stride()` and `storage()`:

`xstorage = x.storage()xstride  = x.stride()ystorage = y.storage()ystride  = y.stride()zstorage = z.storage()zstride  = z.stride()`

Here, `xstride` is `(3, 1)`, and `xstorage` is

` 0 1 2[torch.LongStorage of size 3]`

`ystorage` is the same as `xstorage`:

` 0 1 2[torch.LongStorage of size 3]`

But `zstorage` is different (doubled):

` 0 1 2 0 1 2[torch.LongStorage of size 6]`

Notice, `ystride` is `(0, 1)`, different from `xstride` (`(3, 1)`); and `zstride`: `(3, 1)` , the same as `xstride`.

Similar to before, here’s some code showing `stride` and `storage` work for `y` and `z`, respectively.

`idx = (1,2)yloc = sum([i*s for i,s in zip(idx, ystride)])zloc = sum([i*s for i,s in zip(idx, zstride)])assert y[idx].item()==ystorage[yloc]assert z[idx].item()==zstorage[zloc]`

That’s it. Thanks for reading :)

References: