深度嵌入聚类算法 DEC

如何对图片做聚类,我的直觉是:先用预训练模型计算图片 Embedding,然后用 cosine 度量的 DBSCAN 无监督地计算图片 label,再用 MLP 有监督地学习上一步产生的 label。DBSCAN 的好处是可以把无监督转为有监督,且由于它基于密度的特性,还不需要指定聚类的类别数,这方便了生产环境使用,因为生产环境通常也是不知道类别数的。MLP 的好处是可以对输入泛化,即使没见过的输入,在不重训练的情况下,也可以有一个对应输出。图片特征提取器 + 传统聚类 + 神经网络,是简单且符合直觉的方法,但恐怕不是最好的方法。

我们探索使用更端到端的方法实现图片聚类:DINOv2 特征提取器 + DEC 聚类器

使用 DEC 的好处起码有两点。一是让训练过程更简单,端到端的架构肯定比两阶段模型的架构更简单。二是 DEC 用特征向量表示聚类中心,这和传统聚类用标签表示不同。特征向量表示的类心更便于微调和增量更新。K-Means 每次更新 label 都是乱的,需要用匈牙利算法,对前后两次结果进行桥接。神经网络在这一点上天生有优势,因为它是顺着梯度一点一点更新的,所以前后两次结果是天然有联系,并且可以限制更新的幅度。

GitHub 项目地址:dec-pytorch

本文的工作包括:

  • 用 DINOv2 模型生成图片 Embeddings
  • 用 FastAPI 开发 DINOv2 批量推理服务,支持分 batch 和 模型结果归一化
  • 训练 DEC 模型的三阶段:训练降噪自编码器、初始化聚类中心、训练 DEC
  • 开发集成的 DEC 训练框架,支持训练、推理、保存,详见 dec.py
  • 在我的数据集上,对比 DEC 与传统聚类算法的效果:与 K-Means 接近
  • 介绍 DEC 的创新点:软分配策略和目标分布优化
  • 在线学习探索:尝试两种思路,对 DEC 模型做小幅度的增量更新

✨ DEC 论文在这里 Unsupervised Deep Embedding for Clustering Analysis.

一、使用 DINOv2 生成图片 Embedding

本节我们来完成三项任务:

  1. 下载 DINOv2 模型文件,并完成单张图片的 Embedding 推理
  2. 先实现多张图片在 CPU 上的推理。然后进阶一点,固定 batch_size 参数,在 GPU 上实现分 batch 批量推理
  3. 开发 FastAPI 推理服务。输入图片的 base64,输出 Embedding

dinov2:

目录:

  1. 从 huggingface 下载模型文件
  2. 计算图片 Embedding
  3. 批量计算图片 Embedding
    • 在 CPU 上批量推理
    • 在 GPU 上批量推理
  4. 批量推理服务化
    • 启动服务端
    • 运行客户端


二、Embedding 数据准备

上一节中,我们已经开发了一个 FastAPI 推理服务,用于将图片转成 DINOv2 推理的 Embedding。

本节我们来做数据准备。首先下载 CIFAR-100 数据集,它是一个由 100 类图片组成的图片分类数据集,其中训练集有 5 万张图片,测试集有 1 万张图片。下载完成后,使用上一节搭建的 FastAPI 服务,将图片转成 Embedding,然后用 csv 格式存储 Embedding 和对应标签。

为了方便 DataFrame 和 csv 之间的转换,我开发了两个工具函数,见 utils.py

  • embedding_df_to_csv(): 将 DataFrame 存入 csv
  • read_embedding_csv(): 从 csv 读入 DataFrame

目录:

  1. 下载 CIFAR-100 数据集
  2. 图片转 Embedding


三、DEC 模型训练

DEC(Deep Embedded Clustering,深度嵌入聚类)是一种结合深度学习与 K-Means 聚类的算法。核心思想是通过联合优化特征表示和聚类目标,提升传统聚类方法在高维数据上的效果。

目录:

  1. 加载 Embedding 数据
  2. 训练 DEC 模型
    • 初始化配置:初始化设备;定义评估指标函数
    • 定义降噪自编码器:支持加入掩蔽噪声或高斯噪声;添加了 L2 归一化
    • 定义主要组件:target_distribution, ClusterAssignment, DEC
    • 阶段一:训练降噪自编码器
    • 阶段二:初始化聚类中心
    • 阶段三:训练 DEC
    • 保存最优模型
    • 计算指标
  3. 推理新数据
  4. 评估


四、对比传统聚类算法

在同一个数据集上,对比 DEC 与 K-Means, DBSCAN 两种传统聚类算法的效果。

目录:

  1. 加载数据
  2. 评估函数
  3. K-means 算法
  4. DBSCAN 算法
  5. 结论


五、深入学习 DEC 模型

第三节,我通过吃百家饭复现了 DEC 模型。不知道是不是我的问题,pt-dec 仓库的 DEC 模型的 loss 不降反升(见 test_ptdec.ipynb)。好在该仓库已经实现了 DEC 论文中几个重要的类和函数,将它们拼接一番,也顺利把模型跑起来了。

上一节,我们比较了 DEC 模型和传统的 K-Means 模型的准确率。在我的数据集上,它们准确率类似,都在 0.7 左右。这个结果不意外,因为 DEC 本身就是用 K-Means 来初始化聚类中心的。

经过前两节的工作,我们对 DEC 模型有了初步的理解,希望在这里停下来总结一下。

目录:

  1. 模型的创新点
    • 聚类中心初始化
    • 软分配策略
    • 目标分布优化
    • 模型训练
  2. 训练阶段
  3. 聚类标签匹配问题(匈牙利算法)
  4. 模型推理和优化
    • 模型推理
    • 模型优化


六、探索:在线学习

我希望当新一批 embeddings 进入时,只进行少量的训练。既让模型适应新数据,又尽量不使原来的 embedding - label 映射发生偏移。

我的计划是:

  • 先训练一次 DEC 模型
  • 再将原本一半样本丢弃,加入与丢弃数量相同的新样本

观察模型在新数据集上准确率是否有改善,以及聚类中心的变动是否平缓。

  1. 初次训练 DEC 模型
  2. 生成新样本
  3. 增量训练
    • 原模型在新数据集上的效果
    • 思路一:移动聚类中心
    • 思路二:重训练拟合目标分布的阶段