深度学习项目示例 使用自编码器进行模糊图像修复
ztj100 2024-12-19 17:56 22 浏览 0 评论
图像模糊是由相机或拍摄对象移动、对焦不准确或使用光圈配置不当导致的图像不清晰。 为了获得更清晰的照片,我们可以使用相机镜头的首选焦点重新拍摄同一张照片,或者使用深度学习知识重现模糊的图像。 由于我的专长不是摄影,只能选择使用深度学习技术对图像进行去模糊处理!
在开始这个项目之前,本文假定读者应该了解深度学习的基本概念,例如神经网络、CNN。 还要稍微熟悉一下 Keras、Tensorflow 和 OpenCV。
有各种类型的模糊——运动模糊、高斯模糊、平均模糊等。 但我们将专注于高斯模糊图像。 在这种模糊类型中,像素权重是不相等的。 模糊在中心处较高,在边缘处按照钟形曲线减少。
数据集
在开始使用代码之前,首先需要的是一个由 2 组图像组成的数据集——模糊图像和干净图像。 目前可能没有现成的数据集可以使用,但是就像我们上面所说的,如果你有opencv的基础这个对于我们来说是非常个简单的,只要我们有原始图像,使用opencv就可以自己生成训练需要的数据集。
这里我的数据集大小约为 50 张图像(50 张干净图像和 50 张模糊图像),因为只是演示目的所以只选择了少量图像。
编写代码
已经准备好数据集,可以开始编写代码了。
依赖项
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import random
import cv2
import os
import tensorflow as tf
from tqdm import tqdm
这里导入了 tqdm 库来帮助我创建进度条,这样可以知道运行代码需要多长时间。
导入数据
good_frames = '/content/drive/MyDrive/mini_clean'
bad_frames = '/content/drive/MyDrive/mini_blur'
现在创建了2 个列表。 我们将使用 keras 预处理库读取“.jpg”、“jpeg”或“.png”类型的图像,并转换为数组。这里图像尺寸为 128x128。
clean_frames = []
for file in tqdm(sorted(os.listdir(good_frames))):
if any(extension in file for extension in ['.jpg', 'jpeg', '.png']):
image = tf.keras.preprocessing.image.load_img(good_frames + '/' + file, target_size=(128,128))
image = tf.keras.preprocessing.image.img_to_array(image).astype('float32') / 255
clean_frames.append(image)
clean_frames = np.array(clean_frames)
blurry_frames = []
for file in tqdm(sorted(os.listdir(bad_frames))):
if any(extension in file for extension in ['.jpg', 'jpeg', '.png']):
image = tf.keras.preprocessing.image.load_img(bad_frames + '/' + file, target_size=(128,128))
image = tf.keras.preprocessing.image.img_to_array(image).astype('float32') / 255
blurry_frames.append(image)
blurry_frames = np.array(blurry_frames)
导入模型库
from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
from keras.utils.vis_utils import plot_model
from keras import backend as K
random.seed = 21
np.random.seed = seed
将数据集拆分为训练集和测试集
现在我们按 80:20 的比例将数据集分成训练和测试集。
x = clean_frames;
y = blurry_frames;
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
检查训练和测试数据集的形状
print(x_train[0].shape)
print(y_train[0].shape)
r = random.randint(0, len(clean_frames)-1)
print(r)
fig = plt.figure()
fig.subplots_adjust(hspace=0.1, wspace=0.2)
ax = fig.add_subplot(1, 2, 1)
ax.imshow(clean_frames[r])
ax = fig.add_subplot(1, 2, 2)
ax.imshow(blurry_frames[r])
上面的代码可以查看来自训练和测试数据集的图像,例如:
下面初始化一些编写模型时需要用到的参数
# Network Parameters
input_shape = (128, 128, 3)
batch_size = 32
kernel_size = 3
latent_dim = 256
# Encoder/Decoder number of CNN layers and filters per layer
layer_filters = [64, 128, 256]
编码器模型
自编码器的结构我们以前的文章中已经详细介绍过多次了,这里就不详细说明了
inputs = Input(shape = input_shape, name = 'encoder_input')
x = inputs
首先就是输入(图片的数组),获取输入后构建一个 Conv2D(64) - Conv2D(128) - Conv2D(256) 的简单的编码器,编码器将图片压缩为 (16, 16, 256) ,该数组将会是解码器的输入。
for filters in layer_filters:
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=2,
activation='relu',
padding='same')(x)
shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
这里的 K.int_shape()将张量转换为整数元组。
实例化编码器模型,如下
encoder = Model(inputs, latent, name='encoder')
encoder.summary()
解码器模型
解码器模型类似于编码器模型,但它进行相反的计算。 解码器以将输入解码回 (128, 128, 3)。 所以这里的将使用 Conv2DTranspose(256) - Conv2DTranspose(128) - Conv2DTranspose(64)。
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)for filters in layer_filters[::-1]:
x = Conv2DTranspose(filters=filters,
kernel_size=kernel_size,
strides=2,
activation='relu',
padding='same')(x)
outputs = Conv2DTranspose(filters=3,
kernel_size=kernel_size,
activation='sigmoid',
padding='same',
name='decoder_output')(x)
解码器如下:
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
整合成自编码器
自编码器 = 编码器 + 解码器
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.summary()
最后但是非常重要的是在训练我们的模型之前需要设置超参数。
autoencoder.compile(loss='mse', optimizer='adam',metrics=["acc"])
我选择损失函数为均方误差,优化器为adam,评估指标为准确率。然后还需要定义学习率调整的计划,这样可以在指标没有改进的情况下降低学习率,
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
cooldown=0,
patience=5,
verbose=1,
min_lr=0.5e-6)
学习率的调整需要在训练的每个轮次都调用,
callbacks = [lr_reducer]
训练模型
history = autoencoder.fit(blurry_frames,
clean_frames,
validation_data=(blurry_frames, clean_frames),
epochs=100,
batch_size=batch_size,
callbacks=callbacks)
运行此代码后,可能需要大约 5-6 分钟甚至更长时间才能看到最终输出,因为我们设置了训练轮次为100,
最后结果
现在已经成功训练了模型,让我们看看我们的模型的预测,
print("\n Input Ground Truth Predicted Value")
for i in range(3):
r = random.randint(0, len(clean_frames)-1)
x, y = blurry_frames[r],clean_frames[r]
x_inp=x.reshape(1,128,128,3)
result = autoencoder.predict(x_inp)
result = result.reshape(128,128,3)
fig = plt.figure(figsize=(12,10))
fig.subplots_adjust(hspace=0.1, wspace=0.2)
ax = fig.add_subplot(1, 3, 1)
ax.imshow(x)
ax = fig.add_subplot(1, 3, 2)
ax.imshow(y)
ax = fig.add_subplot(1, 3, 3)
plt.imshow(result)
可以看到该模型在去模糊图像方面做得很好,并且几乎能够获得原始图像。 因为我们只用了3层的卷积架构,所以如果我们使用更深的模型,还有一些超参数的调整应该会获得更好的结果。
为了查看训练的情况,可以绘制损失函数和准确率的图表,可以通过这些数据做出更好的决策。
损失的变化
plt.figure(figsize=(12,8))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['Train', 'Test'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(np.arange(0, 101, 25))
plt.show()
可以看到损失显着减少,然后从第 80 个 epoch 开始停滞不前。
准确率
plt.figure(figsize=(12,8))
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.legend(['Train', 'Test'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.xticks(np.arange(0, 101, 25))
plt.show()
这里可以看到准确率显着提高,如果训练更多轮,它可能会进一步提高。 因此,可以尝试增加 epoch 大小并检查准确率是否确实提高了,或者增加早停机制,让训练自动停止
总结
我们取得了不错的准确率,为 78.07%。 对于实际的应用本文只是开始,例如更好的网络架构,更多的数据,和超参数的调整等等,如果你有什么改进的想法也欢迎留言
作者:Chandana Kuntala
- 上一篇:用U-Net提取航拍图中的建筑物轮廓
- 下一篇:深度学习中的「卷积层」如何深入理解?
相关推荐
- SpringBoot整合SpringSecurity+JWT
-
作者|Sans_https://juejin.im/post/5da82f066fb9a04e2a73daec一.说明SpringSecurity是一个用于Java企业级应用程序的安全框架,主要包含...
- 「计算机毕设」一个精美的JAVA博客系统源码分享
-
前言大家好,我是程序员it分享师,今天给大家带来一个精美的博客系统源码!可以自己买一个便宜的云服务器,当自己的博客网站,记录一下自己学习的心得。开发技术博客系统源码基于SpringBoot,shiro...
- springboot教务管理系统+微信小程序云开发附带源码
-
今天给大家分享的程序是基于springboot的管理,前端是小程序,系统非常的nice,不管是学习还是毕设都非常的靠谱。本系统主要分为pc端后台管理和微信小程序端,pc端有三个角色:管理员、学生、教师...
- SpringBoot+LayUI后台管理系统开发脚手架
-
源码获取方式:关注,转发之后私信回复【源码】即可免费获取到!项目简介本项目本着避免重复造轮子的原则,建立一套快速开发JavaWEB项目(springboot-mini),能满足大部分后台管理系统基础开...
- Spring Boot的Security安全控制——认识SpringSecurity!
-
SpringBoot的Security安全控制在Web项目开发中,安全控制是非常重要的,不同的人配置不同的权限,这样的系统才安全。最常见的权限框架有Shiro和SpringSecurity。Shi...
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
-
前言不得不佩服SpringBoot的生态如此强大,今天给大家推荐几款优秀的后台管理系统,小伙伴们再也不用从头到尾撸一个项目了。SmartAdmin...
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
-
SpringBoot算是目前Java领域最火的技术栈了,除了书呢?当然就是开源项目了,今天整理15个开源领域非常不错的SpringBoot项目供大家学习,参考。高富帅的路上只能帮你到这里了,...
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
-
前言推荐这个项目是因为使用手册部署手册非常...
- 2021年超详细的java学习路线总结—纯干货分享
-
本文整理了java开发的学习路线和相关的学习资源,非常适合零基础入门java的同学,希望大家在学习的时候,能够节省时间。纯干货,良心推荐!第一阶段:Java基础...
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
-
jeecg-boot学习总结及使用心得1.jeecg-boot是一个真正前后端分离的模版项目,便于二次开发,使用的都是较流行的新技术,后端技术主要有spring-boot2.x、shiro、Myb...
- 后勤集团原料管理系统springboot+Layui+MybatisPlus+Shiro源代码
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述后勤集团原料管理系统spr...
- 白卷开源SpringBoot+Vue的前后端分离入门项目
-
简介白卷是一个简单的前后端分离项目,主要采用Vue.js+SpringBoot技术栈开发。除了用作入门练习,作者还希望该项目可以作为一些常见Web项目的脚手架,帮助大家简化搭建网站的流程。...
- Spring Security 自动踢掉前一个登录用户,一个配置搞定
-
登录成功后,自动踢掉前一个登录用户,松哥第一次见到这个功能,就是在扣扣里边见到的,当时觉得挺好玩的。自己做开发后,也遇到过一模一样的需求,正好最近的SpringSecurity系列正在连载,就结...
- 收藏起来!这款开源在线考试系统,我爱了
-
大家好,我是为广大程序员兄弟操碎了心的小编,每天推荐一个小工具/源码,装满你的收藏夹,每天分享一个小技巧,让你轻松节省开发效率,实现不加班不熬夜不掉头发,是我的目标!今天小编推荐一款基于Spr...
- Shiro框架:认证和授权原理(shiro权限认证流程)
-
优质文章,及时送达前言Shiro作为解决权限问题的常用框架,常用于解决认证、授权、加密、会话管理等场景。本文将对Shiro的认证和授权原理进行介绍:Shiro可以做什么?、Shiro是由什么组成的?举...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- SpringBoot整合SpringSecurity+JWT
- 「计算机毕设」一个精美的JAVA博客系统源码分享
- springboot教务管理系统+微信小程序云开发附带源码
- SpringBoot+LayUI后台管理系统开发脚手架
- Spring Boot的Security安全控制——认识SpringSecurity!
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
- 2021年超详细的java学习路线总结—纯干货分享
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)