KAN-Mamba:首次讲Kan与Mamba融合的医学图像分割

KAN-Mamba:首次讲Kan与Mamba融合的医学图像分割

_

KM-UNet:KAN与Mamba融合的U-Net用于医学图像分割

📄 🧬 🔬 🧪 📊


KM-UNet:KAN Mamba UNet 用于医学图像分割

_

KM-UNet:KAN Mamba UNet 用于医学图像分割

论文来源
KM-UNet: KAN Mamba UNet for medical image segmentation

Yibo Zhang、Jingwen Zhao*、Xiang Liu、Xian Tang、Yunyu Shi、Lina Wei、Guyue Zhang
(上海工程技术大学、杭州城市大学、浙江省质量科学研究院)

关键词: KAN、Mamba、state-space models、UNet、Medical image segmentation、Deep learning
(开源代码:https://github.com/2760613195/KM_UNet)

摘要

医学图像分割是医学图像分析中的关键任务。传统CNN方法难以建模长程依赖,而Transformer模型虽效果显著却存在二次计算复杂度问题。为解决这些局限,本文提出KM-UNet,一种新型U型网络架构,将Kolmogorov-Arnold Networks(KAN)与状态空间模型(SSM/Mamba)相结合。KM-UNet利用Kolmogorov-Arnold表示定理实现高效特征表示,并借助SSM实现可扩展的长程建模,在精度与计算效率之间取得良好平衡。

核心创新包括:

  • SEM注意力模块:融合Selective-Scan Efficient Multi-scale机制,通过多方向扫描(外层向中心旋转策略)+S6块实现高效多尺度特征提取与长程依赖建模。
  • Tok-KAN模块:在U-Net瓶颈层引入KAN网络,替代传统MLP,提供固有可解释性与参数高效的非线性映射。
  • 整体架构:三阶段编码器-解码器结构(Convolution Phase + SEM Phase + Tok-KAN Phase),跳跃连接采用简单相加,显著提升纯SSM模型的分割性能。

在ISIC17、ISIC18、CVC、BUSI和GLAS五个公开基准数据集上进行评估,KM-UNet在IoU和Dice指标上达到SOTA水平,平均IoU达81.17%、F1达89.20%,参数量仅7.35M、GFLOPs仅17.66,兼具高精度与轻量化优势。同时,KAN层显著提升模型可解释性,通过注意力热图验证了其在关键区域定位上的优越性。本文为高效、可解释的医学图像分割系统提供了重要基线与新思路。

1. 引言

过去十年,医学图像分割方法取得显著进展,以满足计算机辅助诊断和图像引导手术系统的需求。U-Net作为里程碑式架构,通过编码器-解码器+跳跃连接在医学分割中展现强大能力。随后涌现U-Net++、3D U-Net、V-Net等改进,以及U-NeXt等CNN+MLP混合架构。

Transformer系列(如TransUNet、Swin-UNet)虽擅长全局上下文建模,但对数据量和计算资源需求极高。近期,结构化状态空间模型(SSM/Mamba)以线性复杂度高效处理长序列,在U-Mamba、SegMamba等工作中展现潜力。然而,现有U-Net变体仍面临内核设计与黑箱属性问题,影响可解释性和临床可靠性。

本文提出KM-UNet,首次将KAN与SSM(Mamba)融合进U型架构:编码器采用SEM注意力+Patch Merging下采样,解码器采用SEM+Patch Expanding上采样,跳跃连接简单相加。KAN提供可解释性,Mamba解决长程依赖与效率权衡,实现精度、效率与透明度的统一。

2. 相关工作

基于KAN的U-Net架构研究
传统CNN受局部感受野限制,难以捕捉全局信息。KAN基于Kolmogorov-Arnold表示定理,以可学习激活函数替代线性变换矩阵,提升非线性建模能力和可解释性。与U-Net结合后,既保留多尺度特征融合优势,又增强全局上下文理解与模型稳定性。

Mamba模块与U-Net的集成
Mamba作为现代SSM,在视觉任务中展现线性复杂度优势。U-Mamba首次将SSM与CNN结合应用于医学分割,SegMamba则在3D任务中验证其潜力。多流网络(如PSA)和多尺度卷积进一步提升特征表示。本文提出的SEM模块通过跨空间学习优化全局上下文捕获,兼顾计算效率。

3. 方法论(核心模块详解)

KM-UNet采用对称三阶段编码器-解码器结构(Convolution Phase + SEM Phase + Tok-KAN Phase),通道数由超参数C_1C_5D_1D_5控制。

7.1 KM-UNet整体架构图.jpg

3.1 Selective-Scan Efficient Multi-scale (SEM)注意力模块

特征提取:改进SS2D扫描策略,支持四方向(左上→右下、右上→左下等)+自适应旋转(外层向中心)扫描。每个方向序列经S6块(Mamba改进版,动态参数调整)提取特征,再Re-weight融合恢复原尺寸。

多尺度注意力:采用1×1与3×3并行卷积子网络,避免通道降维,通过通道重塑+聚合实现短程与长程空间依赖联合建模。

7.2 SEM Modulejpg.jpg

3.2 Tokenized KAN (Tok-KAN)模块

将KAN置于瓶颈层,替代传统MLP。KAN结构为:

\text{KAN}(Z) = (\Phi_{K-1} \circ \Phi_{K-2} \circ \cdots \circ \Phi_0)Z

其中\Phi_i为可学习激活函数,实现高效非线性映射与固有可解释性。

3.3 整体轻量化设计

  • 编码器:卷积+SEM+Tok-MLP下采样
  • 解码器:Tok-KAN+SEM+卷积上采样
  • 跳跃连接:简单相加,突出纯SSM性能

4. 实验结果(精选对比)

在BUSI、GlaS、CVC、ISIC17、ISIC18五个异构数据集上进行评估(统一预处理为256×256或512×512,80/20划分,300 epochs,BCE+Dice损失)。

表1:与SOTA方法在五个数据集上的性能对比(IoU↑ / F1↑)

MethodsBUSIGlaSCVCISIC17ISIC18
U-Net57.22/71.9186.66/92.7983.79/91.0676.98/86.9977.86/87.55
U-Net++57.41/72.1187.07/92.9684.61/91.5378.58/86.3578.31/87.83
U-Mamba61.81/75.5587.01/93.0284.79/91.6381.47/89.0780.92/89.49
KM-UNet (Ours)65.42/78.7987.51/93.2785.01/91.7984.05/91.1583.84/91.00

表2:整体效率与分割指标对比

MethodsAvg IoU↑Avg F1↑GFLOPsParams (M)
U-Net76.5086.06524.234.53
U-Mamba79.2087.75208786.3
KM-UNet81.1789.2017.667.35

可视化结果显示KM-UNet在边界细节、假阳性抑制等方面显著优于U-Net、U-Mamba等基线。

可解释性验证:KAN层注意力热图显示,引入KAN后模型能更精确聚焦目标边界,IoU显著提升,验证了其透明性。

5. 结论与展望

KM-UNet通过SEM模块、KAN网络及先进训练策略,在医学图像分割任务中实现显著性能提升。消融实验表明SEM模块带来2%-3%的IoU/F1增益,余弦退火学习率优于固定学习率,KAN替换MLP进一步增强全局建模能力。

未来工作将聚焦降低SEM计算复杂度、探索轻量化技术(量化、剪枝)、扩展至遥感/视频分割、多模态融合(CT/MRI)及联邦学习,以进一步提升泛化能力与临床适用性。

6. 技术借鉴与实现建议

6.1 SEM注意力模块(即插即用PyTorch实现)

import torch
import torch.nn as nn
import torch.nn.functional as F

class S6Block(nn.Module):  # 来自archs.py & Mamba改进版(论文Algorithm 1)
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.dt_rank = int(d_model / 16)
        self.d_state = d_state
        # Linear projections for Δ, B, C
        self.x_proj = nn.Linear(d_model, self.dt_rank + 2 * d_state)
        self.dt_proj = nn.Linear(self.dt_rank, d_model)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
        self.D = nn.Parameter(torch.ones(d_model))
        self.conv1d = nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=d_conv, padding=d_conv - 1, groups=d_model)

    def forward(self, x):  # x: [B, L, D]
        # Selective mechanism (论文S6伪代码)
        delta, B, C = self.x_proj(x).split([self.dt_rank, self.d_state, self.d_state], dim=-1)
        delta = F.softplus(self.dt_proj(delta))
        # SSM computation
        A = -torch.exp(self.A_log)
        # ... (完整S6离散化实现,详见repo archs.py)
        y = ...  # 返回 [B, L, D]
        return y

