SRCNN实现图像超分辨率重建

SRCNN实现图像超分辨率重建

基于ECCV2014中的"Learning a Deep Convolutional Network for Image Super-Resolution"一文,用windows+tensorflow实现作者(原文是用caffe实现)的工作。

原文链接如下:

http://personal.ie.cuhk.edu.hk/~ccloy/files/eccv_2014_deepresolution.pdfpersonal.ie.cuhk.edu.hk
http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.htmlmmlab.ie.cuhk.edu.hk

1.SRCNN简介

(1)总述

SRCNN的网络结构仅包含三个卷积层,网络结构十分简单,如下图:

SRCNN首先使用双三次(bicubic)插值将低分辨率图像放大成目标尺寸,接着通过三层卷积网络拟合非线性映射,最后输出高分辨率图像结果。本文中,作者将三层卷积的结构解释成三个步骤:图像块的提取和特征表示特征非线性映射和最终的重建

(2)网络设计思路

从稀疏编码得来的,并表述为之前所述的三个步骤如下

Patch extraction: 提取图像Patch,进行卷积提取特征,类似于稀疏编码中的将图像patch映射到低分辨率字典中

Non-linear mapping: 将低分辨率的特征映射为高分辨率特征,类似于字典学习中的找到图像patch对应的高分辨字典

Reconstruction:根据高分辨率特征进行图像重建。类似于字典学习中的根据高分辨率字典进行图像重建

(3)网络结构

第一层为conv层(实现数据读入)

输入:低分辨率补丁

卷积核:c*f1*f1*n1(其中,c为输入图像通道数,文中取YCrCb中Y通道,c=1;f1=9;n1为当前卷积核输出深度取64)

第二层为conv层(实现非线性多个映射)

输入:第一层输入

卷积核:n1*1*1*n2(其中,n1为前一卷积层输出数据深度64,n2为当前层数据输出深度为32)

第三层为conv层(实现重建)

输入:第二层输出

卷积核:n2*f3*f3*c(其中,n2为前一个卷积层输出数据深度,f3=5,c为重建后高分辨率图通道数,和输入保持一致c=1)

(4)训练与测试

训练过程:对高分辨率图像随机选取一些patch,然后进行降采样,然后进行升采样,以此作为输入,并以原始高分辨率图像作为目标,采用逐像素损失为优化目标。

测试过程:首先将图像插值一定的倍数作为输入图像,然后通过网络,输出结果。


2.python(tensorflow)实现

共包含三个部分,主函数main.py、网络结构函数model.py、工具函数utils.py

(1)main.py功能:

定义训练和测试参数(包括:如果采用SGD时的batchSize、学习率、步长stride、训练还是测试模式),此后由设定的参数进行训练或测试。

(2)model.py功能:

定义网络的结构(三个卷积层,以及它们的卷积核大小):

#第一层CNN:对输入图片的特征提取。(9 x 9 x 64卷积核)

#第二层CNN:对第一层提取的特征的非线性映射(1 x 1 x 32卷积核)

#第三层CNN:对映射后的特征进行重建,生成高分辨率图像(5 x 5 x 1卷积核)

训练方式:SGD、Adam等(测试结果SGD的效果更好)

(3)utils.py功能:

存放需要使用的小函数,包括

read_data(path) #读取h5格式数据文件

preprocess(path, scale=3) #对路径下的image裁剪成scale整数倍,再对image缩小1/scale倍后,放大scale倍以得到低分辨率图input_,调整尺寸后的image为高分辨率图label_

prepare_data(sess, dataset) #数据准备

make_data(sess, data, label) #把数据保存成.h5格式

imread(path, is_grayscale=True) #读指定路径的图像

modcrop(image, scale=3) #把图像的长和宽都变成scale的倍数

modcrop_small(image) #把result变为和origin一样的大小

input_setup(sess, config) #读图像集,制作子图并保存为h5文件格式,以及训练和测试操作


源代码如下:

(1)main.py函数

