DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

narrow() and narrow_copy() in PyTorch

narrow() and narrow_copy() can extract the zero or more elements of a 1D or more D tensor from a 1D or more D tensor as shown below:

*Memos:

  • narrow() and narrow_copy() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is dim(Required).
  • The 3rd argument(int or tensor of int) with torch or the 2nd argument(int or tensor of int) with a tensor is start(Required).
  • The 4th argument(int) with torch or the 3rd argument(int) with a tensor is length(Required).
  • narrow() uses a shared tensor while narrow_copy() copies a tensor taking more memory so narrow() is ligher and faster than narrow_copy().
import torch

my_tensor = torch.tensor([6, -4, 5, 7, 1, 9, -6, 0, 3, -2, -5, 8])

torch.narrow(input=my_tensor, dim=0, start=3, length=5)
my_tensor.narrow(dim=0, start=3, length=5)
torch.narrow(input=my_tensor, dim=0, start=-9, length=5)
torch.narrow(input=my_tensor, dim=-1, start=3, length=5)
torch.narrow(input=my_tensor, dim=-1, start=-9, length=5)
torch.narrow(input=my_tensor, dim=0, start=torch.tensor(3), length=5)
torch.narrow(input=my_tensor, dim=0, start=torch.tensor(-9), length=5)
torch.narrow(input=my_tensor, dim=-1, start=torch.tensor(3), length=5)
torch.narrow(input=my_tensor, dim=-1, start=torch.tensor(-9), length=5)
torch.narrow_copy(input=my_tensor, dim=0, start=3, length=5)
my_tensor.narrow_copy(dim=0, start=3, length=5)
torch.narrow_copy(input=my_tensor, dim=0, start=-9, length=5)
torch.narrow_copy(input=my_tensor, dim=-1, start=3, length=5)
torch.narrow_copy(input=my_tensor, dim=-1, start=-9, length=5)
torch.narrow_copy(input=my_tensor, dim=0, start=torch.tensor(3), length=5)
torch.narrow_copy(input=my_tensor, dim=0, start=torch.tensor(-9), length=5)
torch.narrow_copy(input=my_tensor, dim=-1, start=torch.tensor(3), length=5)
torch.narrow_copy(input=my_tensor, dim=-1, start=torch.tensor(-9), length=5)
# tensor([7, 1, 9, -6, 0])

my_tensor = torch.tensor([[6, -4, 5, 7],
                          [1, 9, -6, 0],
                          [3, -2, -5, 8]])
torch.narrow(input=my_tensor, dim=0, start=1, length=2)
torch.narrow(input=my_tensor, dim=0, start=-2, length=2)
torch.narrow(input=my_tensor, dim=-2, start=1, length=2)
torch.narrow(input=my_tensor, dim=-2, start=-2, length=2)
torch.narrow(input=my_tensor, dim=0, start=torch.tensor(1), length=2)
torch.narrow(input=my_tensor, dim=0, start=torch.tensor(-2), length=2)
torch.narrow(input=my_tensor, dim=-2, start=torch.tensor(1), length=2)
torch.narrow(input=my_tensor, dim=-2, start=torch.tensor(-2), length=2)
torch.narrow_copy(input=my_tensor, dim=0, start=1, length=2)
torch.narrow_copy(input=my_tensor, dim=0, start=-2, length=2)
torch.narrow_copy(input=my_tensor, dim=-2, start=1, length=2)
torch.narrow_copy(input=my_tensor, dim=-2, start=-2, length=2)
torch.narrow_copy(input=my_tensor, dim=0, start=torch.tensor(1), length=2)
torch.narrow_copy(input=my_tensor, dim=0, start=torch.tensor(-2), length=2)
torch.narrow_copy(input=my_tensor, dim=-2, start=torch.tensor(1), length=2)
torch.narrow_copy(input=my_tensor, dim=-2, start=torch.tensor(-2), length=2)
# tensor([[1, 9, -6, 0],
#         [3, -2, -5, 8]])

