百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

NSFW 图片分类 图片分类管理系统

ztj100 2024-12-28 16:50 58 浏览 0 评论

NSFW指的是 不适宜工作场所("Not Safe (or Suitable) For Work;")。在本文中,我介绍如何创建一个检测NSFW图像的图像分类模型。

数据集

由于数据集的性质,我们无法从一些数据集的网站(如Kaggle等)获得所有图像。

但是我们找到了一个专门抓取这种类型图片的github库,所以我们可以直接使用。clone项目后可以运行下面的代码来创建文件夹,并将每个图像下载到其特定的文件夹中。

folders = ['drawings','hentai','neutral','porn','sexy']
urls = ['urls_drawings.txt','urls_hentai.txt','urls_neutral.txt','urls_porn.txt','urls_sexy.txt']
names = ['d','h','n','p','s']
for i,j,k in zip(folders,urls,names):
try:
#Specify the path of the folder that has to be made
folder_path = os.path.join('your directory',i)
os.mkdir(folder_path)
except:
pass
#setup the path of url text file
url_path = os.path.join('Datasets_Urls',j)
my_file = open(url_path, "r")
data = my_file.read()
#create a list with all urls
data_into_list = data.split("\n")
my_file.close()
icount = 0
for ii in data_into_list:
try:
#create a unique image names for each images
image_name = 'image'+str(icount)+str(k)+'.png'
image_path = os.path.join(folder_path,image_name)
#download it using the library
urllib.request.urlretrieve(ii, image_path)
icount+=1
except Exception as e:
pass
#this below code is done to make the count of the image same for all the data 
#you can use a big number if you are building a more complex model or if you have a good system
if icount == 2000:
break

这里的folder变量表示类的名称,urls变量用于获取URL文本文件(可以根据文本文件名更改它),name变量用于为每个图像创建唯一的名称。

上面代码将为每个类下载2000张图像,可以编辑最后一个“if”条件来更改下载图像的个数。

数据准备

我们下载的文件夹可能包含其他类型的文件,所以首先必须删除不需要的类型的文件。

image_exts = ['jpeg','.jpg','bmp','png']
path_list = ['drawings','hentai','neutral','porn','sexy']
cwd = os.getcwd()
def remove_other_images(path_list):
for ii in path_list:
data_dir = os.path.join(cwd,'DataSet',ii)
for image in os.listdir(os.path.join(data_dir)):
image_path = os.path.join(data_dir,image_class,image)
try:
img = cv2.imread(image_path)
tip = imghdr.what(image_path)
if tip not in image_exts:
print('Image not in ext list {}'.format(image_path))
os.remove(image_path)
except Exception as e:
print("Issue with image {}".format(image_path))
remove_other_images(path_list)

上面的代码删除了扩展名不是指定格式的图像。

另外图像可能包含许多重复的图像,所以我们必须从每个文件夹中删除重复的图像。

cwd = os.getcwd()
path_list = ['drawings','hentai','neutral','porn','sexy']
def remove_dup_images(path_list):
for ii in path_list:
os.chdir(os.path.join(cwd,'DataSet',ii))
filelist = os.listdir()
duplicates = []
hash_keys = dict()
for index, filename in enumerate(filelist):
if os.path.isfile(filename):
with open(filename,'rb') as f:
filehash = hashlib.md5(f.read()).hexdigest()
if filehash not in hash_keys:
hash_keys[filehash] = index
else:
duplicates.append((index,hash_keys[filehash]))

for index in duplicates:
os.remove(filelist[index[0]])
print('{} duplicates removed from {}'.format(len(duplicates),ii))
remove_dup_images(path_list)

这里我们使用hashlib.md5编码来查找每个类中的重复图像。

Md5为每个图像创建一个唯一的哈希值,如果哈希值重复(重复图像),那么我们将重复图片添加到一个列表中,稍后进行删除。

因为使用TensorFlow框架所以需要判断是否被TensorFlow支持,所以我们这里加一个判断:

import tensorflow as tf
os.chdir('{data-set} directory')
cwd = os.getcwd()
for ii in path_list:
os.chdir(os.path.join(cwd,ii))
filelist = os.listdir()
for image_file in filelist:
with open(image_file, 'rb') as f:
image_data = f.read()
# Check the file format
_, ext = os.path.splitext(image_file)
if ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
print('Unsupported image format:', ext)
os.remove(os.path.join(cwd,ii,image_file)) 
else:
# Decode the image
try:
image = tf.image.decode_image(image_data)
except:
print(image_file)
print("unspported")
os.remove(os.path.join(cwd,ii,image_file))