from model import SRCNN
from utils import input_setup
import numpy as np
import tensorflow as tf
import pprint
import os
flags = tf.app.flags
flags.DEFINE_integer("epoch", 2000,"训练多少波")
#flags.DEFINE_integer("batch_size", 128, "The size of batch images [128]")
#一开始将batch size设为128和64,不仅参数初始loss很大,而且往往一段时间后训练就发散
#batch中每个样本产生梯度竞争可能比较激烈,所以导致了收敛过慢
#后来改回了128
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_integer("image_size", 33, "图像使用的尺寸")
flags.DEFINE_integer("label_size", 21, "label_制作的尺寸")
#学习率文中设置为 前两层1e-4 第三层1e-5
#SGD+指数学习率10-2作为初始
flags.DEFINE_float("learning_rate", 1e-2, "学习率")
flags.DEFINE_integer("c_dim", 1, "图像维度")
flags.DEFINE_integer("scale", 3, "sample的scale大小")
#stride训练采用14,测试采用21
flags.DEFINE_integer("stride", 21 , "步长为14或者21")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "checkpoint directory名字")
flags.DEFINE_string("sample_dir", "sample", "sample directory名字")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing")#测试
#flags.DEFINE_boolean("is_train", True, "True for training, False for testing")#训练
FLAGS = flags.FLAGS
pp = pprint.PrettyPrinter()
def main(_):
  pp.pprint(flags.FLAGS.__flags)
  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)
  with tf.Session() as sess:
    srcnn = SRCNN(sess, 
                  image_size=FLAGS.image_size, 
                  label_size=FLAGS.label_size, 
                  batch_size=FLAGS.batch_size,
                  c_dim=FLAGS.c_dim, 
                  checkpoint_dir=FLAGS.checkpoint_dir,
                  sample_dir=FLAGS.sample_dir)
    srcnn.train(FLAGS)
if __name__ == '__main__':
  tf.app.run()

(2)model.py函数

from utils import (
  read_data, 
  input_setup, 
  imsave,
  psnr,
  merge
)
import time
import os
import cv2
import matplotlib.pyplot as plt
from skimage import data, exposure, img_as_float

import numpy as np
import tensorflow as tf
#定义SRCNN类
class SRCNN(object):

  def __init__(self, 
               sess, 
               image_size=33,
               label_size=21, 
               batch_size=64,
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):
    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size
    self.c_dim = c_dim
    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()
#搭建网络
  def build_model(self):
    self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
    #第一层CNN:对输入图片的特征提取。(9 x 9 x 64卷积核)
    #第二层CNN:对第一层提取的特征的非线性映射(1 x 1 x 32卷积核)
    #第三层CNN:对映射后的特征进行重建,生成高分辨率图像(5 x 5 x 1卷积核)
    #权重
    self.weights = {
      #论文中为提高训练速度的设置 n1=32 n2=16
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    #偏置
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }
    self.pred = self.model()
    # 以MSE作为损耗函数
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
    self.saver = tf.train.Saver()
  #主函数调用(训练或测试)
  def train(self, config):
    if config.is_train:  #判断是否为训练(main传入)
      input_setup(self.sess, config)
    else:
      nx, ny = input_setup(self.sess, config)  
    #训练为checkpoint下train.h5
    #测试为checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")
    train_data, train_label = read_data(data_dir)#读取.h5文件(由测试和训练决定)
    global_step=tf.Variable(0)#定义global_step 它会自动+1
    #通过exponential_decay函数生成学习率
    learning_rate_exp=tf.train.exponential_decay(config.learning_rate , global_step , 1480 , 0.98 , staircase=True)  #每1个Epoch 学习率*0.98   
    #标准反向传播的随机梯度下降
    #self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)#学习率learning rate  使self.loss有最小值
    self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss , global_step=global_step)
    #Adam  替换上面的连续4行
    #self.train_op = tf.train.AdamOptimizer(config.learning_rate).minimize(self.loss, global_step=global_step)
    
    #出现warning : initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
    #tf.initialize_all_variables().run()
    tf.global_variables_initializer().run() #替换掉上句
    counter = 0
    start_time = time.time()
    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")
    #训练
    if config.is_train:     
      print("Training...")
      for ep in range(config.epoch): #迭代次数的循环
        # 以batch为单元
        batch_idxs = len(train_data) // config.batch_size
        for idx in range(0, batch_idxs):
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]
          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})
          if counter % 10 == 0:  #10的倍数step显示
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))
          if counter % 500 == 0:  #500的倍数step存储
            self.save(config.checkpoint_dir, counter)
    #测试
    else:   
      print("Testing...")
      result = self.pred.eval({self.images: train_data, self.labels: train_label}) # 从test.h中来 
      result = merge(result, [nx, ny])
      result = result.squeeze()#除去size为1的维度
      #result= exposure.adjust_gamma(result, 1.07)#调暗一些
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "MySRCNN.bmp")
      imsave(result, image_path)
  def model(self):
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

  def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #再一次确定路径为 checkpoint->srcnn_21下
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),  #文件名为SRCNN.model-迭代次数
                    global_step=step)
  def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #路径为checkpoint->srcnn_labelsize(21)
    #加载路径下的模型(.meta文件保存当前图的结构; .index文件保存当前参数名; .data文件保存当前参数值)
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))  #saver.restore()函数给出model.-n路径后会自动寻找参数名-值文件进行加载
        return True
    else:
        return False

