Caffe2学习笔记(6) 创建属于自己的数据集
0. 前言
- 官方教程:
- 参考博客:
1. 基本概念
- 数据存储的基本格式:Caffe2使用二进制格式来存储数据,用于训练、预测等。
- Caffe2 DB的本质:一系列key-value对,将key与value作为字符串存储。
- key会随机初始化,使得数据独立同分布。
- value用于实际数据存储。
- 使用结构化数据(包括数据类型信息、shape、以及多维数组)存储DB
- 使用
TensorProtos
protocol buffer来存储数据。 - 后续使用
TensorProtosDBInput
operator来导入数据。 - 其他:
- Caffe2中shuffle数据好像是在保存DB前就shuffle好,不能像
tf.data
那样实时多次shuffle数据。
2. Python实现
- 注意事项
- 源码主要来源于这里。
- 原来代码是Python 2的,下面把代码改为Python 3。
- 导入依赖
# First let's import a few things needed.
%matplotlib inline
import urllib # for downloading the dataset from the web.
import numpy as np
from matplotlib import pyplot
import io
from caffe2.python import core, utils, workspace
from caffe2.proto import caffe2_pb2
- 下载原始数据,并展示
- 注意,这里使用了
urllib.request.urlopen
,而不是Python2中的urllib2
f = urllib.request.urlopen('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')
raw_data = f.read()
print('Raw data looks like this:')
print(raw_data[:100].decode()+'...')
- 将原始数据转变为ndarray
- 注意,这里没有使用Python2中的
StringIO.StringIO
,而使用了io.BytesIO
。
features = np.loadtxt(io.BytesIO(raw_data), dtype=np.float32, delimiter=',', usecols=(0, 1, 2, 3))
label_converter = lambda s : {b'Iris-setosa':0, b'Iris-versicolor':1, b'Iris-virginica':2}[s]
labels = np.loadtxt(io.BytesIO(raw_data), dtype=np.int, delimiter=',', usecols=(4,), converters={4: label_converter})
- 对输入数据进行shuffle,并可视化展示
random_index = np.random.permutation(150)
features = features[random_index]
labels = labels[random_index]
train_features = features[:100]
train_labels = labels[:100]
test_features = features[100:]
test_labels = labels[100:]
legend = ['rx', 'b+', 'go']
pyplot.title("Training data distribution, feature 0 and 1")
for i in range(3):
pyplot.plot(train_features[train_labels==i, 0], train_features[train_labels==i, 1], legend[i])
pyplot.figure()
pyplot.title("Testing data distribution, feature 0 and 1")
for i in range(3):
pyplot.plot(test_features[test_labels==i, 0], test_features[test_labels==i, 1], legend[i])
- 使用
TensorProtos
保存数据
feature_and_label = caffe2_pb2.TensorProtos()
feature_and_label.protos.extend([
utils.NumpyArrayToCaffe2Tensor(features[0]),
utils.NumpyArrayToCaffe2Tensor(labels[0])])
print('This is what the tensor proto looks like for a feature and its label:')
print(str(feature_and_label))
print('This is the compact string that gets written into the db:')
print(feature_and_label.SerializeToString())
# 实际写入DB
def write_db(db_type, db_name, features, labels):
db = core.C.create_db(db_type, db_name, core.C.Mode.write)
transaction = db.new_transaction()
for i in range(features.shape[0]):
feature_and_label = caffe2_pb2.TensorProtos()
feature_and_label.protos.extend([
utils.NumpyArrayToCaffe2Tensor(features[i]),
utils.NumpyArrayToCaffe2Tensor(labels[i])])
transaction.put(
'train_%03d'.format(i),
feature_and_label.SerializeToString())
# Close the transaction, and then close the db.
del transaction
del db
write_db("minidb", "iris_train.minidb", train_features, train_labels)
write_db("minidb", "iris_test.minidb", test_features, test_labels)
- 读取DB中的数据
- 流程:先创建DB(本质就是
CreateDB
Op),再创建TensorProtosDBInput
Op来读取。
net_proto = core.Net("example_reader")
dbreader = net_proto.CreateDB([], "dbreader", db="iris_train.minidb", db_type="minidb")
net_proto.TensorProtosDBInput([dbreader], ["X", "Y"], batch_size=16)
print("The net looks like this:")
print(str(net_proto.Proto()))
- 通过workspace读取数据
workspace.CreateNet(net_proto)
# Let's run it to get batches of features.
workspace.RunNet(net_proto.Proto().name)
print("The first batch of feature is:")
print(workspace.FetchBlob("X"))
print("The first batch of label is:")
print(workspace.FetchBlob("Y"))
# Let's run again.
workspace.RunNet(net_proto.Proto().name)
print("The second batch of feature is:")
print(workspace.FetchBlob("X"))
print("The second batch of label is:")
print(workspace.FetchBlob("Y"))
3. C++实现
- 代码主要来源于这里,主要功能是通过C++创建mnist Caffe2 DB。
- 其他可参考代码还有 make_image_db.cc(推荐)、make_cifar_db.cc等。
- 头文件与命令行参数
CAFFE2_DEFINE_string
等调用了gflags
中的DEFINE_string
等方法,用于设置命令行参数。
# 头文件
#include <fstream>
#include <string>
#include "caffe2/core/common.h"
#include "caffe2/core/db.h"
#include "caffe2/core/init.h"
#include "caffe2/proto/caffe2.pb.h"
#include "caffe2/core/logging.h"
# 命令行参数
CAFFE2_DEFINE_string(image_file, "", "The input image file name.");
CAFFE2_DEFINE_string(label_file, "", "The label file name.");
CAFFE2_DEFINE_string(output_file, "", "The output db name.");
CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
CAFFE2_DEFINE_int(data_limit, -1,
"If set, only output this number of data points.");
CAFFE2_DEFINE_bool(channel_first, false,
"If set, write the data as channel-first (CHW order) as the old "
"Caffe does.");
- 读取输入数据,并判断其合法性
- 其中,形如
CAFFE_ENFORCE
等,都是判断数据合法性,如果不合法会报错。源码位于 logging.h 中。
# 判断文件读取是否成功
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
CAFFE_ENFORCE(image_file, "Unable to open file ", image_filename);
CAFFE_ENFORCE(label_file, "Unable to open file ", label_filename);
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
# 查看文件格式
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
if (magic == 529205256) {
LOG(FATAL) <<
"It seems that you forgot to unzip the mnist dataset. You should "
"first unzip them using e.g. gunzip on Linux.";
}
CAFFE_ENFORCE_EQ(magic, 2051, "Incorrect image file magic.");
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CAFFE_ENFORCE_EQ(magic, 2049, "Incorrect label file magic.");
# 获取数据数量
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CAFFE_ENFORCE_EQ(num_items, num_labels);
# 获取图片尺寸
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
- 创建DB文件,并存储数据
- 与Python相同,通过TensorProtos对象来写入数据。
- 数据主要通过DB的
Transaction
对象来写入DB。
// 创建CreateDB以及对应的Transaction
std::unique_ptr<db::DB> mnist_db(db::CreateDB(caffe2::FLAGS_db, db_path, db::NEW));
std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
// 准备写入数据
char label_value;
std::vector<char> pixels(rows * cols);
int count = 0;
const int kMaxKeyLength = 10;
char key_cstr[kMaxKeyLength];
string value;
# 创建图片与标签对应的TensorProtos格式
TensorProtos protos;
TensorProto* data = protos.add_protos();
TensorProto* label = protos.add_protos();
data->set_data_type(TensorProto::BYTE);
if (caffe2::FLAGS_channel_first) {
data->add_dims(1);
data->add_dims(rows);
data->add_dims(cols);
} else {
data->add_dims(rows);
data->add_dims(cols);
data->add_dims(1);
}
label->set_data_type(TensorProto::INT32);
label->add_int32_data(0);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
for (int item_id = 0; item_id < num_items; ++item_id) {
image_file.read(pixels.data(), rows * cols); # 读取image_file中的数据
label_file.read(&label_value, 1); # 读取label_value中的数据
// 设置数据
for (int i = 0; i < rows * cols; ++i) {
data->set_byte_data(pixels.data(), rows * cols);
}
label->set_int32_data(0, static_cast<int>(label_value));
// 获取key与value值
snprintf(key_cstr, kMaxKeyLength, "%08d", item_id);
protos.SerializeToString(&value);
string keystr(key_cstr);
// 将数据添加到DB中,通过Transaction实例
transaction->Put(keystr, value);
if (++count % 1000 == 0) {
transaction->Commit();
}
if (data_limit > 0 && count == data_limit) {
LOG(INFO) << "Reached data limit of " << data_limit << ", stop.";
break;
}
}
编辑于 2018-05-14 16:27