class SEMModule(nn.Module):  # 完整SEM模块(archs.py核心)
    def __init__(self, dim, directions=4):
        super().__init__()
        self.s6_blocks = nn.ModuleList([S6Block(dim) for _ in range(directions)])
        self.multi_scale = nn.ModuleList([
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1)
        ])
        self.reweight = nn.Conv2d(dim * directions, dim, kernel_size=1)

    def forward(self, x):  # x: [B, C, H, W]
        B, C, H, W = x.shape
        # 多方向扫描(论文外层→中心旋转策略)
        scans = []
        for i, s6 in enumerate(self.s6_blocks):
            # Scan(X, direction) -> reshape to sequence
            seq = x.flatten(2).transpose(1, 2)  # 简化版,实际repo使用自定义Scan
            seq = s6(seq)
            scans.append(seq.transpose(1, 2).view(B, C, H, W))
        merged = torch.cat(scans, dim=1)
        merged = self.reweight(merged)
        # 多尺度注意力
        cross = sum(conv(merged) for conv in self.multi_scale)
        return F.sigmoid(cross) * x + x  # 残差融合

6.2 Tok-KAN模块(即插即用KAN实现,来自kan.py)

import torch
import torch.nn as nn
import torch.nn.functional as F

class KANLinear(nn.Module):  # 核心KAN线性层(kan.py)
    def __init__(self, in_features, out_features, grid_size=5, spline_order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        # 可学习激活函数Φ
        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.spline_weight, a=math.sqrt(5))

    def forward(self, x):
        # KAN公式:Φ(Z)
        base = F.linear(x, self.base_weight)
        spline = ...  # B-spline基函数(完整实现见repo kan.py)
        return base + spline