(3)utils.py函数

import os
import cv2
import glob
import h5py
import random
import matplotlib.pyplot as plt
import tensorflow as tf
import scipy.ndimage
from skimage import transform,data
from scipy.ndimage import filters
from PIL import Image  # for loading images as YCbCr format
from skimage import exposure#gamma
import scipy.misc
import scipy.ndimage
import numpy as np
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

def read_data(path):
  """
  读取h5格式数据文件,用于训练或者测试
  参数:
    路径: 文件
    data.h5 包含训练输入
    label.h5 包含训练输出
  """
  with h5py.File(path, 'r') as hf:  #读取h5格式数据文件(用于训练或测试)
    data = np.array(hf.get('data'))
    label = np.array(hf.get('label'))
    return data, label

def preprocess(path, scale=3):
#对路径下的image裁剪成scale整数倍,再对image缩小1/scale倍后,放大scale倍以得到低分辨率图input_,调整尺寸后的image为高分辨率图label_


  #image = imread(path, is_grayscale=True)
  #label_ = modcrop(image, scale)
  scale-=1
  image=scipy.misc.imread(path, mode='YCbCr').astype(np.float)
  image=modcrop(image , scale)
  label_=image[:,:,0]
  # 标准化
  image = image/255.
  label_ = label_/255.
  input_ = scipy.ndimage.interpolation.zoom(label_, (1./(scale)),mode='wrap',prefilter=False)#order=4  mode='wrap', 
  input_ = scipy.ndimage.interpolation.zoom(input_, ((scale)/1.),mode='wrap',prefilter=False)
  label_small=modcrop_small(label_)#把原图裁剪成和输出一样的大小
  input_small=modcrop_small(input_)#把原图裁剪成和输出一样的大小
  imsave(input_small, "C:\\Users\\lenovo\\Desktop\\SRCNN\\SRCNN-Tensorflow-master_MY\\sample\\bicubic.bmp")#保存插值图像
  imsave(label_small, "C:\\Users\\lenovo\\Desktop\\SRCNN\\SRCNN-Tensorflow-master_MY\\sample\\origin.bmp")#保存原始图像
  imsave(input_, "C:\\Users\\lenovo\\Desktop\\SRCNN\\SRCNN-Tensorflow-master_MY\\sample\\input_.bmp")#保存input_图像
  imsave(label_, "C:\\Users\\lenovo\\Desktop\\SRCNN\\SRCNN-Tensorflow-master_MY\\sample\\label_.bmp")#保存label_图像
  return input_, label_
  
def prepare_data(sess, dataset):
  if FLAGS.is_train:
    filenames = os.listdir(dataset)
    data_dir = os.path.join(os.getcwd(), dataset)
    data = glob.glob(os.path.join(data_dir, "*.bmp"))
  else:
    #确定测试数据集合的文件夹为Set5
    data_dir = os.path.join((os.path.join(os.getcwd(), dataset)),"Set5")
    data = glob.glob(os.path.join(data_dir,"*.bmp"))
  return data
def make_data(sess, data, label):
#把数据保存成.h5格式
  if FLAGS.is_train:
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
  else:
    savepath = os.path.join(os.getcwd(), 'checkpoint/test.h5')
  with h5py.File(savepath, 'w') as hf:
    hf.create_dataset('data', data=data)
    hf.create_dataset('label', data=label)
def imread(path, is_grayscale=True):
#读指定路径的图像
  if is_grayscale:
    return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)
  else:
    return scipy.misc.imread(path, mode='YCbCr').astype(np.float)
def modcrop(image, scale=3):
#把图像的长和宽都变成scale的倍数
  if len(image.shape) == 3:
    h, w, _ = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w, :]
  else:
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    image = image[0:h, 0:w]
  return image
#把result变为和origin一样的大小
def modcrop_small(image):
  #6来自padding = abs(config.image_size - config.label_size) // 2
  #21来自label_size
  #33来自image_size
  padding2 = 6
  #padding2 = 0
  if len(image.shape) == 3:
    h, w, _ = image.shape
    h = (h-33+1)//21*21+21+padding2
    w =(w-33+1)//21*21+21+padding2
    image1 = image[padding2:h, padding2:w, :]#6
  else:
    h, w = image.shape
    h = (h-33+1)//21*21+21+padding2
    w =(w-33+1)//21*21+21+padding2
    image1 = image[padding2:h, padding2:w]
  return image1
