PyTorch memory allocation during tensor shape manipulation

March 18, 2026

During my daily MLSys development, I'll encounter some tensor operations in PyTorch, such as .view(), .reshape(). When combined with .T (transpose), .contiguous(), the scenario becomes more complex. In this post, I'm trying to record whether these ops return a brand new tensor with newly allocated underlying memory (GPU or CPU), or merely a "view" of the original one.

Monitoring PyTorch memory allocation

Before we officially start, we should find a way to detect every single memory allocation (at least for tensor creations). Of course, PyTorch maintains some visualization tools like torch.profiler, but I think it's too heavy for our purposes.

šŸ’”

What I want is a logging interface to record every single tensor creation from the caching allocator, not the actual CUDA allocation. So we use torch.cuda.memory_allocated interface here to monitor and log that.

The monitoring code looks like this:

import torch
from contextlib import contextmanager

@contextmanager
def detect_cuda_allocation(label: str):
    torch.cuda.synchronize()
    start_mem = torch.cuda.memory_allocated()
    yield

    torch.cuda.synchronize()
    end_mem = torch.cuda.memory_allocated()
    diff = end_mem - start_mem

    if diff > 0:
        print(f"šŸš€ [ALLOCATION] {label}: {diff / 1024:.2f} KB allocated.")
    elif diff < 0:
        print(f"šŸ—‘ļø  [RELEASE] {label}: {abs(diff) / 1024:.2f} KB freed.")
    else:
        print(f"āœ… [NO CHANGE] {label}: No new CUDA memory allocated.")

.view

Let's get started with the simpler case: .view(), the key take away is:

šŸ“š
  • .view never allocates
  • .view can only be called on contiguous tensors

.view accepts two types of input parameters, dtype and shape. As the self tensor is required to be contiguous, .view(shape) only modifies the tensor metadata (shape and stride), and the underlying data_ptr remains unchanged and thus no allocation happens. If a new dtype is passed, PyTorch just returns a new tensor with reinterpreted data type (just like reinterprete_cast in C++), together with the last dim shrink or scale.

You can test with

x = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
with detect_cuda_allocation(".view"):
    y = x.view(512, 2048)
    z = x.view(torch.int32)

.T (.transpose), .permute

Just like .view, transpose does not allocate any new memory. It just swap the tensor metadata.

with detect_cuda_allocation(".T / transpose"):
    y = x.T
    z = x.transpose(0, 1)

.permute, as a more general version of .transpose, does not allocate any memory at all.

.contiguous

with detect_cuda_allocation(".contiguous"):
    y = x.T
    print(f"{x.is_contiguous() = }, {y.is_contiguous() = }")
    y = y.contiguous()

The above recipe prints out:

x.is_contiguous() = True, y.is_contiguous() = False
šŸš€ [ALLOCATION] .contiguous: 2048.00 KB allocated.

So we can see, if the tensor layout is not contiguous, PyTorch automatically (and, intentionally) creates a new tensor with corresponding dimension swapped and thus contiguous.

āš ļø

So if you want to pass a transposed/permuted tensor data_ptr to a custom CUDA kernel, remember to call contiguous and most importantly, use TORCH_CHECK to check inputs' continuity at the host side.

And there is another question, given a tensor with dimension (d0,d1,…,dk)(d_0, d_1, \dots, d_k) and stride (s0,s1,…,sk)(s_0, s_1, \dots, s_k), how to determine whether it's contiguous or not?

First, sks_k must be 1, otherwise it's not contiguous. As stride at dimension ii is used to describe how many elements we need to travel when the index increases by 1 at dimension ii, if the tensor is contiguous, we have the following relationship between stride and dimension:

si=si+1ā‹…di+1,i=kāˆ’1,…,0.s_i = s_{i+1} \cdot d_{i+1}, i=k-1,\dots,0.

.reshape

PyTorch official documents say .reshape always attempts to return a tensor view whenever possible. Intuitively, .reshape returns a view if the self tensor contiguous, otherwise a new copy is created by .contiguous call.

with detect_cuda_allocation(".reshape.view"):
    y = x.reshape(512, 2048)

with detect_cuda_allocation(".reshape.copy"):
    y = x.T.reshape(512, 2048)

With the above recipe, you will observe no allocation in the first scope since the original tensor to be reshaped is contiguous already. As for the second case, new allocation is needed.

.repeat and .expand

torch.Tensor.repeat and torch.Tensor.expand are two similar logical viewing APIs that behave differently regard to memory allocation. The key takeways are: .repeat always copies memory while .expand always returns a view.

.narrow

Programatically, Python's slice grammar can be treated exactly equivalant to torch.Tensor.narrow, both of them return a view of some part of the source tensor. Therefore, any modification to that sliced view reflects back to the source.

Takeaways

So the key takeaway is: .reshape allocates new memory when the source tensor is not contiguous, and .repeat always asks for new memory buffer. The leftover operations don't take new memory.