一、当卷积遇到自注意力:一场视觉特征的化学反应

想象一下,你正在玩拼图游戏。CNN就像是一个拿着放大镜的小朋友,专注地观察每一块拼图的细节(局部特征);而ViT则像是个站在高处的小大人,一眼就能看出拼图的整体布局(全局特征)。现在问题来了:能不能让这两个小朋友合作完成拼图呢?

在计算机视觉领域,CNN通过卷积核滑动捕捉局部特征,就像用3×3或5×5的窗口扫描图像。而Vision Transformer(ViT)通过自注意力机制,让图像块之间建立全局联系。举个具体例子:

# 技术栈:PyTorch
# 典型CNN局部特征提取示例
import torch.nn as nn

class CNNBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # 3x3卷积保留空间信息
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # 输出特征图尺寸与输入相同,但每个点只感知周围3x3区域
        return self.relu(self.conv(x))
# 技术栈:PyTorch
# 典型ViT全局注意力示例
from torch import nn
import math

class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        
        self.to_qkv = nn.Linear(dim, dim*3)  # 生成Q,K,V
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        # x形状: [batch, num_patches+1, dim]
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(*t.shape[:2], self.heads, -1).transpose(1, 2), qkv)
        
        # 计算所有图像块间的注意力权重
        dots = (q @ k.transpose(-1, -2)) * self.scale
        attn = dots.softmax(dim=-1)
        
        # 输出形状: [batch, heads, num_patches+1, dim//heads]
        out = (attn @ v).transpose(1, 2).reshape(x.shape)
        return self.to_out(out)

二、融合架构的三种经典范式

2.1 串行式融合:先局部后全局

这种设计就像先让CNN小朋友拼好局部区域,再交给ViT小朋友调整整体布局。典型的代表是Convolutional Vision Transformer(CvT):

# 技术栈:PyTorch
# CvT的简化实现
class CVTBlock(nn.Module):
    def __init__(self, dim, kernel_size=3):
        super().__init__()
        # 先用卷积进行局部建模
        self.conv_proj = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)
        # 再进行注意力计算
        self.attention = Attention(dim)
        
    def forward(self, x):
        # 输入形状: [batch, dim, height, width]
        x = self.conv_proj(x)  # 深度可分离卷积
        # 转换形状为ViT需要的格式
        batch, dim, h, w = x.shape
        x = x.flatten(2).transpose(1, 2)  # [batch, h*w, dim]
        x = self.attention(x)
        x = x.transpose(1, 2).view(batch, dim, h, w)
        return x

优点:保留了CNN的平移等变性,适合处理纹理丰富的图像
缺点:全局注意力计算开销随分辨率平方增长

2.2 并行式融合:双管齐下

让CNN和ViT同时工作,最后融合结果。就像两个小朋友分别拼图,然后对比修正:

# 技术栈:PyTorch
class ParallelBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv_branch = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.GELU()
        )
        self.attn_branch = Attention(dim)
        self.proj = nn.Conv2d(dim*2, dim, 1)

    def forward(self, x):
        # 卷积分支处理
        conv_out = self.conv_branch(x)
        
        # 注意力分支需要转换形状
        batch, dim, h, w = x.shape
        attn_in = x.flatten(2).transpose(1, 2)
        attn_out = self.attn_branch(attn_in)
        attn_out = attn_out.transpose(1, 2).view(batch, dim, h, w)
        
        # 特征拼接与融合
        fused = torch.cat([conv_out, attn_out], dim=1)
        return self.proj(fused)

适用场景:需要同时捕捉局部细节和全局关系的任务,如医学图像分析

2.3 层次化融合:渐进式抽象

这种设计模仿人类视觉系统,底层用CNN,高层用ViT。代表工作是CoAtNet:

