2026/3/4 1:21:21
网站建设
项目流程
荣誉章标志做网站,好看的网页设计模板,徐州商城建站系统,公司使用威联通nas做网站存储前言
本文介绍了内容感知Token聚合网络#xff08;CATANet#xff09;中的局部区域自注意力#xff08;LRSA#xff09;模块在YOLOv11中的结合。基于Transformer的图像超分辨率方法存在计算复杂度高、捕捉长距离依赖能力受限等问题。LRSA作为CATANet的核心辅助模块#x…前言本文介绍了内容感知Token聚合网络CATANet中的局部区域自注意力LRSA模块在YOLOv11中的结合。基于Transformer的图像超分辨率方法存在计算复杂度高、捕捉长距离依赖能力受限等问题。LRSA作为CATANet的核心辅助模块通过重叠补丁策略强化局部特征交互补充局部细节。我们将相关代码加入指定目录在ultralytics/nn/tasks.py中注册配置yolov11 - LRSA.yaml文件最后通过实验脚本和结果验证了方法的有效性。文章目录 YOLOv11改进大全卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总专栏链接: YOLOv11改进专栏介绍摘要基于 Transformer 的方法在图像超分辨率SR等底层视觉任务中展现出了卓越的性能。然而其计算复杂度随着空间分辨率的增加呈二次方级增长。一系列研究工作试图通过将低分辨率图像划分为局部窗口local windows、轴向条纹axial stripes或空洞窗口dilated windows来缓解这一问题。SR 通常利用图像的冗余性进行重建而这种冗余不仅存在于局部区域也存在于长距离区域中。然而上述方法将注意力计算局限在与内容无关的局部区域内直接限制了注意力机制捕捉长距离依赖关系的能力。为了解决这些问题我们提出了一种轻量级的内容感知 Token 聚合网络CATANet。具体而言我们提出了一种高效的内容感知 Token 聚合模块用于聚合长距离且内容相似的 Token该模块在所有图像 Token 间共享 Token 中心且仅在训练阶段对其进行更新。随后我们利用组内自注意力intra-group self-attention来实现长距离的信息交互。此外我们还设计了一种组间交叉注意力inter-group cross-attention以进一步增强全局信息的交互。实验结果表明与最先进的SOTA基于聚类的方法 SPIN 相比我们的方法取得了更优越的性能PSNR 最大提升了 0.33dB且推理速度几乎翻倍。文章链接论文地址论文地址代码地址代码地址基本原理Local-Region Self-AttentionLRSA局部区域自注意力是CATANet中负责细化图像局部细节的核心辅助模块与捕捉长距离依赖的Token-Aggregation BlockTAB形成功能互补共同支撑轻量级图像超分辨率任务的高效性能。其设计核心是在低计算复杂度前提下强化局部范围内像素/特征块的信息交互弥补长距离注意力在细节还原上的不足。一、核心定位与设计目标1. 核心定位作为CATANet深度特征提取阶段的关键组件每个残差组RG包含TAB、LRSA和3×3卷积LRSA专注于局部特征交互——在长距离依赖已被TAB捕捉后进一步优化图像边缘、纹理等细粒度细节避免因过度关注全局信息导致的局部模糊或 artifacts。2. 设计目标补充局部细节与TAB的长距离信息捕捉形成“全局局部”双重保障提升超分辨率图像的细节还原度保持轻量化采用高效结构设计避免局部注意力计算引入过多冗余适配手机等资源受限设备兼容并行计算通过重叠补丁Overlapping Patches设计兼顾局部交互效果与计算效率。二、核心设计与工作原理1. 结构来源与基础设计LRSA参考了HPINetHierarchical Pixel Integration Network的局部注意力结构核心采用重叠补丁Overlapping Patches策略将输入特征图划分为多个相互重叠的局部补丁而非无重叠的独立窗口确保相邻区域的特征能自然交互避免窗口边界处的细节断裂所有补丁共享查询Q、键K、值V的权重矩阵减少参数数量降低计算复杂度。2. 具体工作流程设LRSA的输入为经过TAB处理后的特征图 ( X_o \in \mathbb{R}^{N \times d} )其中 ( N ) 为特征token数量( d ) 为特征维度其工作流程可概括为3步补丁划分与特征投影将输入特征图按固定尺寸如8×8划分为重叠补丁每个补丁通过共享权重矩阵 ( WQ、WK、W^V \in \mathbb{R}^{d \times d} ) 分别投影为查询向量 ( Q )、键向量 ( K ) 和值向量 ( V )局部自注意力计算在每个补丁内部执行多头自注意力MSA运算捕捉补丁内特征的局部依赖关系如相邻像素的纹理关联、边缘连续性特征融合与输出将所有补丁的注意力输出按原位置拼接得到细化后的局部特征图 ( X_{out} \in \mathbb{R}^{N \times d} )传递给后续的ConvFFN卷积前馈网络进一步优化。3. 关键特性重叠补丁设计区别于SwinIR的非重叠固定窗口重叠设计让局部注意力更平滑避免“窗口效应”导致的图像边缘生硬权重共享所有补丁共用一套Q/K/V投影权重相比为每个补丁单独设计权重参数数量减少约10%-20%根据补丁数量调整符合轻量化需求低计算复杂度注意力计算仅局限于局部补丁内复杂度与补丁尺寸呈线性关系而非全局自注意力的二次复杂度确保推理效率。三、与其他注意力机制的区别与互补1. 与CATANet内部核心注意力的互补注意力机制关注范围核心功能计算复杂度LRSA局部补丁如8×8细化边缘、纹理等局部细节线性复杂度与补丁尺寸相关IASA组内自注意力跨图像的内容相似组捕捉长距离依赖线性复杂度与组内token数量相关IRCA组间交叉注意力组与全局token中心强化全局信息交互低复杂度( M \ll N )( M ) 为token中心数量LRSA的核心价值的是“补位”——IASA和IRCA解决了“长距离相似信息交互”问题但可能忽略局部像素的精细关联而LRSA专注于局部细节修复三者形成“全局依赖局部细节”的完整覆盖。2. 与传统局部注意力的区别相比SwinIR的“固定窗口注意力”LRSA的重叠补丁设计避免了窗口边界的信息割裂细节还原更自然相比NLSA非局部稀疏注意力LRSA不依赖哈希分组无需处理哈希冲突分组更稳定且计算更高效相比CNN的局部卷积LRSA通过自注意力机制能自适应捕捉局部特征的关联强度如强边缘与弱纹理的差异化关注而卷积的局部交互是固定权重的灵活性更弱。核心代码classLRSA(nn.Module):Attention module. Args: dim (int): Base channels. num (int): Number of blocks. qk_dim (int): Channels of query and key in Attention. mlp_dim (int): Channels of hidden mlp in Mlp. heads (int): Head numbers of Attention. def__init__(self,dim,qk_dim,mlp_dim,heads1):super().__init__()self.layernn.ModuleList([PreNorm(dim,Attention(dim,heads,qk_dim)),PreNorm(dim,ConvFFN(dim,mlp_dim))])defforward(self,x,ps):stepps-2crop_x,nh,nwpatch_divide(x,step,ps)# (b, n, c, ps, ps)b,n,c,ph,pwcrop_x.shape crop_xrearrange(crop_x,b n c h w - (b n) (h w) c)attn,ffself.layer crop_xattn(crop_x)crop_x crop_xrearrange(crop_x,(b n) (h w) c - b n c h w,nn,wpw)xpatch_reverse(crop_x,x,step,ps)_,_,h,wx.shape xrearrange(x,b c h w- b (h w) c)xff(x,x_size(h,w))x xrearrange(x,b (h w) c-b c h w,hh)returnxYOLO11引入代码在根目录下的ultralytics/nn/目录新建一个attention目录然后新建一个以LRSA.py为文件名的py文件 把代码拷贝进去。importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfromeinopsimportrearrangedefpatch_divide(x,step,ps):Crop image into patches. Args: x (Tensor): Input feature map of shape(b, c, h, w). step (int): Divide step. ps (int): Patch size. Returns: crop_x (Tensor): Cropped patches. nh (int): Number of patches along the horizontal direction. nw (int): Number of patches along the vertical direction. b,c,h,wx.size()ifhpsandwps:stepps crop_x[]nh0foriinrange(0,hstep-ps,step):topi downipsifdownh:toph-ps downh nh1forjinrange(0,wstep-ps,step):leftj rightjpsifrightw:leftw-ps rightw crop_x.append(x[:,:,top:down,left:right])nwlen(crop_x)//nh crop_xtorch.stack(crop_x,dim0)# (n, b, c, ps, ps)crop_xcrop_x.permute(1,0,2,3,4).contiguous()# (b, n, c, ps, ps)returncrop_x,nh,nwdefpatch_reverse(crop_x,x,step,ps):Reverse patches into image. Args: crop_x (Tensor): Cropped patches. x (Tensor): Feature map of shape(b, c, h, w). step (int): Divide step. ps (int): Patch size. Returns: output (Tensor): Reversed image. b,c,h,wx.size()outputtorch.zeros_like(x)index0foriinrange(0,hstep-ps,step):topi downipsifdownh:toph-ps downhforjinrange(0,wstep-ps,step):leftj rightjpsifrightw:leftw-ps rightw output[:,:,top:down,left:right]crop_x[:,index]index1foriinrange(step,hstep-ps,step):topi downips-stepiftoppsh:toph-ps output[:,:,top:down,:]/2forjinrange(step,wstep-ps,step):leftj rightjps-stepifleftpsw:leftw-ps output[:,:,:,left:right]/2returnoutputclassPreNorm(nn.Module):Normalization layer. Args: dim (int): Base channels. fn (Module): Module after normalization. def__init__(self,dim,fn):super().__init__()self.normnn.LayerNorm(dim)self.fnfndefforward(self,x,**kwargs):returnself.fn(self.norm(x),**kwargs)classdwconv(nn.Module):def__init__(self,hidden_features,kernel_size5):super(dwconv,self).__init__()self.depthwise_convnn.Sequential(nn.Conv2d(hidden_features,hidden_features,kernel_sizekernel_size,stride1,padding(kernel_size-1)//2,dilation1,groupshidden_features),nn.GELU())self.hidden_featureshidden_featuresdefforward(self,x,x_size):xx.transpose(1,2).view(x.shape[0],self.hidden_features,x_size[0],x_size[1]).contiguous()# b Ph*Pw cxself.depthwise_conv(x)xx.flatten(2).transpose(1,2).contiguous()returnxclassConvFFN(nn.Module):def__init__(self,in_features,hidden_featuresNone,out_featuresNone,kernel_size5,act_layernn.GELU):super().__init__()out_featuresout_featuresorin_features hidden_featureshidden_featuresorin_features self.fc1nn.Linear(in_features,hidden_features)self.actact_layer()self.dwconvdwconv(hidden_featureshidden_features,kernel_sizekernel_size)self.fc2nn.Linear(hidden_features,out_features)defforward(self,x,x_size):xself.fc1(x)xself.act(x)xxself.dwconv(x,x_size)xself.fc2(x)returnxclassAttention(nn.Module):Attention module. Args: dim (int): Base channels. heads (int): Head numbers. qk_dim (int): Channels of query and key. def__init__(self,dim,heads,qk_dim):super().__init__()self.headsheads self.dimdim self.qk_dimqk_dim self.scaleqk_dim**-0.5self.to_qnn.Linear(dim,qk_dim,biasFalse)self.to_knn.Linear(dim,qk_dim,biasFalse)self.to_vnn.Linear(dim,dim,biasFalse)self.projnn.Linear(dim,dim,biasFalse)defforward(self,x):q,k,vself.to_q(x),self.to_k(x),self.to_v(x)q,k,vmap(lambdat:rearrange(t,b n (h d) - b h n d,hself.heads),(q,k,v))outF.scaled_dot_product_attention(q,k,v)outrearrange(out,b h n d - b n (h d))returnself.proj(out)classLRSA(nn.Module):def__init__(self,dim,qk_dimNone,mlp_dimNone,heads1):super().__init__()# Set default values to maintain compatibilityifqk_dimisNone:qk_dim32# Default value close to original usageifmlp_dimisNone:mlp_dim2*dim# Default: 2x expansionself.layernn.ModuleList([PreNorm(dim,Attention(dim,heads,qk_dim)),PreNorm(dim,ConvFFN(dim,mlp_dim))])defforward(self,x):ps8stepps-2crop_x,nh,nwpatch_divide(x,step,ps)# (b, n, c, ps, ps)b,n,c,ph,pwcrop_x.shape crop_xrearrange(crop_x,b n c h w - (b n) (h w) c)attn,ffself.layer crop_xattn(crop_x)crop_x crop_xrearrange(crop_x,(b n) (h w) c - b n c h w,nn,wpw)xpatch_reverse(crop_x,x,step,ps)_,_,h,wx.shape xrearrange(x,b c h w- b (h w) c)xff(x,x_size(h,w))x xrearrange(x,b (h w) c-b c h w,hh)returnx注册在ultralytics/nn/tasks.py中进行如下操作步骤1:fromultralytics.nn.attention.LRSAimportLRSA步骤2修改def parse_model(d, ch, verboseTrue):elifmisLRSA:args[ch[f],*args]配置yolov11-LRSA.yamlultralytics/cfg/models/11/yolov11-LRSA.yaml# Ultralytics YOLO , AGPL-3.0 license# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parametersnc:80# number of classesscales:# model compound scaling constants, i.e. modelyolo11n.yaml will call yolo11.yaml with scale n# [depth, width, max_channels]n:[0.50,0.25,1024]# summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss:[0.50,0.50,1024]# summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm:[0.50,1.00,512]# summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl:[1.00,1.00,512]# summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx:[1.00,1.50,512]# summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPsbackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,Conv,[128,3,2]]# 1-P2/4-[-1,2,C3k2,[256,False,0.25]]-[-1,1,Conv,[256,3,2]]# 3-P3/8-[-1,2,C3k2,[512,False,0.25]]-[-1,1,Conv,[512,3,2]]# 5-P4/16-[-1,2,C3k2,[512,True]]-[-1,1,Conv,[1024,3,2]]# 7-P5/32-[-1,2,C3k2,[1024,True]]-[-1,1,SPPF,[1024,5]]# 9-[-1,2,C2PSA,[1024]]# 10head:-[-1,1,nn.Upsample,[None,2,nearest]]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,2,C3k2,[512,False]]# 13-[-1,1,nn.Upsample,[None,2,nearest]]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,2,C3k2,[256,False]]# 16 (P3/8-small)-[-1,1,LRSA,[256]]# 17 (P3/8-small) - Local Region Self-Attention-[-1,1,Conv,[256,3,2]]-[[-1,13],1,Concat,[1]]# cat head P4-[-1,2,C3k2,[512,False]]# 20 (P4/16-medium)-[-1,1,LRSA,[512]]# 21 (P4/16-medium) - Local Region Self-Attention-[-1,1,Conv,[512,3,2]]-[[-1,10],1,Concat,[1]]# cat head P5-[-1,2,C3k2,[1024,True]]# 24 (P5/32-large)-[-1,1,LRSA,[1024]]# 25 (P5/32-large) - Local Region Self-Attention-[[17,21,25],1,Detect,[nc]]# Detect(P3, P4, P5)实验脚本importwarnings warnings.filterwarnings(ignore)fromultralyticsimportYOLOif__name____main__:# 修改为自己的配置文件地址modelYOLO(/root/ultralytics-main/ultralytics/cfg/models/11/yolov11-LRSA.yaml)# 修改为自己的数据集地址model.train(data/root/ultralytics-main/ultralytics/cfg/datasets/coco8.yaml,cacheFalse,imgsz640,epochs10,single_clsFalse,# 是否是单类别检测batch8,close_mosaic10,workers0,optimizerSGD,ampTrue,projectruns/train,nameLRSA,)结果