本研究文章探讨了针对高级驾驶辅助系统(ADAS)的Stereo Transformer(STTR)模型的微调与推理流程。探索高级驾驶辅助系统(ADAS)的动态世界以及立体视觉这一创新领域。本文将深入探讨车载ADAS立体视觉如何改变游戏规则,为深度感知提供一种优于传统激光雷达(LiDAR)方法的智能替代方案。
这篇全面的研究文章包含一套详细的流程,逐步指导如何搭建并微调一个Stereo Transformer(STTR)模型,使其能够像人类双眼一样,从两个摄像头视频流中预测视差图。
本文不仅限于纯粹的计算机视觉理论,还包含了在KITTI立体视觉数据集上微调后的真实世界实验结果。这些令人印象深刻的结果,堪比更复杂且昂贵的激光雷达(LiDAR)系统所能达到的性能。此外,您还将深入了解此类模型的实际训练与推理流程——这是基础教程中通常不会涉及的内容。
为什么选择立体视觉? ADAS 立体视觉是指利用两个摄像头(类似于人类双眼)从略微不同的角度捕捉图像的技术。这种配置模拟了人类的双目视觉,使系统能够感知环境的深度和三维结构。
立体视觉的关键技术要点 ADAS 立体视觉具备多项优势和高度复杂的功能。让我们深入探讨,以更好地理解:
对极几何与深度估计 :
立体视觉依赖于对极几何(epipolar geometry)——这是计算机视觉中的一个基本概念,用于描述立体成像系统中两个视角之间的几何关系。通过在一对图像中寻找对应点(如边缘或角点等特征),系统可计算出视差(disparity),即两幅图像中相同特征点坐标的差异。该视差与场景中物体到摄像头的距离成反比,从而实现深度估计。
三维重建与点云生成 :
借助三角测量方法,立体视觉系统可以重建所观测场景的三维模型,并生成点云(point cloud)。点云中的每个点都代表场景中某一物理位置,其三维坐标由视差图推导得出。
在某些方面优于激光雷达(LiDAR) :
事实证明,与三维激光雷达系统相比,立体视觉通常更具成本效益。立体视觉系统中的摄像头能够捕获高分辨率图像,提供当前激光雷达尚无法获取的丰富纹理信息。此外,激光雷达在特定光照条件下(例如强日光直射或黑暗环境)可能表现不佳,而立体视觉系统则能在更广泛的光照场景中保持稳定性能,尤其是在低照度成像技术不断进步的背景下。
在 ADAS 中的应用 :
立体视觉可为高级驾驶辅助系统提供前沿功能,例如障碍物检测与避让、车道线检测、行人识别等,这些都具有重要实用价值。
视差的概念
在关于提升 ADAS 立体视觉性能的计算机视觉研究中,Okutomi 与 Kanade 在其论文《A Multiple-Baseline Stereo》中提出了一种创新的立体匹配方法:通过使用多个具有不同基线长度的立体图像对,提高距离估计的精度,同时降低匹配歧义的风险。
立体视觉中视差 ( d ) 与摄像头到物体距离 ( z ) 之间的数学关系,可通过摄像头的基线 ( B )(即两个摄像头光心之间的距离)和焦距 ( F ) 表示如下:
其中,( d ) 表示视差,( B ) 为基线长度,( F ) 为摄像头焦距,( z ) 为物体到摄像头的距离。该公式表明,视差与基线和焦距的乘积成正比,与物体距离成反比。
该方法通过横向移动单个摄像头来生成多个不同基线的图像对,从而规避了传统立体匹配中精度与准确度之间的权衡问题。
该技术的核心在于:通过对多个立体图像对计算并累加“平方差和”(Sum of Squared Differences, SSD)值,并以逆距离(而非视差)作为变量进行表示,有效减少了全局匹配错误。这种方法能妥善处理匹配过程中固有的模糊性问题(例如重复纹理图案),在不依赖搜索或序列滤波技术的前提下显著提升了匹配精度。
文献综述——当前研究趋势概览 在他们的研究论文中,Naveen Appiah 与 Nitin Bandaru 提出了一种新颖的方法:利用一对360°全景摄像头构建ADAS立体视觉系统,并通过垂直方向上的摄像头位移,实现全向视角下的全面深度感知。该方法主要聚焦于一种基于几何的聚类技术用于障碍物识别,其中将障碍物定义为相对于地面平面存在高度抬升的点或区域。
该障碍物检测算法通过两个定量标准来界定障碍物:
Hendrik Königshof 所开展的研究提出了一种面向自动驾驶的突破性三维目标检测与姿态估计方法。该方法创新性地将深度卷积神经网络(CNN)提供的语义信息与视差数据及几何约束相结合,能够在实时条件下为各类道路使用者精确生成三维边界框(3D bounding boxes)。
该系统采用基于 ResNet-38 的编码器,用于逐像素的语义分割与目标检测,并结合一种受 SSD 和 RetinaNet 启发的、无需候选框(proposal-free)的边界框检测机制。在立体视频的视差估计方面,系统利用 GPU 加速的块匹配算法,采用倾斜平面(slanted planes)方法,并引入一种新颖的置信度度量指标 CPKR,以实现可靠且鲁棒的视差计算。通过融合上述技术,该系统构建出一种高效、具备实时处理能力的算法,在 KITTI 三维目标检测基准测试中展现出具有竞争力的性能,其运行速度显著优于现有的基于图像的方法。
由 Zhaoshuo Li 等人撰写的论文《Revisiting Stereo Depth Estimation From a Sequence-to-Sequence Perspective with Transformers》提出了 Stereo Transformer(STTR)方法——一种用于立体深度估计的全新序列到序列建模范式。与传统方法不同,STTR 利用位置信息和注意力机制进行密集像素匹配,摆脱了固定视差范围的限制,从而显著提升了遮挡区域的检测能力与视差置信度估计的准确性。
该架构采用了一个沙漏形(hourglass-shaped)的特征提取器,结合残差连接(residual connections)与空间金字塔池化(spatial pyramid pooling),以高效获取上下文信息。Transformer 模型则交替使用自注意力(self-attention)和交叉注意力(cross-attention)层,对特征描述符进行优化,从而实现高精度的视差估计。该模型的独特之处在于引入了熵正则化的最优传输 (entropy-regularized optimal transport),用于在立体匹配中施加唯一性约束,生成带有梯度流的软分配(soft assignments)。此外,STTR 还包含一个上下文调整层(context adjustment layer),该层利用卷积模块与残差网络对原始视差图和遮挡图进行精细化处理,并通过跨对极线(cross-epipolar line)信息进一步提升深度感知的准确性。
KITTI ADAS 立体视觉数据集概览 KITTI 2015 ADAS 立体视觉数据集因其在计算机视觉与自动驾驶研究中的广泛应用而广为人知,是一个全面且被广泛采用的数据集。该数据集由卡尔斯鲁厄理工学院(Karlsruhe Institute of Technology)与芝加哥丰田技术研究所(Toyota Technological Institute at Chicago)联合开发,属于 KITTI 视觉基准套件(KITTI Vision Benchmark Suite)的一部分。数据集包含 200 个训练场景和 200 个测试场景,每个场景提供四张彩色图像,均以无损 PNG 格式保存,为模型训练与验证提供了充足的数据支持。
KITTI 2015 数据集的一个重要特点是聚焦于动态场景 ,这与早期版本(如 KITTI 2012)有所不同。场景中包含移动物体,使其特别适用于自动驾驶等需要理解动态环境的应用场景。
在继续深入之前,我们不妨先来看几个来自该数据集的样本示例,如何?
从上述图中可以看出,在每个时间单位内,每个样本包含三个数据点:
左摄像头视频流 右摄像头视频流 融合后的真值视差图(Ground Truth Disparity Map)
Stereo Transformer(STTR)—— 网络架构 在本节中,我们将深入探讨立体视觉 Transformer(STereo TRansformer,简称 STTR)的内部架构。该模型的详细结构示意图如下图所示:
亮点 特征提取器 :采用先进的沙漏架构(hourglass architecture),结合残差连接(residual connections)和空间金字塔池化(spatial pyramid pooling),全面捕捉局部与全局上下文信息。Transformer :使用交替的自注意力机制(self-attention mechanism)和交叉注意力机制(cross-attention mechanism),根据图像上下文和位置关系更新特征描述符。最优传输 :应用熵正则化的最优传输(entropy-regularized optimal transport)实现立体匹配中的软分配,确保灵活性和可微性。上下文调整层 :利用卷积块和激活函数细化视差图和遮挡图估计,并集成跨对极线(cross-epipolar line)上下文信息。内存可行的实现方式 :通过梯度检查点(gradient checkpointing)和混合精度训练(mixed-precision training)实现内存的有效管理,使得在标准硬件上注意力层(attention layers)的扩展性得以保障。
特征提取器 STTR 中的特征提取器使用类似于先前模型的沙漏形架构,但进行了显著改进。它集成了残差连接和空间金字塔池化模块,以高效获取全局上下文信息。解码路径设计包括转置卷积、密集块(dense blocks)以及最终的卷积层,确保每个像素的特征描述符既能编码局部也能编码全局上下文信息,同时保持与输入图像相同的空域分辨率。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL class SppBackbone(nn.Module):
"""
使用空间金字塔池化(SPP)构建特征描述符的收缩路径,
SPP 模块参考自 PSMNet (https://github.com/JiaRenChang/PSMNet)
"""
def __init__ (self ):
super (SppBackbone, self ).__init__()
self .inplanes = 32
self .in_conv = nn.Sequential(
nn.Conv2d(3 , 16 , kernel_size=3 , padding=1 , stride=2 , bias=False ),
nn.BatchNorm2d(16 ),
nn.ReLU(inplace=True ),
nn.Conv2d(16 , 16 , kernel_size=3 , padding=1 , bias=False ),
nn.BatchNorm2d(16 ),
nn.ReLU(inplace=True ),
nn.Conv2d(16 , 32 , kernel_size=3 , padding=1 , bias=False ),
nn.BatchNorm2d(32 ),
nn.ReLU(inplace=True )
)
self .resblock_1 = self ._make_layer(BasicBlock, 64 , 3 , 2 )
self .resblock_2 = self ._make_layer(BasicBlock, 128 , 3 , 2 )
self .branch1 = nn.Sequential(
nn.AvgPool2d((16 , 16 ), stride=(16 , 16 )),
nn.Conv2d(128 , 32 , kernel_size=1 , bias=False ),
nn.BatchNorm2d(32 ),
nn.ReLU(inplace=True )
)
self .branch2 = nn.Sequential(
nn.AvgPool2d((8 , 8 ), stride=(8 , 8 )),
nn.Conv2d(128 , 32 , kernel_size=1 , bias=False ),
nn.BatchNorm2d(32 ),
nn.ReLU(inplace=True )
)
self .branch3 = nn.Sequential(
nn.AvgPool2d((4 , 4 ), stride=(4 , 4 )),
nn.Conv2d(128 , 32 , kernel_size=1 , bias=False ),
nn.BatchNorm2d(32 ),
nn.ReLU(inplace=True )
)
self .branch4 = nn.Sequential(
nn.AvgPool2d((2 , 2 ), stride=(2 , 2 )),
nn.Conv2d(128 , 32 , kernel_size=1 , bias=False ),
nn.BatchNorm2d(32 ),
nn.ReLU(inplace=True )
)
def _make_layer (self, block, planes, blocks, stride=1 ):
downsample = None
if stride != 1 or self .inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self .inplanes, planes * block.expansion,
kernel_size=1 , stride=stride, bias=False ),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self .inplanes, planes, stride, downsample))
self .inplanes = planes * block.expansion
for i in range (1 , blocks):
layers.append(block(self .inplanes, planes))
return nn.Sequential(*layers)
def forward (self, x: NestedTensor ):
"""
:param x: NestedTensor
:return: 包含不同空间分辨率下特征描述符的列表
0: [2N, 3, H, W]
1: [2N, C0, H//4, W//4]
2: [2N, C1, H//8, W//8]
3: [2N, C2, H//16, W//16]
"""
_, _, h, w = x.left.shape
src_stereo = torch.cat([x.left, x.right], dim=0 )
output = self .in_conv(src_stereo)
output_1 = self .resblock_1(output)
output_2 = self .resblock_2(output_1)
h_spp, w_spp = math.ceil(h / 16 ), math.ceil(w / 16 )
spp_1 = F.interpolate(self .branch1(output_2), size=(h_spp, w_spp), mode='bilinear' , align_corners=False )
spp_2 = F.interpolate(self .branch2(output_2), size=(h_spp, w_spp), mode='bilinear' , align_corners=False )
spp_3 = F.interpolate(self .branch3(output_2), size=(h_spp, w_spp), mode='bilinear' , align_corners=False )
spp_4 = F.interpolate(self .branch4(output_2), size=(h_spp, w_spp), mode='bilinear' , align_corners=False )
output_3 = torch.cat([spp_1, spp_2, spp_3, spp_4], dim=1 )
return [src_stereo, output_1, output_2, output_3]
Transformer 模块 STTR 中的 Transformer 架构是核心组件之一,采用交替注意力机制。它利用自注意力计算同一图像中沿极线方向像素间的注意力,同时利用交叉注意力处理左右图像对应极线之间的关系。该机制在前 (N-1) 层中交替执行自注意力与交叉注意力,持续根据图像上下文和相对位置更新特征描述符。最后一层交叉注意力专注于原始视差估计,并结合最优传输以满足唯一性约束,同时使用注意力掩码缩小搜索空间。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL class Transformer(nn.Module):
"""
Transformer 计算自注意力(图像内)和交叉注意力(图像间)
"""
def __init__ (self, hidden_dim: int = 128, nhead: int = 8, num_attn_layers: int = 6 ):
super ().__init__()
self_attn_layer = TransformerSelfAttnLayer(hidden_dim, nhead)
self .self_attn_layers = get_clones(self_attn_layer, num_attn_layers)
cross_attn_layer = TransformerCrossAttnLayer(hidden_dim, nhead)
self .cross_attn_layers = get_clones(cross_attn_layer, num_attn_layers)
self .norm = nn.LayerNorm(hidden_dim)
self .hidden_dim = hidden_dim
self .nhead = nhead
self .num_attn_layers = num_attn_layers
def _alternating_attn (self, feat: torch.Tensor, pos_enc: torch.Tensor, pos_indexes: Tensor, hn: int ):
"""
交替执行自注意力与交叉注意力,并使用梯度检查点节省内存
:param feat: 左右图像拼接后的特征,[W,2HN,C]
:param pos_enc: 位置编码,[W,HN,C]
:param pos_indexes: 用于切片位置编码的索引,[W,HN,C]
:param hn: HN 的尺寸
:return: 注意力权重 [N,H,W,W]
"""
global layer_idx
for idx, (self_attn, cross_attn) in enumerate (zip (self .self_attn_layers, self .cross_attn_layers)):
layer_idx = idx
def create_custom_self_attn (module ):
def custom_self_attn (*inputs ):
return module(*inputs)
return custom_self_attn
feat = checkpoint(create_custom_self_attn(self_attn), feat, pos_enc, pos_indexes)
if idx == self .num_attn_layers - 1 :
def create_custom_cross_attn (module ):
def custom_cross_attn (*inputs ):
return module(*inputs, True )
return custom_cross_attn
else :
def create_custom_cross_attn (module ):
def custom_cross_attn (*inputs ):
return module(*inputs, False )
return custom_cross_attn
feat, attn_weight = checkpoint(
create_custom_cross_attn(cross_attn),
feat[:, :hn], feat[:, hn:], pos_enc, pos_indexes
)
layer_idx = 0
return attn_weight
def forward (self, feat_left: torch.Tensor, feat_right: torch.Tensor, pos_enc: Optional[Tensor] = None ):
"""
:param feat_left: 左图特征描述符,[N,C,H,W]
:param feat_right: 右图特征描述符,[N,C,H,W]
:param pos_enc: 相对位置编码,[N,C,H,2W-1]
:return: 交叉注意力权重 [N,H,W,W],dim=2 为左图,dim=3 为右图
"""
bs, c, hn, w = feat_left.shape
feat_left = feat_left.permute(1 , 3 , 2 , 0 ).flatten(2 ).permute(1 , 2 , 0 )
feat_right = feat_right.permute(1 , 3 , 2 , 0 ).flatten(2 ).permute(1 , 2 , 0 )
if pos_enc is not None :
with torch.no_grad():
indexes_r = torch.linspace(w - 1 , 0 , w).view(w, 1 ).to(feat_left.device)
indexes_c = torch.linspace(0 , w - 1 , w).view(1 , w).to(feat_left.device)
pos_indexes = (indexes_r + indexes_c).view(-1 ).long()
else :
pos_indexes = None
feat = torch.cat([feat_left, feat_right], dim=1 )
attn_weight = self ._alternating_attn(feat, pos_enc, pos_indexes, hn)
attn_weight = attn_weight.view(hn, bs, w, w).permute(1 , 0 , 2 , 3 )
return attn_weight
class TransformerSelfAttnLayer(nn.Module):
"""自注意力层"""
def __init__ (self, hidden_dim: int, nhead: int ):
super ().__init__()
self .self_attn = MultiheadAttentionRelative(hidden_dim, nhead)
self .norm1 = nn.LayerNorm(hidden_dim)
def forward (self, feat: Tensor, pos: Optional[Tensor] = None, pos_indexes: Optional[Tensor] = None ):
feat2 = self .norm1(feat)
feat2, attn_weight, _ = self .self_attn(query=feat2, key=feat2, value=feat2, pos_enc=pos, pos_indexes=pos_indexes)
feat = feat + feat2
return feat
class TransformerCrossAttnLayer(nn.Module):
"""交叉注意力层"""
def __init__ (self, hidden_dim: int, nhead: int ):
super ().__init__()
self .cross_attn = MultiheadAttentionRelative(hidden_dim, nhead)
self .norm1 = nn.LayerNorm(hidden_dim)
self .norm2 = nn.LayerNorm(hidden_dim)
def forward (self, feat_left: Tensor, feat_right: Tensor,
pos: Optional[Tensor] = None,
pos_indexes: Optional[Tensor] = None,
last_layer: Optional[bool] = False):
feat_left_2 = self .norm1(feat_left)
feat_right_2 = self .norm1(feat_right)
if pos is not None :
pos_flipped = torch.flip(pos, [0 ])
else :
pos_flipped = pos
feat_right_2 = self .cross_attn(query=feat_right_2, key=feat_left_2, value=feat_left_2,
pos_enc=pos_flipped, pos_indexes=pos_indexes)[0 ]
feat_right = feat_right + feat_right_2
if last_layer:
w = feat_left_2.size(0 )
attn_mask = self ._generate_square_subsequent_mask(w).to(feat_left.device)
else :
attn_mask = None
feat_right_2 = self .norm2(feat_right)
feat_left_2, attn_weight, raw_attn = self .cross_attn(
query=feat_left_2, key=feat_right_2, value=feat_right_2,
attn_mask=attn_mask, pos_enc=pos, pos_indexes=pos_indexes
)
feat_left = feat_left + feat_left_2
feat = torch.cat([feat_left, feat_right], dim=1 )
return feat, raw_attn
@torch.no_grad()
def _generate_square_subsequent_mask (self, sz: int ):
mask = torch.triu(torch.ones(sz, sz), diagonal=1 )
mask[mask == 1 ] = float ('-inf' )
return mask
def build_transformer(args):
return Transformer(
hidden_dim=args.channel_dim,
nhead=args.nheads,
num_attn_layers=args.num_attn_layers
)
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL 在 STTR 中,注意力模块使用点积相似度计算查询集与键向量之间的注意力,并据此加权值向量。模型采用多头注意力机制,通过将通道维度划分为多个组,增强特征描述符的表达能力,从而优化注意力计算并提升特征表示效果。
最优传输(Optimal Transport) STTR 中的最优传输模块用于解决立体匹配中的唯一性约束问题。不同于以往模型采用的硬分配(hard assignment)会阻碍梯度流动,STTR 采用熵正则化的最优传输方法,其软分配特性具备良好的可微分性。该方法特别适用于稀疏特征匹配和语义对应等任务,提供了更灵活高效的匹配机制。
上下文调整层(Context Adjustment Layer) 该层旨在弥补原始视差图和遮挡图中缺乏跨极线上下文信息的问题。通过将这些图与左图拼接,并使用卷积块和 ReLU 激活函数,模型对视差估计进行精细化处理。最终遮挡估计采用 Sigmoid 激活函数,而视差细化则引入残差块,确保基于输入图像上下文进行全面调整。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL class ContextAdjustmentLayer(nn.Module):
"""
基于图像上下文对视差和遮挡进行调整,
设计思路大致参考 https://github.com/JiahuiYu/wdsr_ntire2018
"""
def __init__ (self, num_blocks=8, feature_dim=16, expansion=3 ):
super ().__init__()
self .num_blocks = num_blocks
self .in_conv = nn.Conv2d(4 , feature_dim, kernel_size=3 , padding=1 )
self .layers = nn.ModuleList([ResBlock(feature_dim, expansion) for _ in range (num_blocks)])
self .out_conv = nn.Conv2d(feature_dim, 1 , kernel_size=3 , padding=1 )
self .occ_head = nn.Sequential(
weight_norm(nn.Conv2d(1 + 3 , feature_dim, kernel_size=3 , padding=1 )),
weight_norm(nn.Conv2d(feature_dim, feature_dim, kernel_size=3 , padding=1 )),
nn.ReLU(inplace=True ),
weight_norm(nn.Conv2d(feature_dim, feature_dim, kernel_size=3 , padding=1 )),
weight_norm(nn.Conv2d(feature_dim, feature_dim, kernel_size=3 , padding=1 )),
nn.ReLU(inplace=True ),
nn.Conv2d(feature_dim, 1 , kernel_size=3 , padding=1 ),
nn.Sigmoid()
)
def forward (self, disp_raw: Tensor, occ_raw: Tensor, img: Tensor ):
"""
:param disp_raw: 原始视差图,[N,1,H,W]
:param occ_raw: 原始遮挡掩码,[N,1,H,W]
:param img: 输入左图,[N,3,H,W]
:return:
disp_final: 最终视差图 [N,1,H,W]
occ_final: 最终遮挡图 [N,1,H,W]
"""
feat = self .in_conv(torch.cat([disp_raw, img], dim=1 ))
for layer in self .layers:
feat = layer(feat, disp_raw)
disp_res = self .out_conv(feat)
disp_final = disp_raw + disp_res
occ_final = self .occ_head(torch.cat([occ_raw, img], dim=1 ))
return disp_final, occ_final
内存友好型实现(Memory-Feasible Implementation) STTR 解决了注意力机制通常伴随的高内存消耗问题。通过采用梯度检查点、混合精度训练以及注意力步长调整等技术,有效控制内存使用。这些策略显著降低了内存占用,使得网络在注意力层数量方面具备良好的可扩展性,从而能够在常规硬件上实际部署和运行。
代码详解 – STTR 微调策略 正如本文前面所述,本次微调所选用的数据集为 KITTI ADAS 立体视觉 2015 数据集。不过,该数据集的 2012 版本目前仍可作为开源资源下载。在任何深度学习流程中,对原始数据集进行预处理都至关重要。需要注意的是,下方提供的预处理脚本同时支持 2015 和 2012 两个版本的数据集,也可将两者结合使用。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL class KITTIBaseDataset (data.Dataset):
def __init__ (self, datadir, split='train' ):
super (KITTIBaseDataset, self ).__init__()
self .datadir = datadir
self .split = split
if split == 'train' or split == 'validation' or split == 'validation_all' :
self .sub_folder = 'training/'
elif split == 'test' :
self .sub_folder = 'testing/'
self .left_fold = None
self .right_fold = None
self .disp_fold = None
self ._augmentation()
def _read_data (self ):
assert self .left_fold is not None
self .left_data = natsorted([os.path.join(self .datadir, self .sub_folder, self .left_fold, img) for img in
os.listdir(os.path.join(self .datadir, self .sub_folder, self .left_fold)) if
img.find('_10' ) > -1 ])
self .right_data = [img.replace(self .left_fold, self .right_fold) for img in self .left_data]
self .disp_data = [img.replace(self .left_fold, self .disp_fold) for img in self .left_data]
self ._split_data()
def _split_data (self ):
train_val_frac = 0.95
if len (self .left_data) > 1 :
if self .split == 'train' :
self .left_data = self .left_data[:int (len (self .left_data) * train_val_frac)]
self .right_data = self .right_data[:int (len (self .right_data) * train_val_frac)]
self .disp_data = self .disp_data[:int (len (self .disp_data) * train_val_frac)]
elif self .split == 'validation' :
self .left_data = self .left_data[int (len (self .left_data) * train_val_frac):]
self .right_data = self .right_data[int (len (self .right_data) * train_val_frac):]
self .disp_data = self .disp_data[int (len (self .disp_data) * train_val_frac):]
def _augmentation (self ):
if self .split == 'train' :
self .transformation = Compose([
RGBShiftStereo(always_apply=True , p_asym=0.5 ),
RandomBrightnessContrastStereo(always_apply=True , p_asym=0.5 )
])
elif self .split == 'validation' or self .split == 'test' or self .split == 'validation_all' :
self .transformation = None
else :
raise Exception("Split not recognized" )
def __len__ (self ):
return len (self .left_data)
def __getitem__ (self, idx ):
input_data = {}
left_fname = self .left_data[idx]
left = np.array(Image.open (left_fname)).astype(np.uint8)
input_data['left' ] = left
right_fname = self .right_data[idx]
right = np.array(Image.open (right_fname)).astype(np.uint8)
input_data['right' ] = right
if not self .split == 'test' :
disp_fname = self .disp_data[idx]
disp = np.array(Image.open (disp_fname)).astype(float ) / 256.
input_data['disp' ] = disp
input_data['occ_mask' ] = np.zeros_like(disp).astype(bool )
if self .split == 'train' :
input_data = random_crop(200 , 640 , input_data, self .split)
input_data = augment(input_data, self .transformation)
else :
input_data = normalization(**input_data)
return input_data
class KITTIDataset (KITTIBaseDataset ):
"""
合并了 KITTI 2015 与 2012 数据的混合数据集
"""
def __init__ (self, datadir, split='train' ):
super (KITTIDataset, self ).__init__(datadir, split)
self .left_fold_2015 = 'image_2'
self .right_fold_2015 = 'image_3'
self .disp_fold_2015 = 'disp_occ_0'
self .preprend_2015 = '2015'
self .left_fold_2012 = 'colored_0'
self .right_fold_2012 = 'colored_1'
self .disp_fold_2012 = 'disp_occ'
self .preprend_2012 = '2012'
self ._read_data()
def _read_data (self ):
assert self .left_fold_2015 is not None
assert self .left_fold_2012 is not None
left_data_2015 = [os.path.join(self .datadir, self .preprend_2015, self .sub_folder, self .left_fold_2015, img) for
img in os.listdir(os.path.join(self .datadir, '2015' , self .sub_folder, self .left_fold_2015)) if
img.find('_10' ) > -1 ]
left_data_2015 = natsorted(left_data_2015)
right_data_2015 = [img.replace(self .left_fold_2015, self .right_fold_2015) for img in left_data_2015]
disp_data_2015 = [img.replace(self .left_fold_2015, self .disp_fold_2015) for img in left_data_2015]
left_data_2012 = [os.path.join(self .datadir, self .preprend_2012, self .sub_folder, self .left_fold_2012, img) for
img in os.listdir(os.path.join(self .datadir, '2012' , self .sub_folder, self .left_fold_2012)) if
img.find('_10' ) > -1 ]
left_data_2012 = natsorted(left_data_2012)
right_data_2012 = [img.replace(self .left_fold_2012, self .right_fold_2012) for img in left_data_2012]
disp_data_2012 = [img.replace(self .left_fold_2012, self .disp_fold_2012) for img in left_data_2012]
self .left_data = natsorted(left_data_2015 + left_data_2012)
self .right_data = natsorted(right_data_2015 + right_data_2012)
self .disp_data = natsorted(disp_data_2015 + disp_data_2012)
self ._split_data()
class KITTI2015Dataset (KITTIBaseDataset ):
def __init__ (self, datadir, split='train' ):
super (KITTI2015Dataset, self ).__init__(datadir, split)
self .left_fold = 'image_2/'
self .right_fold = 'image_3/'
self .disp_fold = 'disp_occ_0/'
self ._read_data()
class KITTI2012Dataset (KITTIBaseDataset ):
def __init__ (self, datadir, split='train' ):
super (KITTI2012Dataset, self ).__init__(datadir, split)
self .left_fold = 'colored_0/'
self .right_fold = 'colored_1/'
self .disp_fold = 'disp_occ/'
self ._read_data()
让我们详细理解上述代码片段:
KITTIBaseDataset :这是用于处理 KITTI 数据集的基类,继承自 torch.utils.data.Dataset。构造函数 (__init__) 接收两个参数:datadir(数据集所在目录)和 split(指定数据划分,如训练、验证、测试等)。它初始化数据路径,并调用 _augmentation() 方法,根据数据划分设置数据增强策略。_read_data() 方法构建左图、右图和视差图的路径,并将数据集划分为训练集和验证集。_split_data() 方法根据预设比例划分训练/验证数据,而 _augmentation() 方法则为训练数据定义增强策略,例如 RGB 偏移、随机亮度/对比度调整等。KITTI2015Dataset 与 KITTI2012Dataset :这两个类继承自 KITTIBaseDataset,分别专门处理 KITTI 2015 和 KITTI 2012 数据集。每个类根据各自数据集的结构,设置左图、右图和视差图的具体目录(left_fold、right_fold、disp_fold)。KITTIDataset :该类同样继承自 KITTIBaseDataset,用于处理合并后的 KITTI 2015 与 2012 数据。它为每一年份的数据分别设置目录路径,然后读取并合并两者。其 _read_data() 方法被重写,以支持从两个数据集中读取并合并数据。
要使上述脚本正常运行,需建立特定的数据集目录结构。如果你通过本文提供的代码包下载代码,该结构已预先配置好。但若你希望自行下载原始 KITTI ADAS 立体视觉数据集,可访问以下链接:
KITTI ADAS 立体视觉 2015 数据集(2GB)
以下是推荐的目录结构:
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL stereo-transformer
├── sample_data
│ └── KITTI_2015
│ ├── 2012
│ │ ├── testing
│ │ │ ├── colored_0
│ │ │ ├── colored_1
│ │ │ ├── image_0
│ │ └── training
│ │ ├── colored_0
│ │ ├── colored_1
│ │ └── disp_occ
│ └── 2015
│ ├── testing
│ │ ├── image_2
│ │ └── disp_occ
│ └── training
│ ├── image_2
│ ├── image_3
│ └── disp_occ_0
├── dataset
├── media
├── module
├── run
├── scripts
└── utilities
要启动微调过程,只需在 stereo-transformer 项目根目录下执行以下命令:
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL python main.py --epochs 400\
--batch_size 1\
--checkpoint kitti_ft\
--num_workers 2\
--dataset kitti\
--dataset_directory sample_data/KITTI_2015\
--ft\
--resume kitti_finetuned_model.pth.tar
该命令接收以下参数:
训练轮数(Epochs) 批次大小(Batch Size) 检查点保存目录(Checkpoint Directory) 数据加载工作线程数(Workers) 数据集类型(Dataset Type) 数据集路径(Dataset Directory) 恢复训练的检查点路径(Resume Checkpoint)
你可以根据自己的计算资源适当增加 Epochs、Batch Size 和 Workers 的数值。
注意 :最初该微调实验在配备 Nvidia RTX 3080 Ti 的深度学习机器上进行测试,但很快显存(vRAM)耗尽。因此,最终改用配备 24GB 显存的 Nvidia RTX A5000 对 STTR 模型进行 ADAS 立体视觉任务的微调。Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL def print_param(model):
"""
打印模型中各部分的参数数量
"""
n_parameters = sum (p.numel() for n, p in model.named_parameters() if 'backbone' in n and p.requires_grad)
print ('number of params in backbone:' , f'{n_parameters:,}' )
n_parameters = sum (p.numel() for n, p in model.named_parameters() if
'transformer' in n and 'regression' not in n and p.requires_grad)
print ('number of params in transformer:' , f'{n_parameters:,}' )
n_parameters = sum (p.numel() for n, p in model.named_parameters() if 'tokenizer' in n and p.requires_grad)
print ('number of params in tokenizer:' , f'{n_parameters:,}' )
n_parameters = sum (p.numel() for n, p in model.named_parameters() if 'regression' in n and p.requires_grad)
print ('number of params in regression:' , f'{n_parameters:,}' )
def main(args):
device = torch.device(args.device)
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model = STTR(args).to(device)
print_param(model)
param_dicts = [
{"params" : [p for n, p in model.named_parameters() if
"backbone" not in n and "regression" not in n and p.requires_grad]},
{
"params" : [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr" : args.lr_backbone,
},
{
"params" : [p for n, p in model.named_parameters() if "regression" in n and p.requires_grad],
"lr" : args.lr_regression,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay_rate)
if args.apex:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level='O1' )
else :
amp = None
prev_best = np.inf
if args.resume != '' :
if not os.path.isfile(args.resume):
raise RuntimeError(f"=> no checkpoint found at '{args.resume}'" )
checkpoint = torch.load(args.resume)
pretrained_dict = checkpoint['state_dict' ]
missing, unexpected = model.load_state_dict(pretrained_dict, strict=False )
if len (missing) > 0 :
print ("Missing keys: " , ',' .join(missing))
raise Exception("Missing keys." )
unexpected_filtered = [k for k in unexpected if
'running_mean' not in k and 'running_var' not in k]
if len (unexpected_filtered) > 0 :
print ("Unexpected keys: " , ',' .join(unexpected_filtered))
raise Exception("Unexpected keys." )
print ("Pre-trained model successfully loaded." )
if not (args.ft or args.inference or args.eval ):
if len (unexpected) > 0 :
raise Exception("Resuming legacy model with BN parameters. Not possible due to BN param change. " +
"Do you want to finetune or inference? If so, check your arguments." )
else :
args.start_epoch = checkpoint['epoch' ] + 1
optimizer.load_state_dict(checkpoint['optimizer' ])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler' ])
prev_best = checkpoint['best_pred' ]
if args.apex:
amp.load_state_dict(checkpoint['amp' ])
print ("Pre-trained optimizer, lr scheduler and stats successfully loaded." )
if args.inference:
print ("Start inference" )
_, _, data_loader = build_data_loader(args)
inference(model, data_loader, device, args.downsample)
return
checkpoint_saver = Saver(args)
summary_writer = TensorboardSummary(checkpoint_saver.experiment_dir)
data_loader_train, data_loader_val, _ = build_data_loader(args)
criterion = build_criterion(args)
set_downsample(args)
if args.eval :
print ("Start evaluation" )
evaluate(model, criterion, data_loader_val, device, 0 , summary_writer, True )
return
print ("Start training" )
for epoch in range (args.start_epoch, args.epochs):
print ("Epoch: %d" % epoch)
train_one_epoch(model, data_loader_train, optimizer, criterion, device, epoch, summary_writer,
args.clip_max_norm, amp)
if not args.pre_train:
lr_scheduler.step()
print ("current learning rate" , lr_scheduler.get_lr())
torch.cuda.empty_cache()
if args.pre_train or epoch % 50 == 0 :
save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False , amp)
eval_stats = evaluate(model, criterion, data_loader_val, device, epoch, summary_writer, False )
if prev_best > eval_stats['epe' ] and 0.5 > eval_stats['px_error_rate' ]:
save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, True , amp)
save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False , amp)
return
让我们深入理解在开发环境中执行上述命令时内部发生了什么:
print_param 函数 :该函数计算并打印 STTR 模型不同组件(如骨干网络 backbone、Transformer、分词器 tokenizer 和回归头 regression)的参数数量,使用 PyTorch 的 named_parameters 方法实现。main() 函数 是脚本的核心: 根据输入参数设置计算设备(CPU/GPU); 通过固定 PyTorch、NumPy 和 random 库的随机种子确保结果可复现; 初始化 STTR 模型并将其移至指定设备; 调用 print_param 打印各组件参数量; 为模型不同部分(如 backbone)配置不同的学习率; 定义优化器(AdamW)和学习率调度器(ExponentialLR); 可选地启用 Nvidia Apex 库进行混合精度训练以提升性能; 若提供检查点路径,则加载模型状态以及优化器、调度器和 AMP 状态,用于恢复训练或微调; 在推理模式下,加载数据并运行模型进行预测; 构建训练和验证数据加载器; 设置损失函数(criterion); 包含一个训练循环,调用训练和验证函数,并根据验证性能保存最佳模型; 训练结束后保存最终模型检查点。推理策略 现在我们已经获得了一个微调后的模型。但如何对其进行推理,以检验其性能呢?本节将介绍微调后 STTR 模型的推理流程。请参考项目目录下 scripts 子目录中的 inference-kitti.ipynb 文件。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL from PIL import Image
import torch
import numpy as np
import cv2
import glob
import os
import argparse
import matplotlib.pyplot as plt
import sys
sys.path.append('../' )
from module.sttr import STTR
from dataset.preprocess import normalization, compute_left_occ_region
from utilities.misc import NestedTensor
首先,需要导入必要的包,如 PIL、torch、cv2、glob、os 以及其他内部依赖项。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL
def load_images(image_dir, pattern):
filenames = sorted (glob.glob(os.path.join(image_dir, pattern)))
return [np.array(Image.open (filename)) for filename in filenames[:500 ]]
此代码片段从 KITTI ADAS 立体视觉数据集的测试文件夹中加载约 500 对图像用于推理。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL
args = type ('' , (), {})()
args.channel_dim = 128
args.position_encoding = 'sine1d_rel'
args.num_attn_layers = 6
args.nheads = 8
args.regression_head = 'ot'
args.context_adjustment_layer = 'cal'
args.cal_num_blocks = 8
args.cal_feat_dim = 16
args.cal_expansion_ratio = 4
每个模型都需要一组称为 args 的参数来实例化。此处列出的是默认参数。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL model = STTR(args).cuda().eval ()
由于我们对已微调的模型进行推理,需将其设为评估模式(.eval())。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL
model_file_name = "../kitti_finetuned_model.pth.tar"
checkpoint = torch.load(model_file_name)
pretrained_dict = checkpoint['state_dict' ]
model.load_state_dict(pretrained_dict, strict=False )
print ("Pre-trained model successfully loaded." )
上述代码从最后保存的检查点文件中加载预训练模型。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL
left_images = load_images('../sample_data/KITTI_2015/2015/training/image_2' , '*.png' )
right_images = load_images('../sample_data/KITTI_2015/2015/training/image_3' , '*.png' )
初始化 KITTI ADAS 立体视觉数据集测试集中左右图像的路径。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL
height, width, _ = left_images[0 ].shape
output_dir = '../inference_output/'
os.makedirs(output_dir, exist_ok=True )
推理结果需保存到指定目录。脚本将自动创建名为 inference_output 的文件夹用于存储所有结果。
Plain Bash C++ C# CSS Diff HTML/XML Java Javascript Markdown PHP Python Ruby SQL for i, (left, right) in enumerate(zip(left_images, right_images)):
# 归一化并创建 NestedTensor
input_data = normalization(left=left, right=right)
h, w, _ = left.shape
bs = 1
downsample = 3
col_offset = int(downsample / 2)
row_offset = int(downsample / 2)
sampled_cols = torch.arange(col_offset, w, downsample)[None,].expand(bs, -1).cuda()
sampled_rows = torch.arange(row_offset, h, downsample)[None,].expand(bs, -1).cuda()
input_data = NestedTensor(input_data['left'].cuda()[None,], input_data['right'].cuda()[None,], sampled_cols=sampled_cols, sampled_rows=sampled_rows)
# 执行推理
output = model(input_data)
disp_pred = output['disp_pred'].data.cpu().numpy()[0]
occ_pred = output['occ_pred'].data.cpu().numpy()[0] > 0.5
disp_pred[occ_pred] = 0.0
# 将 disp_pred 和 occ_pred 归一化并转为 uint8
disp_pred_norm = cv2.normalize(disp_pred, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
occ_pred_uint8 = np.uint8(occ_pred * 255)
# 拼接视差图与遮挡图
combined_output = np.hstack((disp_pred_norm, occ_pred_uint8))
# 保存为 PNG 文件
output_filename = os.path.join(output_dir, f'inference_{i:03d}.png')
cv2.imwrite(output_filename, combined_output)
print(f"Saved: {output_filename}")
print("All inferences saved as PNG files.")
上述代码对一系列立体图像对(左图和右图)进行处理,用于计算机视觉任务中的视差与遮挡预测。流程如下:
遍历每一对左右图像; 对图像进行归一化,并封装为自定义数据结构 NestedTensor,其中包含图像及由下采样因子决定的采样行列索引; 将该结构输入模型进行推理; 模型输出视差和遮挡预测结果,将其从 GPU 移至 CPU 并转为 NumPy 数组; 视差图通过 OpenCV 归一化为 8 位格式,遮挡图经阈值处理生成二值掩码并转为 8 位; 在遮挡区域将视差值置零; 将处理后的视差图与遮挡图水平拼接,保存为 PNG 文件; 文件名通过循环计数器确保唯一性; 所有图像对处理完毕后,打印完成信息。
该流程完整展示了立体视觉任务中从预处理、模型推理到结果保存的典型工作流。
实验结果:立体视觉视差图 本节将可视化本研究工作的结果:
有趣的结果,对吧?可以看本研究文章上面的“代码详解”部分,以深入了解这一精细调优过程。
关键要点 上一节展示了ADAS立体视觉Transformer模型的推理输出。让我们总结一下本研究工作的关键发现:
借助STTR提升深度感知能力 :对立体视觉Transformer(STTR)模型进行微调,显著提升了ADAS系统中的深度感知性能,尤其在低光照和动态环境等具有挑战性的条件下表现突出。这表明STTR模型在生成精确的视差图和遮挡图方面具备出色的鲁棒性与适应性。立体视觉作为高性价比的3D LiDAR替代方案 :立体视觉能够提供丰富的纹理信息,并在各种光照条件下保持稳定性能,而这些是LiDAR系统可能欠缺的。计算与性能限制 :尽管优势明显,STTR模型仍存在较高的计算开销,尤其在实时应用中面临较大挑战。ADAS中基于深度估计的目标检测 :研究表明,该模型在ADAS的关键功能(如障碍物检测和行人检测)方面展现出巨大潜力。这些应用得益于模型对环境要素的精准检测与分析能力,有助于实现更安全、高效的自动驾驶导航。
结论 本研究聚焦于将立体视觉作为3D LiDAR的替代方案应用于ADAS系统,重点在于在KITTI ADAS立体视觉数据集上对立体视觉Transformer(STTR)模型进行微调。结果表明,该方法显著提升了深度感知性能,尤其在低光照和动态环境中效果显著,说明STTR是一种可行且成本效益高的LiDAR替代方案。然而,在极端天气条件或低纹理场景下,其计算需求和性能表现仍面临一定挑战。