深度学习--Pytorch构建栈式自编码器实现以图搜图任务(以cifar10数据集...

深度学习--Pytorch构建栈式编码器实现以图搜图任务(以cifar10数据集为
例)
Pytorch构建栈式⾃编码器实现以图搜图任务
本⽂旨在使⽤CIFAR-10数据集,构建与训练栈式⾃编码器,提取数据集中图像的特征;基于所提取的特征完成CIFAR-10中任意图像的检索任务并展⽰效果。搞清楚pytorch与tensorflow区别
pytorch
pytorch是⼀种python科学计算框架
作⽤:
⽆缝替换numpy,通过GPU实现神经⽹络的加速
通过⾃动微分机制,让神经⽹络实现更容易(即⾃动求导机制)
张量:类似于数组和矩阵,是⼀种特殊的数据结构。在pytorch中,神经⽹络的输⼊、输出以及⽹络的参数等数据,都是使⽤张量来进⾏描述的。
每个变量中都有两个标志:requires_grad volatile
requires_grad:
如果有⼀个单⼀的输⼊操作需要梯度,它的输出就需要梯度。只有所有输⼊都不需要梯度时,输出才不需要。
volatile:
只需要⼀个volatile的输⼊就会得到⼀个volatile输出。
tensorflow
TensorFlow 是由 Google Brain 团队为深度神经⽹络(DNN)开发的功能强⼤的开源软件库
TensorFlow 则还有更多的特点,如下:
⽀持所有流⾏语⾔,如 Python、C++、Java、R和Go。
可以在多种平台上⼯作,甚⾄是移动平台和分布式平台。
它受到所有云服务(AWS、Google和Azure)的⽀持。
Keras——⾼级神经⽹络 API,已经与 TensorFlow 整合。
与Torch/Theano ⽐较,TensorFlow 拥有更好的计算图表可视化。 允
许模型部署到⼯业⽣产中,并且容易使⽤。
有⾮常好的社区⽀持。 TensorFlow 不仅仅是⼀个软件库,它是⼀套包括 TensorFlow,TensorBoard 和TensorServing 的软件。搞清楚栈式⾃编码器的内部原理
  我们构建栈式编码器,⽤编码器再解码出来的结果和原标签对⽐进⾏训练模型,然后⽤中间编码提取到的特征直接和原图的特征进⾏对⽐,得到相似度,实现以图搜图。
整个⽹络的训练不是⼀蹴⽽就的,⽽是逐层进⾏的。
效果图
随机取测试集的五张图⽚,进⾏以图搜图(TOP8)
提取的分布式特征聚集图像:第⼀张为原图散点图,第⼆张以检索的TOP8的TOP1的提取特征散点图为例
代码及效果图
⽋完备编码器
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 24 18:37:55 2021
@author: ASUS
"""
import torch
import torchvision
import torch.utils.data
as nn
import matplotlib.pyplot as plt
import random #随机取测试集的图⽚
import time
starttime = time.time()
torch.manual_seed(1)
EPOCH =10
BATCH_SIZE =64
LR =0.005
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
ansforms.ToTensor(),
download=False)
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
卷积编码
ansforms.ToTensor(),
download=False)
# dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
shuffle=True)
train_data = torchvision.datasets.MNIST(
root='./data',
train=True,
ansforms.ToTensor(),
download=False
)
loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
class Stack_AutoEncoder(nn.Module):
def__init__(self):
super(Stack_AutoEncoder,self).__init__()
nn.Linear(32*32,256),
nn.Tanh(),
nn.Linear(256,128),
nn.Tanh(),
nn.Linear(128,32),
nn.Tanh(),
nn.Linear(32,16),
nn.Tanh(),
nn.Linear(16,8)
)
self.decoder = nn.Sequential(
nn.Linear(8,16),
nn.Tanh(),
nn.Linear(16,32),
nn.Tanh(),
nn.Linear(32,128),
nn.Tanh(),
nn.Linear(128,256),
nn.Tanh(),
nn.Linear(256,32*32),
nn.Sigmoid()
)
def forward(self, x):
encoded = der(x)
decoded = self.decoder(encoded)
return encoded,decoded
Coder = Stack_AutoEncoder()
print(Coder)
optimizer = torch.optim.Adam(Coder.parameters(),lr=LR)
loss_func = nn.MSELoss()
for epoch in range(EPOCH):
for step,(x,y)in enumerate(trainloader):
b_x = x.view(-1,32*32)
b_y = x.view(-1,32*32)
b_label = y
encoded , decoded = Coder(b_x)
#        print(encoded)
loss = loss_func(decoded,b_y)
<_grad()
loss.backward()
optimizer.step()
#        if step%5 == 0:
print('Epoch :', epoch,'|','train_loss:%.4f'%loss.data)
torch.save(Coder,'Stack_AutoEncoder.pkl')
print('________________________________________')
print('finish training')
endtime = time.time()
print('训练耗时:',(endtime - starttime))
#以图搜图函数
Coder = Stack_AutoEncoder()
Coder = torch.load('Stack_AutoEncoder.pkl')
def search_by_image(x,inputImage,K):
c =['b','g','r']#画特征散点图
loss_func = nn.MSELoss()
x_ = inputImage.view(-1,32*32)
encoded , decoded = Coder(x_)
#    print(encoded)
lossList=[]
for step,(test_x,y)in enumerate(testset):
if(step == x):#去掉原图
lossList.append((x,1))
continue
b_x = test_x.view(-1,32*32)
b_y = test_x.view(-1,32*32)
b_label = y
test_encoded , test_decoded = Coder(b_x)
loss = loss_func(encoded,test_encoded)
#        loss = round(loss, 4) #保留⼩数
lossList.append((step,loss.item()))
lossList=sorted(lossList,key=lambda x:x[1],reverse=False)[:K] print(lossList)
plt.figure(1)
#    plt.figure(figsize=(10, 10))
trueImage = shape((3,32,32)).transpose(0,2)
plt.imshow(trueImage)
plt.title('true')
plt.show()
for j in range(K):
showImage = testset[lossList[j][0]][0]#遍历相似度最⾼列表⾥的图        showImage = shape((3,32,32)).transpose(0,2)

本文发布于:2024-09-21 15:44:41,感谢您对本站的认可!

本文链接:https://www.17tex.com/tex/3/378273.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:编码器   特征   实现   栈式
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2024 Comsenz Inc.Powered by © 易纺专利技术学习网 豫ICP备2022007602号 豫公网安备41160202000603 站长QQ:729038198 关于我们 投诉建议