Caffe教程系列之LMDB

本教程旨在督促自己从头到尾重新学习一遍Caffe,主要参考Caffe主页的教程和大牛们的博文,若有不妥之处,还望告知。

Caffe系列教程 - 闲渔的文章 - 知乎专栏

LMDB是Cafffe中应用的一种数据库,我们常常需要对LMDB进行读写操作,本文介绍如何采用Python代码进行LMDB的读写操作。

HDF5和LMDB相比,HDF5的读写格式简单;LMDB采用内存-映射文件(memory-mapped files),所以拥有非常好的I/O性能,而且对于大型数据库来说,HDF5的文件常常整个写入内存,所以HDF5的文件大小就受限于内存大小,当然也可以通过文件分割来解决问题,但其I/O性能就不如LMDB的页缓存(page cachiing)策略了。

Python读写LMDB

首先确认你安装了lmdb(pip安装:pip install lmdb)和Caffe的python包(Caffe中的pycaffe)。LMDB采用键值对<key-value>的存储格式,key就是字符形式的ID,value是Caffe中Datum类的序列化形式。

import numpy as np
import lmdb
import caffe

N = 1000

# Let's pretend this is interesting data
X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
y = np.zeros(N, dtype=np.int64)

# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 10

env = lmdb.open('mylmdb', map_size=map_size)

with env.begin(write=True) as txn:
    # txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
        datum.label = int(y[i])
        str_id = '{:08}'.format(i)

        # The encode is only essential in Python 3
        txn.put(str_id.encode('ascii'), datum.SerializeToString())

同样地,也可以采用Python读取LMDB数据库。

import numpy as np
import lmdb
import caffe

env = lmdb.open('mylmdb', readonly=True)
with env.begin() as txn:
    raw_datum = txn.get(b'00000000')

datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(raw_datum)

flat_x = np.fromstring(datum.data, dtype=np.uint8)
x = flat_x.reshape(datum.channels, datum.height, datum.width)
y = datum.label

其中,<key, value>键值对的遍历:

with env.begin() as txn:
    cursor = txn.cursor()
    for key, value in cursor:
        print(key, value)

Python读取图像数据

很多时候我们需要从硬盘读取图像数据并存入到LMDB数据库,或者读取LMDB数据Datum并还原图像。

读取LMDB数据库中的Datum数据,这里再稍微介绍一下Datum的格式:channels:图片的通道,彩色图有3个通道,灰度图只有1通道,当然也可以用通道数来表示其他意思,比如表示两张图片,每个通道一个单张的图;height:图片(即data)的高;width:图片(即data)的宽;data:图片的数据(像素值);label:图片的label。

import sys
import numpy as np
import lmdb
import caffe
import argparse
from matplotlib import pyplot
 
lmdbpath = 'you/lmdb/file/path'
env = lmdb.open(lmdbpath, readonly=True)
with env.begin() as txn:
  cursor = txn.cursor()
  for key, value in cursor:
    print 'key: ',key
    datum = caffe.proto.caffe_pb2.Datum() #datum类型
    datum.ParseFromString(value) #转成datum
    flat_x = np.fromstring(datum.data, dtype=np.uint8) #转成numpy类型
    x = flat_x.reshape(datum.channels, datum.height, datum.width)
    y = datum.label#图片的label
    fig = pyplot.figure()#把两张图片显示出来    
    pyplot.imshow(x, cmap='gray')

Caffe一般将图片名字定入到txt文件中,也可以采用其他方式,这里直接采用image_list表示所有图像的路径及其对应的标签,采用二维字符数组形式保存

import numpy as np
import lmdb
import Image as img
from skimage import io
import caffe

env = lmdb.Environment(args.lmdb, map_size=int(1e10)) # map_size指数据库大小,根据实际需要进行设置
with env.begin(write=True) as txn:
  # txn is a Transaction object
  for i in range(len(image_list)):
    datum = caffe.proto.caffe_pb2.Datum()
    img = np.array(io.imread(image_list[i][0]))
    label = int(image_list[i][1])

    img = img.transpose((2, 0, 1)) # caffe存储的原因,由RGB转为BGR
    datum = caffe.io.array_to_datum(img)
    datum.label = int(img_files[i]['label'])
    str_id = '%08d' % i # 可以加上文件名或者其他字符
    txn.put(str_id.encode('ascii'), datum.SerializeToString())

注意:数据库的读是按照key的排序读的,key的顺序并不是按照写的顺序,是字典序。所以写数据库时key必须重新写,如果把图片名字作为key读数据库出来的图片就是按照图片的字典序(不是写的顺序)。str_id是key。


参考:

Creating an LMDB database in Python · Deep learning at the University of Chicago

caffe 数据库LMDB的读写-野孩子的专栏

编辑于 2016-11-07