def input_setup(sess, config):
  #global nx#后加
  #global ny#后加
  #读图像集,制作子图并保存为h5文件格式
  # 读取数据路径
  if config.is_train:
    data = prepare_data(sess, dataset="Train")
    print(len(data))#
  else:
    data = prepare_data(sess, dataset="Test")
    print(len(data))#
  sub_input_sequence = []
  sub_label_sequence = []
  padding = abs(config.image_size - config.label_size) // 2 # 6
  #padding=0;#修改padding值,测试效果
  #训练
  if config.is_train: 
    for i in range(len(data)):#一幅图作为一个data
      input_, label_ = preprocess(data[i], config.scale)#得到data[]的LR和HR图input_和label_
      if len(input_.shape) == 3:
        h, w, _ = input_.shape
      else:
        h, w = input_.shape
      #把input_和label_分割成若干自图sub_input和sub_label
      for x in range(0, h-config.image_size+1, config.stride):
        for y in range(0, w-config.image_size+1, config.stride):
          sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
          sub_label = label_[x+padding:x+padding+config.label_size, y+padding:y+padding+config.label_size] # [21 x 21]
          sub_input = sub_input.reshape([config.image_size, config.image_size, 1])#按image size大小重排 因此 imgae_size应为33 而label_size应为21
          sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
          sub_input_sequence.append(sub_input)#在sub_input_sequence末尾加sub_input中元素 但考虑为空
          sub_label_sequence.append(sub_label)
  else:
        #测试
        input_, label_ = preprocess(data[0], config.scale)#测试图片
        if len(input_.shape) == 3:
          h, w, _ = input_.shape
        else:
          h, w = input_.shape
        nx = 0 #后注释
        ny = 0 #后注释
        #自图需要进行合并操作
        for x in range(0, h-config.image_size+1, config.stride): #x从0到h-33+1 步长stride(21)
          nx += 1
          ny = 0
          for y in range(0, w-config.image_size+1, config.stride):#y从0到w-33+1 步长stride(21)
            ny += 1
            #分块sub_input=input_[x:x+33,y:y+33]  sub_label=label_[x+6,x+6+21, y+6,y+6+21]
            sub_input = input_[x:x+config.image_size, y:y+config.image_size] # [33 x 33]
            sub_label = label_[x+padding:x+padding+config.label_size, y+padding:y+padding+config.label_size] # [21 x 21] 
            sub_input = sub_input.reshape([config.image_size, config.image_size, 1])  
            sub_label = sub_label.reshape([config.label_size, config.label_size, 1])
            sub_input_sequence.append(sub_input)
            sub_label_sequence.append(sub_label)
  # 上面的部分和训练是一样的
  arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
  arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]
  make_data(sess, arrdata, arrlabel)#存成h5格式
  if not config.is_train:#后注释
    return nx, ny #后注释
def imsave(image, path):
  return scipy.misc.imsave(path, image)
def merge(images, size):
  h, w = images.shape[1], images.shape[2] #觉得下标应该是0,1
  #h, w = images.shape[0], images.shape[1]
  img = np.zeros((h*size[0], w*size[1], 1))
  for idx, image in enumerate(images):
    i = idx % size[1]
    j = idx // size[1]
    img[j*h:j*h+h, i*w:i*w+w, :] = image
  return img
def psnr( im1 , im2 ): #计算结果的PSNR  转到matlab里去
  diff = np.abs(im1- im2)
  rmse=np.sqrt(diff).sum()  
  psnr=20*np.log10(255/rmse)
  return psnr

3 测试结果

在Set5和Set14两测试集中进行测试。其中,baby图像的测试结果如下:

注意睫毛这类细节部分的效果还是不错的

最后给出本文复现的SRCNN与作者的SRCNN在Set5测试集上的PSNR对比


附录

一些简单的CNN的介绍链接

(1)

RCNN,fast RCNN,faster RCNN比较归纳总结(一)blog.csdn.net图标

(2)

基于深度学习的目标检测技术演进:R-CNN、Fast R-CNN、Faster R-CNNwww.cnblogs.com图标

(3)

Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现blog.csdn.net图标

编辑于 2018-11-14

文章被以下专栏收录