pytorch 使用lmdb加载数据

2022-04-15 00:00:00 数据 方法 样本 权重 取样
1. 在pytorch中使用lmdb

lmdb数据由键值对组成,可以实现将所有键取出来,根据键去读取对应的值。lmdb的具体操作和原理不在这里说明。这里的示例使用由caffe生成的lmdb,也可以自己定义键值的形式并写入


示例:

class LmdbDataset_train(Dataset):
    def __init__(self,lmdb_path,optimizer,keys_path):
        # super().__init__()
        self.optimizer = optimizer

        self.datum=caffe_pb2.Datum()
        self.lmdb_path = lmdb_path
        keys = np.load(keys_path)
        self.keys = keys.tolist()
        self.length = len(self.keys)
    def open_lmdb(self):
        self.env = lmdb.open(self.lmdb_path, max_readers=1, readonly=True, lock=False,
                             readahead=False, meminit=False)
        self.txn = self.env.begin(buffers=True,write=False)
   
    def __getitem__(self, index):
        if not hasattr(self, 'txn'):
            self.open_lmdb()

        serialized_str = self.txn.get(self.keys[index])
        self.datum.ParseFromString(serialized_str)
        size=self.datum.width*self.datum.height
        
        pixles1=self.datum.data[0:size]
        pixles2=self.datum.data[size:2*size]
        pixles3=self.datum.data[2*size:3*size]

        image1=Image.frombytes('L', (self.datum.width, self.datum.height), pixles1)
        image2=Image.frombytes('L', (self.datum.width, self.datum.height), pixles2)
        image3=Image.frombytes('L', (self.datum.width, self.datum.height), pixles3)

        img=Image.merge("RGB",(image3,image2,image1))

        img =self.optimizer(img)

        label=self.datum.label
        return img, label

    def __len__(self):
        return self.length

相关文章