# 技术栈:PyTorch
class CoAtBlock(nn.Module):
    def __init__(self, dim, stage):
        super().__init__()
        if stage == 'early':
            # 早期阶段纯卷积
            self.block = nn.Sequential(
                nn.Conv2d(dim, dim, 3, padding=1),
                nn.BatchNorm2d(dim),
                nn.GELU()
            )
        elif stage == 'middle':
            # 中期MBConv风格
            self.block = nn.Sequential(
                nn.Conv2d(dim, dim*4, 1),
                nn.GELU(),
                nn.Conv2d(dim*4, dim*4, 3, padding=1, groups=dim*4),
                nn.GELU(),
                nn.Conv2d(dim*4, dim, 1)
            )
        else:
            # 后期纯注意力
            self.block = Attention(dim)
    
    def forward(self, x):
        return self.block(x)

技术要点:随着网络深度增加,逐步从局部处理过渡到全局建模

三、关键技术挑战与解决方案

3.1 计算效率问题

全局注意力的计算复杂度是O(n²),当处理高分辨率图像时会显存爆炸。解决方案包括:

# 技术栈:PyTorch
# 使用局部窗口注意力示例
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size=7):
        super().__init__()
        self.window_size = window_size
        self.attn = Attention(dim)
        
    def forward(self, x):
        # x形状: [batch, dim, h, w]
        batch, dim, h, w = x.shape
        x = x.view(batch, dim, h//self.window_size, self.window_size, 
                  w//self.window_size, self.window_size)
        x = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, self.window_size*self.window_size, dim)
        
        # 只在窗口内计算注意力
        x = self.attn(x)
        
        # 恢复原始形状
        x = x.view(batch, h//self.window_size, w//self.window_size, 
                  self.window_size, self.window_size, dim)
        return x.permute(0, 5, 1, 3, 2, 4).reshape(batch, dim, h, w)

3.2 位置信息编码

ViT需要显式的位置编码,而CNN天然具有位置信息。融合时需要特别注意:

# 技术栈:PyTorch
class HybridPositionEncoding(nn.Module):
    def __init__(self, dim, image_size=224, patch_size=16):
        super().__init__()
        # CNN风格的位置编码(通过零填充实现)
        self.conv_pos = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        
        # ViT风格的绝对位置编码
        num_patches = (image_size // patch_size) ** 2
        self.vit_pos = nn.Parameter(torch.randn(1, num_patches, dim))
        
    def forward(self, x, is_conv):
        if is_conv:
            return x + self.conv_pos(x)
        else:
            return x + self.vit_pos

四、实战应用与选型建议

4.1 典型应用场景

  1. 医疗影像分析:CNN捕捉细胞形态,ViT理解组织分布
  2. 自动驾驶:CNN处理近处物体细节,ViT理解远处场景关系
  3. 卫星图像:CNN识别建筑物纹理,ViT分析城市布局

4.2 技术选型对照表

架构类型 参数量 计算成本 适合分辨率 典型应用
串行式融合 中等 中等 中等 通用图像分类
并行式融合 较大 较高 较高 精细粒度分类
层次化融合 可变 可变 宽范围 多尺度目标检测

4.3 注意事项

  1. 当训练数据少于100万张时,优先考虑卷积为主的架构
  2. 输入分辨率超过512×512时,建议采用局部注意力机制
  3. 部署到边缘设备时,可以使用MobileNet+轻量级ViT的混合结构
# 技术栈:PyTorch
# 边缘设备友好型混合架构示例
class EdgeHybrid(nn.Module):
    def __init__(self):
        super().__init__()
        # 轻量级CNN主干
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU6(),
            # 更多深度可分离卷积层...
        )
        
        # 稀疏注意力头
        self.attn = Attention(dim=16, heads=4)
        
    def forward(self, x):
        x = self.cnn(x)  # 下采样到较低分辨率
        batch, c, h, w = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.attn(x)
        x = x.transpose(1, 2).view(batch, c, h, w)
        return x

五、未来发展方向

  1. 动态融合:根据输入内容自动调整CNN和ViT的权重比例
  2. 神经架构搜索:自动寻找最优混合模式
  3. 3D视觉扩展:将融合思路推广到视频理解和体积数据

最终选择哪种融合方式,就像决定让两个小朋友如何合作拼图——没有绝对正确的答案,关键要看具体的拼图(任务)特点。在实践中,建议先用现成的混合架构(如MobileViT)作为基线,再根据实际需求进行调整。记住,好的设计应该让CNN和ViT各展所长,而不是简单堆砌。