pytorch 使用lmdb加载数据
lmdb数据由键值对组成,可以实现将所有键取出来,根据键去读取对应的值。lmdb的具体操作和原理不在这里说明。这里的示例使用由caffe生成的lmdb,也可以自己定义键值的形式并写入
示例:
class LmdbDataset_train(Dataset):
def __init__(self,lmdb_path,optimizer,keys_path):
# super().__init__()
self.optimizer = optimizer
self.datum=caffe_pb2.Datum()
self.lmdb_path = lmdb_path
keys = np.load(keys_path)
self.keys = keys.tolist()
self.length = len(self.keys)
def open_lmdb(self):
self.env = lmdb.open(self.lmdb_path, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
self.txn = self.env.begin(buffers=True,write=False)
def __getitem__(self, index):
if not hasattr(self, 'txn'):
self.open_lmdb()
serialized_str = self.txn.get(self.keys[index])
self.datum.ParseFromString(serialized_str)
size=self.datum.width*self.datum.height
pixles1=self.datum.data[0:size]
pixles2=self.datum.data[size:2*size]
pixles3=self.datum.data[2*size:3*size]
image1=Image.frombytes('L', (self.datum.width, self.datum.height), pixles1)
image2=Image.frombytes('L', (self.datum.width, self.datum.height), pixles2)
image3=Image.frombytes('L', (self.datum.width, self.datum.height), pixles3)
img=Image.merge("RGB",(image3,image2,image1))
img =self.optimizer(img)
label=self.datum.label
return img, label
def __len__(self):
return self.length
相关文章