从Cython结构创建NumPy数据类型
问题描述
以下是SCRICKIT中当前使用的Cython代码片段-学习二叉树,
# Some compound datatypes used below:
cdef struct NodeHeapData_t:
DTYPE_t val
ITYPE_t i1
ITYPE_t i2
# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype
(完整源代码here)
最后一行从该Cython结构创建一个NumPy数据类型。我还没有找到很多关于它的文档,尤其是我不明白为什么需要切片[:1]
,或者它能做什么。有关更多讨论,请参阅scikit-learn#17228。有人对此有什么想法吗?
解决方案
这是一个聪明但令人困惑的把戏!
以下代码创建一个长度为1的cython-array,因为它使用的内存(但不拥有!)正好有一个元素。cdef NodeHeapData_t nhd_tmp
<NodeHeapData_t[:1]>(&nhd_tmp)
现在,cython-array实现了缓冲协议,因此Cython拥有创建format
字符串的机制,该字符串描述它所持有的元素的类型。
np.asarray
也使用缓冲协议,能够从format
-字符串构造dtype
-对象,format
-字符串由cython的数组提供。
可以通过以下方式查看格式字符串:
%%cython
import numpy as np
# Some compound datatypes used below:
cdef struct NodeHeapData_t:
double val
int i1
int i2
# build the corresponding numpy dtype for NodeHeapData
cdef NodeHeapData_t nhd_tmp
NodeHeapData = np.asarray(<NodeHeapData_t[:1]>(&nhd_tmp)).dtype
print("format string:",memoryview(<NodeHeapData_t[:1]>(&nhd_tmp)).format)
print(NodeHeapData )
这将导致
format string: T{d:val:i:i1:i:i2:}
[('val', '<f8'), ('i1', '<i4'), ('i2', '<i4')]
我脑子里想不出一个不那么令人困惑的解决方案,除了手动创建dtype
对象-这对于不同平台上的某些数据类型来说可能会变得很难看*,但在大多数情况下应该是直接的。
*)np.int
就是这样一个问题。很容易忽略np.int
映射到long
而不是int
(令人困惑,不是吗?)。
例如
memoryview(np.zeros(1, dtype=np.int)).itemsize
计算为
- 在Windows上:4(Windows上的大小为
long
,单位为字节)。 - 在Linux上:8(Linux上的大小为
long
,单位为字节)。
相关文章