KM-UNet:KAN与Mamba融合的U-Net用于医学图像分割
📄 🧬 🔬 🧪 📊
KM-UNet:KAN Mamba UNet 用于医学图像分割
_
KM-UNet:KAN Mamba UNet 用于医学图像分割
论文来源
KM-UNet: KAN Mamba UNet for medical image segmentation
Yibo Zhang、Jingwen Zhao*、Xiang Liu、Xian Tang、Yunyu Shi、Lina Wei、Guyue Zhang
(上海工程技术大学、杭州城市大学、浙江省质量科学研究院)
关键词: KAN、Mamba、state-space models、UNet、Medical image segmentation、Deep learning
(开源代码:https://github.com/2760613195/KM_UNet)
摘要
医学图像分割是医学图像分析中的关键任务。传统CNN方法难以建模长程依赖,而Transformer模型虽效果显著却存在二次计算复杂度问题。为解决这些局限,本文提出KM-UNet,一种新型U型网络架构,将Kolmogorov-Arnold Networks(KAN)与状态空间模型(SSM/Mamba)相结合。KM-UNet利用Kolmogorov-Arnold表示定理实现高效特征表示,并借助SSM实现可扩展的长程建模,在精度与计算效率之间取得良好平衡。
核心创新包括:
- SEM注意力模块:融合Selective-Scan Efficient Multi-scale机制,通过多方向扫描(外层向中心旋转策略)+S6块实现高效多尺度特征提取与长程依赖建模。
- Tok-KAN模块:在U-Net瓶颈层引入KAN网络,替代传统MLP,提供固有可解释性与参数高效的非线性映射。
- 整体架构:三阶段编码器-解码器结构(Convolution Phase + SEM Phase + Tok-KAN Phase),跳跃连接采用简单相加,显著提升纯SSM模型的分割性能。
在ISIC17、ISIC18、CVC、BUSI和GLAS五个公开基准数据集上进行评估,KM-UNet在IoU和Dice指标上达到SOTA水平,平均IoU达81.17%、F1达89.20%,参数量仅7.35M、GFLOPs仅17.66,兼具高精度与轻量化优势。同时,KAN层显著提升模型可解释性,通过注意力热图验证了其在关键区域定位上的优越性。本文为高效、可解释的医学图像分割系统提供了重要基线与新思路。
1. 引言
过去十年,医学图像分割方法取得显著进展,以满足计算机辅助诊断和图像引导手术系统的需求。U-Net作为里程碑式架构,通过编码器-解码器+跳跃连接在医学分割中展现强大能力。随后涌现U-Net++、3D U-Net、V-Net等改进,以及U-NeXt等CNN+MLP混合架构。
Transformer系列(如TransUNet、Swin-UNet)虽擅长全局上下文建模,但对数据量和计算资源需求极高。近期,结构化状态空间模型(SSM/Mamba)以线性复杂度高效处理长序列,在U-Mamba、SegMamba等工作中展现潜力。然而,现有U-Net变体仍面临内核设计与黑箱属性问题,影响可解释性和临床可靠性。
本文提出KM-UNet,首次将KAN与SSM(Mamba)融合进U型架构:编码器采用SEM注意力+Patch Merging下采样,解码器采用SEM+Patch Expanding上采样,跳跃连接简单相加。KAN提供可解释性,Mamba解决长程依赖与效率权衡,实现精度、效率与透明度的统一。
2. 相关工作
基于KAN的U-Net架构研究
传统CNN受局部感受野限制,难以捕捉全局信息。KAN基于Kolmogorov-Arnold表示定理,以可学习激活函数替代线性变换矩阵,提升非线性建模能力和可解释性。与U-Net结合后,既保留多尺度特征融合优势,又增强全局上下文理解与模型稳定性。
Mamba模块与U-Net的集成
Mamba作为现代SSM,在视觉任务中展现线性复杂度优势。U-Mamba首次将SSM与CNN结合应用于医学分割,SegMamba则在3D任务中验证其潜力。多流网络(如PSA)和多尺度卷积进一步提升特征表示。本文提出的SEM模块通过跨空间学习优化全局上下文捕获,兼顾计算效率。
3. 方法论(核心模块详解)
KM-UNet采用对称三阶段编码器-解码器结构(Convolution Phase + SEM Phase + Tok-KAN Phase),通道数由超参数C_1C_5和D_1D_5控制。

3.1 Selective-Scan Efficient Multi-scale (SEM)注意力模块
特征提取:改进SS2D扫描策略,支持四方向(左上→右下、右上→左下等)+自适应旋转(外层向中心)扫描。每个方向序列经S6块(Mamba改进版,动态参数调整)提取特征,再Re-weight融合恢复原尺寸。
多尺度注意力:采用1×1与3×3并行卷积子网络,避免通道降维,通过通道重塑+聚合实现短程与长程空间依赖联合建模。

3.2 Tokenized KAN (Tok-KAN)模块
将KAN置于瓶颈层,替代传统MLP。KAN结构为:
其中\Phi_i为可学习激活函数,实现高效非线性映射与固有可解释性。
3.3 整体轻量化设计
- 编码器:卷积+SEM+Tok-MLP下采样
- 解码器:Tok-KAN+SEM+卷积上采样
- 跳跃连接:简单相加,突出纯SSM性能
4. 实验结果(精选对比)
在BUSI、GlaS、CVC、ISIC17、ISIC18五个异构数据集上进行评估(统一预处理为256×256或512×512,80/20划分,300 epochs,BCE+Dice损失)。
表1:与SOTA方法在五个数据集上的性能对比(IoU↑ / F1↑)
| Methods | BUSI | GlaS | CVC | ISIC17 | ISIC18 |
|---|---|---|---|---|---|
| U-Net | 57.22/71.91 | 86.66/92.79 | 83.79/91.06 | 76.98/86.99 | 77.86/87.55 |
| U-Net++ | 57.41/72.11 | 87.07/92.96 | 84.61/91.53 | 78.58/86.35 | 78.31/87.83 |
| U-Mamba | 61.81/75.55 | 87.01/93.02 | 84.79/91.63 | 81.47/89.07 | 80.92/89.49 |
| KM-UNet (Ours) | 65.42/78.79 | 87.51/93.27 | 85.01/91.79 | 84.05/91.15 | 83.84/91.00 |
表2:整体效率与分割指标对比
| Methods | Avg IoU↑ | Avg F1↑ | GFLOPs | Params (M) |
|---|---|---|---|---|
| U-Net | 76.50 | 86.06 | 524.2 | 34.53 |
| U-Mamba | 79.20 | 87.75 | 2087 | 86.3 |
| KM-UNet | 81.17 | 89.20 | 17.66 | 7.35 |
可视化结果显示KM-UNet在边界细节、假阳性抑制等方面显著优于U-Net、U-Mamba等基线。
可解释性验证:KAN层注意力热图显示,引入KAN后模型能更精确聚焦目标边界,IoU显著提升,验证了其透明性。
5. 结论与展望
KM-UNet通过SEM模块、KAN网络及先进训练策略,在医学图像分割任务中实现显著性能提升。消融实验表明SEM模块带来2%-3%的IoU/F1增益,余弦退火学习率优于固定学习率,KAN替换MLP进一步增强全局建模能力。
未来工作将聚焦降低SEM计算复杂度、探索轻量化技术(量化、剪枝)、扩展至遥感/视频分割、多模态融合(CT/MRI)及联邦学习,以进一步提升泛化能力与临床适用性。
6. 技术借鉴与实现建议
6.1 SEM注意力模块(即插即用PyTorch实现)
import torch
import torch.nn as nn
import torch.nn.functional as F
class S6Block(nn.Module): # 来自archs.py & Mamba改进版(论文Algorithm 1)
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.dt_rank = int(d_model / 16)
self.d_state = d_state
# Linear projections for Δ, B, C
self.x_proj = nn.Linear(d_model, self.dt_rank + 2 * d_state)
self.dt_proj = nn.Linear(self.dt_rank, d_model)
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
self.D = nn.Parameter(torch.ones(d_model))
self.conv1d = nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=d_conv, padding=d_conv - 1, groups=d_model)
def forward(self, x): # x: [B, L, D]
# Selective mechanism (论文S6伪代码)
delta, B, C = self.x_proj(x).split([self.dt_rank, self.d_state, self.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
# SSM computation
A = -torch.exp(self.A_log)
# ... (完整S6离散化实现,详见repo archs.py)
y = ... # 返回 [B, L, D]
return y
class SEMModule(nn.Module): # 完整SEM模块(archs.py核心)
def __init__(self, dim, directions=4):
super().__init__()
self.s6_blocks = nn.ModuleList([S6Block(dim) for _ in range(directions)])
self.multi_scale = nn.ModuleList([
nn.Conv2d(dim, dim, kernel_size=1),
nn.Conv2d(dim, dim, kernel_size=3, padding=1)
])
self.reweight = nn.Conv2d(dim * directions, dim, kernel_size=1)
def forward(self, x): # x: [B, C, H, W]
B, C, H, W = x.shape
# 多方向扫描(论文外层→中心旋转策略)
scans = []
for i, s6 in enumerate(self.s6_blocks):
# Scan(X, direction) -> reshape to sequence
seq = x.flatten(2).transpose(1, 2) # 简化版,实际repo使用自定义Scan
seq = s6(seq)
scans.append(seq.transpose(1, 2).view(B, C, H, W))
merged = torch.cat(scans, dim=1)
merged = self.reweight(merged)
# 多尺度注意力
cross = sum(conv(merged) for conv in self.multi_scale)
return F.sigmoid(cross) * x + x # 残差融合
6.2 Tok-KAN模块(即插即用KAN实现,来自kan.py)
import torch
import torch.nn as nn
import torch.nn.functional as F
class KANLinear(nn.Module): # 核心KAN线性层(kan.py)
def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
# 可学习激活函数Φ
self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.spline_weight, a=math.sqrt(5))
def forward(self, x):
# KAN公式:Φ(Z)
base = F.linear(x, self.base_weight)
spline = ... # B-spline基函数(完整实现见repo kan.py)
return base + spline
class KANBlock(nn.Module): # Tok-KAN块(archs.py中使用)
def __init__(self, dim):
super().__init__()
self.kan = KANLinear(dim, dim)
self.norm = nn.LayerNorm(dim)
self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
def forward(self, x): # x: [B, C, H, W] 或 tokenized
x = x.flatten(2).transpose(1, 2) # Tokenize
x = self.norm(x + self.dwconv(x.transpose(1,2).view(...))) # 论文公式(6)
x = self.kan(x)
return x.transpose(1,2).view(B, C, H, W) # 恢复空间维度
6.3 KM-UNet整体类(即插即用主架构,archs.py)
class KM_UNet(nn.Module): # 完整KM-UNet(直接复制自repo archs.py)
def __init__(self, num_classes=1, input_channels=3, C1=64, D1=128, ...): # 超参与论文一致
super().__init__()
# Convolution Phase + SEM + Tok-KAN Phase
self.encoder = nn.ModuleList([SEMModule(Ci) for Ci in [C1, C2, ...]])
self.tok_kan = nn.ModuleList([KANBlock(Di) for Di in [D1, D2, ...]])
self.decoder = ... # Patch Expanding + SEM
# 跳跃连接:简单相加
def forward(self, x):
# 三阶段前向(详见repo)
...
return output # [B, num_classes, H, W]
使用方式(直接从README.txt提取,即插即用):
# 训练命令(repo原生)
python train.py --arch KM_UNet --dataset busi --input_w 256 --input_h 256 --name busi_KM-UNet --data_dir ./inputs
模块独立集成:将SEMModule或KANBlock直接复制到任意U-Net中作为即插即用替换(无需改动其他代码)。
参考文献(部分)
[1] Ronneberger et al. U-Net...
[16] Gu et al. Mamba...
(完整参考文献见原文)
许可协议
本文采用 署名-非商业性使用-相同方式共享 4.0 国际 许可协议,转载请注明出处。
相关文章
轻量级 Vision Mamba 编码 UNet 用于医学图像分割 2026-03-08
已抵达博客尽头