尺寸为M<32的火炬张量分度错误?

2022-02-23 00:00:00 python pytorch debugging

问题描述

我正在尝试通过索引矩阵访问pytorch张量,但我最近发现这段代码找不到无法工作的原因。

下面的代码分为两部分。前半部分被证明是有效的,而后半部分是错误的。我看不出原因。有没有人能解释一下这件事?

import torch
import numpy as np

a = torch.rand(32, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx]   # WORKS for a torch.tensor of size M >= 32. It doesn't work otherwise.

a = torch.rand(16, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx]   # IndexError: too many indices for tensor of dimension 2

如果我更改a = np.random.rand(16, 16),它也可以正常工作。


解决方案

首先,让我快速了解一下如何使用一个数值数组和另一个张量来索引张量。

示例:这是我们要索引的目标张量

    numpy_indices = torch.tensor([[0, 1, 2, 7],
                                  [0, 1, 2, 3]])   # numpy array

    tensor_indices = torch.tensor([[0, 1, 2, 7],
                                   [0, 1, 2, 3]])   # 2D tensor

    t = torch.tensor([[1,  2,  3,   4],            # targeted tensor
                      [5,  6,  7,   8],
                      [9,  10, 11, 12],
                      [13, 14, 15, 16],
                      [17, 18, 19, 20],
                      [21, 22, 23, 24],
                      [25, 26, 27, 28],
                      [29, 30, 31, 32]])
     numpy_result = t[numpy_indices]
     tensor_result = t[tensor_indices]
  • 使用2D数值数组编制索引:索引的读取方式类似于成对(x,y)张量[行,列],例如t[0,0], t[1,1], t[2,2], and t[7,3]

    print(numpy_result)  # tensor([ 1,  6, 11, 32])
    
  • 使用2D张量进行索引:以行的方式遍历索引张量,每个值都是目标张量中一行的索引。 例如[ [t[0],t[1],t[2],[7]] , [[0],[1],[2],[3]] ]参见下例,索引后的tensor_result的新形状为(tensor_indices.shape[0],tensor_indices.shape[1],t.shape[1])=(2,4,4)

    print(tensor_result)     # tensor([[[ 1,  2,  3,  4],
                             #          [ 5,  6,  7,  8],
                             #          [ 9, 10, 11, 12],
                             #          [29, 30, 31, 32]],
    
                             #          [[ 1,  2,  3,  4],
                             #           [ 5,  6,  7,  8],
                             #           [ 9, 10, 11, 12],
                             #           [ 13, 14, 15, 16]]])
    

如果您尝试在numpy_indices中添加第三行,您将收到相同的错误,因为索引将由3D表示,例如,(0,0,0).(7,3,3)。

indices = np.array([[0, 1, 2, 7],
                    [0, 1, 2, 3],
                    [0, 1, 2, 3]])

print(numpy_result)   # IndexError: too many indices for tensor of dimension 2

但是,张量索引不是这种情况,形状将更大(3,4,4)。

最后,如您所见,这两种索引类型的输出完全不同。要解决您的问题,您可以使用

xx = torch.tensor(xx).long()  # convert a numpy array to a tensor

高级索引(NUMPY_INDEX>;3行)的情况如何,因为您的情况仍然不明确且未解决,您可以检查1、2、3。

相关文章