torch.narrow(input=my_tensor, dim=1, start=1, length=2)
torch.narrow(input=my_tensor, dim=1, start=-3, length=2)
torch.narrow(input=my_tensor, dim=-1, start=1, length=2)
torch.narrow(input=my_tensor, dim=-1, start=-3, length=2)
torch.narrow(input=my_tensor, dim=1, start=torch.tensor(1), length=2)
torch.narrow(input=my_tensor, dim=1, start=torch.tensor(-3), length=2)
torch.narrow(input=my_tensor, dim=-1, start=torch.tensor(1), length=2)
torch.narrow(input=my_tensor, dim=-1, start=torch.tensor(-3), length=2)
torch.narrow_copy(input=my_tensor, dim=1, start=1, length=2)
torch.narrow_copy(input=my_tensor, dim=1, start=-3, length=2)
torch.narrow_copy(input=my_tensor, dim=-1, start=1, length=2)
torch.narrow_copy(input=my_tensor, dim=-1, start=-3, length=2)
torch.narrow_copy(input=my_tensor, dim=1, start=torch.tensor(1), length=2)
torch.narrow_copy(input=my_tensor, dim=1, start=torch.tensor(-3), length=2)
torch.narrow_copy(input=my_tensor, dim=-1, start=torch.tensor(1), length=2)
torch.narrow_copy(input=my_tensor, dim=-1, start=torch.tensor(-3), length=2)
# tensor([[-4, 5],
#         [9, -6],
#         [-2, -5]])

my_tensor = torch.tensor([[[6, -4], [5, 7]],
                          [[1, 9], [-6, 0]],
                          [[3, -2], [-5, 8]]])
torch.narrow(input=my_tensor, dim=0, start=1, length=1)
torch.narrow(input=my_tensor, dim=0, start=-2, length=1)
torch.narrow(input=my_tensor, dim=0, start=torch.tensor(1), length=1)
torch.narrow(input=my_tensor, dim=0, start=torch.tensor(-2), length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=-2, length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=torch.tensor(1), length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=torch.tensor(-2), length=1)
# tensor([[[1, 9],
#          [-6, 0]]])

torch.narrow(input=my_tensor, dim=1, start=1, length=1)
torch.narrow(input=my_tensor, dim=1, start=-1, length=1)
torch.narrow(input=my_tensor, dim=1, start=torch.tensor(1), length=1)
torch.narrow(input=my_tensor, dim=1, start=torch.tensor(-1), length=1)
torch.narrow_copy(input=my_tensor, dim=1, start=1, length=1)
torch.narrow_copy(input=my_tensor, dim=1, start=-1, length=1)
torch.narrow_copy(input=my_tensor, dim=1, start=torch.tensor(1), length=1)
torch.narrow_copy(input=my_tensor, dim=1, start=torch.tensor(-1), length=1)
# tensor([[[5, 7]],
#        [[-6, 0]],
#        [[-5, 8]]])

torch.narrow(input=my_tensor, dim=2, start=1, length=1)
torch.narrow(input=my_tensor, dim=2, start=-1, length=1)
torch.narrow(input=my_tensor, dim=2, start=torch.tensor(1), length=1)
torch.narrow(input=my_tensor, dim=2, start=torch.tensor(-1), length=1)
torch.narrow_copy(input=my_tensor, dim=2, start=1, length=1)
torch.narrow_copy(input=my_tensor, dim=2, start=-1, length=1)
torch.narrow_copy(input=my_tensor, dim=2, start=torch.tensor(1), length=1)
torch.narrow_copy(input=my_tensor, dim=2, start=torch.tensor(-1), length=1)
# tensor([[[-4], [7]],
#         [[9], [0]],
#         [[-2], [8]]])

my_tensor = torch.tensor([[[6., -4.], [5., 7.]],
                          [[1., 9.], [-6., 0.]],
                          [[3., -2.], [-5., 8.]]])
torch.narrow(input=my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=1, length=1)
# tensor([[[1., 9.],
#          [-6., 0.]]])

my_tensor = torch.tensor([[[6.+0.j, -4.+0.j], [5.+0.j, 7.+0.j]],
                          [[1.+0.j, 9.+0.j], [-6.+0.j, 0.+0.j]],
                          [[3.+0.j, -2.+0.j], [-5.+0.j, 8.+0.j]]])
torch.narrow(input=my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=1, length=1)
# tensor([[[1.+0.j, 9.+0.j],
#          [-6.+0.j, 0.+0.j]]])

my_tensor = torch.tensor([[[True, False], [True, False]],
                          [[False, True], [False, True]],
                          [[True, False], [True, False]]])
torch.narrow(input=my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(input=my_tensor, dim=0, start=1, length=1)
# tensor([[[False, True],
#          [False, True]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)