从Cython结构创建NumPy数据类型

2022-04-04 00:00:00 python numpy cython

问题描述

以下是SCRICKIT中当前使用的Cython代码片段-学习二叉树,

  # Some compound datatypes used below:
  cdef struct NodeHeapData_t:
      DTYPE_t val
      ITYPE_t i1
      ITYPE_t i2

  # build the corresponding numpy dtype for NodeHeapData
  cdef NodeHeapData_t nhd_tmp
  NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype

(完整源代码here)

最后一行从该Cython结构创建一个NumPy数据类型。我还没有找到很多关于它的文档,尤其是我不明白为什么需要切片[:1],或者它能做什么。有关更多讨论,请参阅scikit-learn#17228。有人对此有什么想法吗?


解决方案

这是一个聪明但令人困惑的把戏!

以下代码创建一个长度为1的cython-array,因为它使用的内存(但不拥有!)正好有一个元素。

cdef NodeHeapData_t nhd_tmp
<NodeHeapData_t[:1]>(&nhd_tmp)

现在,cython-array实现了缓冲协议,因此Cython拥有创建format字符串的机制,该字符串描述它所持有的元素的类型。

np.asarray也使用缓冲协议,能够从format-字符串构造dtype-对象,format-字符串由cython的数组提供。

可以通过以下方式查看格式字符串:

%%cython
import numpy as np

# Some compound datatypes used below:
cdef struct NodeHeapData_t:
  double val
  int i1
  int i2

# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype

print("format string:",memoryview(<NodeHeapData_t[:1]>(&nhd_tmp)).format)
print(NodeHeapData )

这将导致

format string: T{d:val:i:i1:i:i2:}
[('val', '<f8'), ('i1', '<i4'), ('i2', '<i4')]

我脑子里想不出一个不那么令人困惑的解决方案,除了手动创建dtype对象-这对于不同平台上的某些数据类型来说可能会变得很难看*,但在大多数情况下应该是直接的。


*)np.int就是这样一个问题。很容易忽略np.int映射到long而不是int(令人困惑,不是吗?)。

例如

memoryview(np.zeros(1, dtype=np.int)).itemsize

计算为

  • 在Windows上:4(Windows上的大小为long,单位为字节)。
  • 在Linux上:8(Linux上的大小为long,单位为字节)。

相关文章