以上就是数据准备的所有工作,在清理完数据后,我们可以拆分数据。比如分割创建一个训练、验证和测试文件夹,并手动添加文件夹中的图像,我们将80%用于训练,10%用于验证,10%用于测试。

模型

首先导入tensorflow

import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import hashlib
from imageio import imread
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.layers import Flatten,Dense,Input
from tensorflow.keras.models import Model,Sequential
from keras import optimizers

对于图像,默认大小设置为224,224。

IMAGE_SIZE = [224,224]

可以使用ImageDataGenerator库,进行数据增强。数据增强也叫数据扩充,是为了增加数据集的大小。ImageDataGenerator根据给定的参数创建新图像,并将其用于训练(注意:当使用ImageDataGenerator时,原始数据将不用于训练)。

train_datagen = ImageDataGenerator(
rescale=1./255,
preprocessing_function=preprocess_input,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')

对于测试集也是这样:

test_datagen = ImageDataGenerator(rescale=1./255)

为了演示,我们直接使用VGG模型

vgg = VGG16(input_shape=IMAGE_SIZE+[3],weights='imagenet',include_top=False

然后冻结前面的层:

for layer in vgg.layers:
layer.trainable = False

最后我们加入自己的分类头:

x = Flatten()(vgg.output)
prediction = Dense(5,activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=prediction)
model.summary()

模型是这样的:

训练

看看我们训练集:

train_set = train_datagen.flow_from_directory('DataSet/train',
target_size=(224,224),
batch_size=32,
class_mode='sparse')

验证集

val_set = train_datagen.flow_from_directory('DataSet/validation',
target_size=(224,224),
batch_size=32,
class_mode='sparse')

使用' sparse_categorical_crossentropy '损失,这样可以将标签编码为整数而不是独热编码。

from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras.metrics import CategoricalAccuracy
adam = optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy',MeanSquaredError(name='val_loss'),CategoricalAccuracy(name='val_accuracy')])

然后就可以训练了:

from datetime import datetime
from keras.callbacks import ModelCheckpoint
log_dir = 'vg_log'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)
start = datetime.now()
history = model.fit_generator(train_set,
validation_data=val_set,
epochs=100,
steps_per_epoch=len(train_set)// batch_size,
validation_steps=len(val_set)//batch_size,
callbacks=[tensorboard_callback],
verbose=1)
duration = datetime.now() - start
print("Time taken for training is ",duration)

模型训练了100次。得到了80%的验证准确率。f1得分为93%

预测

下面的函数将获取一个图像列表并根据该列表进行预测。

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
def print_classes(images,model):
classes = ['Drawing','Hentai','Neutral','Porn','Sexual']
fig, ax = plt.subplots(ncols=len(images), figsize=(20,20))
for idx,img in enumerate(images):
img = mpimg.imread(img)
resize = tf.image.resize(img,(224,224))
result = model.predict(np.expand_dims(resize/255,0))
result = np.argmax(result)
if classes[result] == 'Porn':
img = gaussian_filter(img, sigma=6)
elif classes[result] == 'Sexual':
img = gaussian_filter(img, sigma=6)
elif classes[result] == 'Hentai':
img = gaussian_filter(img, sigma=6)
ax[idx].imshow(img)
ax[idx].title.set_text(classes[result])
li = ['test1.jpeg','test2.jpeg','test3.jpeg','test4.jpeg','test5.jpeg']
print_classes(li,model)

看结果还是可以的。

最后,本文的源代码:

https://github.com/Nikhilthalappalli/NSFW-Classifier

但是我觉得源代码不重要,Alex Kim的项目才是你需要的:

https://github.com/alex000kim/nsfw_data_scraper

作者:Nikhil Thalappalli

相关推荐

sharding-jdbc实现`分库分表`与`读写分离`

一、前言本文将基于以下环境整合...

三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么

在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...

MySQL8行级锁_mysql如何加行级锁

MySQL8行级锁版本:8.0.34基本概念...

mysql使用小技巧_mysql使用入门

1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...

MySQL/MariaDB中如何支持全部的Unicode?

永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...

聊聊 MySQL Server 可执行注释,你懂了吗?

前言MySQLServer当前支持如下3种注释风格:...

MySQL系列-源码编译安装(v5.7.34)

一、系统环境要求...

MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了

对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...

MySQL字符问题_mysql中字符串的位置

中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...

深圳尚学堂:mysql基本sql语句大全(三)

数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...

MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?

大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...

一文讲清怎么利用Python Django实现Excel数据表的导入导出功能

摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...

用DataX实现两个MySQL实例间的数据同步

DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...

MySQL数据库知识_mysql数据库基础知识

MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...

如何为MySQL中的JSON字段设置索引

背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...

取消回复欢迎 发表评论: