将PyTorch张量转换为Python列表

2022-02-23 00:00:00 python pytorch

问题描述

如何将PyTorchTensor转换为python列表?

我当前的用例是将大小[1, 2048, 1, 1]的张量转换为包含2048个元素的列表。

我的张量具有浮点值。是否有解决方案还可以解决INT和可能的其他数据类型?


解决方案

使用Tensor.tolist()例如:

>>> import torch
>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
 [-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803

若要删除大小1的所有维度,请使用a.squeeze().tolist()

或者,如果除一个维之外的所有维的大小均为1(或者您希望获取张量的每个元素的列表),则可以使用a.flatten().tolist()

相关文章