ATSS论文阅读笔记以及核心代码解析

ATSS论⽂阅读笔记以及核⼼代码解析
⽂章⽬录
1 论⽂题⽬
Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
本篇论⽂也是⼊选了2020年的CVPR,算是今年CVPR⾥⽬标检测⽅向为数不多的论⽂之⼀。
2 论⽂⽬的
⾸先,本⽂作者指出在⽬标检测的⽅法中,anchor-based的⽅法与anchor-free的⽅法之间主要的区别在于如何定义正负样本。如果两类⽅法采取相同的定义⽅式,那么这两类⽅法将会取得差不多的效果。
针对上述发现,作者提出了Adaptive Training Sample Selection (ATSS)的⽅法去⾃动的选取正负样本,这种⽅法弥补了anchor-based 的⽅法与anchor-free⽅法之间的差距。
3 论⽂实现
作者分别选取了两类⽅法中具有代表性的⽅法:RetinaNet和FCOS,并以此详细的说明了ATSS⽅法。
⾸先,在RetinaNet中,是采⽤IOU的⽅式区分正负样本,因此是在空间维度和尺度维度同时进⾏选择。⽽在FCOS中,⾸先是在空间维度上选定⼀些候选的正样本,并在此基础上在尺度维度上选择正样本,剩余的作为负样本。
上述实验结果说明:(按列分析)对于RetinaNet来说,定义正负样本的⽅式从IOU改为Spatial and Scale Constraint的话,map从37.0%提升到37.8%。⽽对于FCOS来说,定义正负样本的⽅式从Spatial and Scale Constraint改变为IOU,map从37.8%降低到36.9%。
所以,如何定义正负样本,是很重要的。
从另⼀个⾓度来分析上述表格(按⾏分析),即回归的⽅式,对于RetinaNet来说,从anchor box得到最后的bbox,改为从anchor point得到最后的bbox,对于map⼏乎没有影响。对于FCOS也⼀样。
所以回归的⽅式不是⼀个重要的因素。
下⾯将具体说明ATSS具体的步骤:
先给出伪算法框图:
输⼊: g: 输⼊图⽚所有ground truth
L: ⾦字塔⽹络的层数
: 第i层⾦字塔⽹络的先验框集合
A: 所有先验框
K: 对于每个ground truth 的中⼼,我们每层选取k个与之最近的先验框。
输出: P:正样本集合
N: 负样本集合
具体流程:
(1)确定⼀些正例的候选样本(根据L2距离,每层选取k个先验框),所以如果有l层的话,则共有k*l个先验框
(2)计算先验框与ground truth之间的IOU
(3)利⽤上⾯计算的IOU,分别计算均值和标准差
(4)利⽤ + ,来定义正负样本,即IOU得分⼤于等于,则被定义成正样本,否则就是负样本。
吾守尔大爷的冰(5)在第(4)步选择正负样本的过程中,还有个判断条件就是该样本的中⼼必须在⽬标内,才能被定义成正样本。
关于ATSS算法的⼀些解释:
(1)作者认为,先验框的中⼼离ground truth越近,先验框的质量越好
(2)关于均值和标准差(从统计学的原理说明):
均值:某个⽬标的IOU均值能够说明:对于这个⽬标,相应的先验框设置的是否合理。
标准差:能够说明⾦字塔结构中的哪⼀层更加适合检测这个⽬标。
(3)保证先验框的中⼼在ground truth之内
4 核⼼代码解读
作者开源的代码是直接在FOCS代码上改进的,所以在此只分析关于ATSS的相关代码,如果想看FCOS的相关代码,可以看看这篇博客,写的很好。
下⾯⾔归正传,先给出ATSS核⼼代码的路径,其实刚开始我也没到,还是特意问的作者. (⼩⽩⼀枚)
ATSS-master/atss_core/modeling/rpn/atss/loss.py
建议下⾯的代码结合上⾯的伪算法流程图⼀起看,更容易理解
elif self.cfg.MODEL.ATSS.POSITIVE_TYPE =='ATSS':
num_anchors_per_loc =len(self.cfg.MODEL.ATSS.ASPECT_RATIOS)* self.cfg.MODEL.ATSS.SCALES_PER_OCTAVE
num_anchors_per_loc =len(self.cfg.MODEL.ATSS.ASPECT_RATIOS)* self.cfg.MODEL.ATSS.SCALES_PER_OCTAVE
#每个level 的 anchor 数量
num_anchors_per_level =[len(anchors_per_level.bbox)for anchors_per_level in anchors[im_i]]
ious =boxlist_iou(anchors_per_im, targets_per_im)
gt_cx =(bboxes_per_im[:,2]+ bboxes_per_im[:,0])/2.0
gt_cy =(bboxes_per_im[:,3]+ bboxes_per_im[:,1])/2.0
gt_points = torch.stack((gt_cx, gt_cy), dim=1)
anchors_cx_per_im =(anchors_per_im.bbox[:,2]+ anchors_per_im.bbox[:,0])/2.0
anchors_cy_per_im =(anchors_per_im.bbox[:,3]+ anchors_per_im.bbox[:,1])/2.0
anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1)
#计算 anchor 和 GT 之间的L2距离
distances =(anchor_points[:, None,:]- gt_points[None,:,:]).pow(2).sum(-1).sqrt()
# Selecting candidates based on the center distance between anchor box and object
candidate_idxs =[]
star_idx =0
#遍历每⼀张img 的的每⼀个level 的  anchor 集合
交通安全与智能控制for level, anchors_per_level in enumerate(anchors[im_i]):
end_idx = star_idx + num_anchors_per_level[level]
种子包衣剂
distances_per_level = distances[star_idx:end_idx,:]
topk =min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level])
# 根据L2 距离选择前K个 anchor
_, topk_idxs_per_level = distances_pk(topk, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + star_idx)
#为了记录下⼀个level,不然的话,下⼀个level的candidate_idxs 会把上⼀个level的candidate_idxs 覆盖掉
star_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=0)
# Using the sum of mean and standard deviation as the IoU threshold to select final positive samples
#计算 anchor 和 GT 之间的 IOU
candidate_ious = ious[candidate_idxs, torch.arange(num_gt)]
#计算均值
iou_mean_per_gt = an(0)
#计算标准差
iou_std_per_gt = candidate_ious.std(0)
#新的阈值=均值+标准差
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt土楼公社
#选出candidate中 IOU ⼤于新的阈值的anchors
is_pos = candidate_ious >= iou_thresh_per_gt[None,:]
# 保证最后保留的anchor的中⼼在ground truth中
# Limiting the final positive samples’ center to object
anchor_num = anchors_cx_per_im.shape[0]
网上购物狂
for ng in range(num_gt):
candidate_idxs[:, ng]+= ng * anchor_num
e_anchors_cx = anchors_cx_per_im.view(1,-1).expand(num_gt, anchor_num).contiguous().view(-1)
e_anchors_cy = anchors_cy_per_im.view(1,-1).expand(num_gt, anchor_num).contiguous().view(-1)
#view(-1) 是将变量转换成1维
#这块怎么判断center 在 GT 之内,有点没看懂
candidate_idxs = candidate_idxs.view(-1)
l = e_anchors_cx[candidate_idxs].view(-1, num_gt)- bboxes_per_im[:,0]
t = e_anchors_cy[candidate_idxs].view(-1, num_gt)- bboxes_per_im[:,1]
r = bboxes_per_im[:,2]- e_anchors_cx[candidate_idxs].view(-1, num_gt)
b = bboxes_per_im[:,3]- e_anchors_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0]>0.01
#is_pos --> is_postive 最后返回 true or false 表⽰该图⽚是否是正例
is_pos = is_pos & is_in_gts
is_pos = is_pos & is_in_gts
# if an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
#创建⼀个元素都为-INF的tensor,并展开成⼀维
ious_inf = torch.full_like(ious,-INF).t().contiguous().view(-1)
index = candidate_idxs.view(-1)[is_pos.view(-1)]
# 向ious_inf tensor中对应的位置(index)赋予iou的值
ious_inf[index]= ious.t().contiguous().view(-1)[index]
#将 ious_inf 转换成num_gt⾏,每⾏都对应着每个anchor 与这个GT 的IOU值
ious_inf = ious_inf.view(num_gt,-1).t()
#出与每个GT  IOU值最⼤的那个anchor 的  IOU值以及序号(index)
anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1)
#到这些anchor对应的label
cls_labels_per_im = labels_per_im[anchors_to_gt_indexs]
#将不具有最⼤IOU值的anchor  分类得分置为0
cls_labels_per_im[anchors_to_gt_values ==-INF]=0
#通过label保留最后的bbox
matched_gts = bboxes_per_im[anchors_to_gt_indexs]
想写的就这么多了,第⼀次写这种代码解析的博客,有什么错误请⼤家积极指出,我尽量改正,如果对于相关代码有什么好的理解,可以写在评论⾥。当然,⼤家有⽐较好的关于⽬标检测的论⽂可以写在评论⾥,也⽅便⼤家⼀起学习。
顾客价值
感谢⼤家。

本文发布于:2024-09-20 20:21:29,感谢您对本站的认可!

本文链接:https://www.17tex.com/xueshu/596827.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:样本   代码   先验   定义   正负
留言与评论(共有 0 条评论)
   
验证码:
Copyright ©2019-2024 Comsenz Inc.Powered by © 易纺专利技术学习网 豫ICP备2022007602号 豫公网安备41160202000603 站长QQ:729038198 关于我们 投诉建议