在pytorch中如何使用lmdb

2022-04-15 00:00:00 修改 函数 文件 图片 文件夹

总述
1、lmdb使用源码github链接:pytorch_lmdb_imagenet
2、使用方法:修改folder2lmdb.py文件即可
①先修改folder2lmdb函数,将图片文件夹转化为lmdb文件;
②再在实际实验中,修改 ImageFolderLMDB类,将现成的lmdb文件转化为dataset,方便后续读取。

folder2lmdb.py完整源码及具体修改如下:
import部分

import os
import os.path as osp
from PIL import Image
import six
import lmdb
import pickle
import numpy as np

import torch.utils.data as data
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

从定义的Dataset类,用于将图片数据集转换为lmdb后,将lmdb文件通过ImageFolderLMDB转化为Dataset,方便后续读取,修改__getitem__函数即修改读取内容。注意:
输入地址应为包含了一个data.mdb和一个lock.mdb文件的文件夹名,若senti文件夹下有data.mdb、lock.mdb以及其他文件,则输入“senti\”即可。
本页代码中生成的lmdb文件的文件名格式为train.lmdb和train.lmdb.lock,只需分别把这两个文件的文件名改为data.mdb和lock.mdb即可。

def loads_data(buf):
"""
Args:
buf: the output of `dumps`.
"""
return pickle.loads(buf)


class ImageFolderLMDB(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = loads_data(txn.get(b'__len__'))
self.keys = loads_data(txn.get(b'__keys__'))

self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])

unpacked = loads_data(byteflow)

# load img
imgbuf = unpacked[0]
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')

# load label
target = unpacked[1]

if self.transform is not None:
img = self.transform(img)

im2arr = np.array(img)

if self.target_transform is not None:
target = self.target_transform(target)

# return img, target
return im2arr, target

def __len__(self):
return self.length

def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
def raw_reader(path):
with open(path, 'rb') as f:
bin_data = f.read()
return bin_data


def dumps_data(obj):
"""
Serialize an object.
Returns:
Implementation-dependent bytes-like object
"""
return pickle.dumps(obj)
  • 将图片文件夹转化为lmdb文件的函数,输入图片所在文件夹,在该文件夹下输出.lmdb和.lmdb.lock文件。
def folder2lmdb(dpath, name="train", write_frequency=5000):
directory = osp.expanduser(osp.join(dpath, name))
print("Loading dataset from %s" % directory)
dataset = ImageFolder(directory, loader=raw_reader)
data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)

lmdb_path = osp.join(dpath, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)

print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)

txn = db.begin(write=True)
for idx, data in enumerate(data_loader):
image, label = data[]

txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((image, label)))
if idx % write_frequency == :
print("[%d/%d]" % (idx, len(data_loader)))
txn.commit()
txn = db.begin(write=True)

# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_data(keys))
txn.put(b'__len__', dumps_data(len(keys)))

print("Flushing database ...")
db.sync()
db.close()

调用folder2lmdb函数,注意:
若图片文件地址为"dataset/image/xxx.jpg",则此处的输入变量应为"dataset/",因为folder2lmdb函数在读取图片时,会把image作为图片的target存储,若输入变量写成"dataset/image/",会报错。

if __name__ == "__main__":
# generate lmdb
folder2lmdb("/home/jiang/dataset/imagenet/", name="train")
folder2lmdb("/home/jiang/dataset/imagenet/", name="val")



相关文章