class KANBlock(nn.Module):  # Tok-KAN块(archs.py中使用)
    def __init__(self, dim):
        super().__init__()
        self.kan = KANLinear(dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)

    def forward(self, x):  # x: [B, C, H, W] 或 tokenized
        x = x.flatten(2).transpose(1, 2)  # Tokenize
        x = self.norm(x + self.dwconv(x.transpose(1,2).view(...)))  # 论文公式(6)
        x = self.kan(x)
        return x.transpose(1,2).view(B, C, H, W)  # 恢复空间维度

6.3 KM-UNet整体类(即插即用主架构,archs.py)

class KM_UNet(nn.Module):  # 完整KM-UNet(直接复制自repo archs.py)
    def __init__(self, num_classes=1, input_channels=3, C1=64, D1=128, ...):  # 超参与论文一致
        super().__init__()
        # Convolution Phase + SEM + Tok-KAN Phase
        self.encoder = nn.ModuleList([SEMModule(Ci) for Ci in [C1, C2, ...]])
        self.tok_kan = nn.ModuleList([KANBlock(Di) for Di in [D1, D2, ...]])
        self.decoder = ...  # Patch Expanding + SEM
        # 跳跃连接:简单相加

    def forward(self, x):
        # 三阶段前向(详见repo)
        ...
        return output  # [B, num_classes, H, W]

使用方式(直接从README.txt提取,即插即用):

# 训练命令(repo原生)
python train.py --arch KM_UNet --dataset busi --input_w 256 --input_h 256 --name busi_KM-UNet --data_dir ./inputs

模块独立集成:将SEMModuleKANBlock直接复制到任意U-Net中作为即插即用替换(无需改动其他代码)。

参考文献(部分)
[1] Ronneberger et al. U-Net...
[16] Gu et al. Mamba...
(完整参考文献见原文)

许可协议
本文采用 署名-非商业性使用-相同方式共享 4.0 国际 许可协议,转载请注明出处。


相关文章
轻量级 Vision Mamba 编码 UNet 用于医学图像分割 2026-03-08

已抵达博客尽头

轻量级 Vision Mamba 编码 UNet 用于医学图像分割 2026-03-08

评论区