如何对图片做聚类,我的直觉是:先用预训练模型计算图片 Embedding,然后用 cosine 度量的 DBSCAN 无监督地计算图片 label,再用 MLP 有监督地学习上一步产生的 label。DBSCAN 的好处是可以把无监督转为有监督,且由于它基于密度的特性,还不需要指定聚类的类别数,这方便了生产环境使用,因为生产环境通常也是不知道类别数的。MLP 的好处是可以对输入泛化,即使没见过的输入,在不重训练的情况下,也可以有一个对应输出。图片特征提取器 + 传统聚类 + 神经网络,是简单且符合直觉的方法,但恐怕不是最好的方法。
我们探索使用更端到端的方法实现图片聚类:DINOv2 特征提取器 + DEC 聚类器
使用 DEC 的好处起码有两点。一是让训练过程更简单,端到端的架构肯定比两阶段模型的架构更简单。二是 DEC 用特征向量表示聚类中心,这和传统聚类用标签表示不同。特征向量表示的类心更便于微调和增量更新。K-Means 每次更新 label 都是乱的,需要用匈牙利算法,对前后两次结果进行桥接。神经网络在这一点上天生有优势,因为它是顺着梯度一点一点更新的,所以前后两次结果是天然有联系,并且可以限制更新的幅度。
GitHub 项目地址:dec-pytorch
本文的工作包括:
- 用 DINOv2 模型生成图片 Embeddings
- 用 FastAPI 开发 DINOv2 批量推理服务,支持分 batch 和 模型结果归一化
- 训练 DEC 模型的三阶段:训练降噪自编码器、初始化聚类中心、训练 DEC
- 开发集成的 DEC 训练框架,支持训练、推理、保存,详见 dec.py
- 在我的数据集上,对比 DEC 与传统聚类算法的效果:与 K-Means 接近
- 介绍 DEC 的创新点:软分配策略和目标分布优化
- 在线学习探索:尝试两种思路,对 DEC 模型做小幅度的增量更新
✨ DEC 论文在这里 Unsupervised Deep Embedding for Clustering Analysis.
一、使用 DINOv2 生成图片 Embedding
本节我们来完成三项任务:
- 下载 DINOv2 模型文件,并完成单张图片的 Embedding 推理
- 先实现多张图片在 CPU 上的推理。然后进阶一点,固定 batch_size 参数,在 GPU 上实现分 batch 批量推理
- 开发 FastAPI 推理服务。输入图片的 base64,输出 Embedding
dinov2:
- GitHub: facebookresearch/dinov2
- Hugging Face: facebook/dinov2-base
目录:
- 从 huggingface 下载模型文件
- 计算图片 Embedding
- 批量计算图片 Embedding
- 在 CPU 上批量推理
- 在 GPU 上批量推理
- 批量推理服务化
- 启动服务端
- 运行客户端
二、Embedding 数据准备
上一节中,我们已经开发了一个 FastAPI 推理服务,用于将图片转成 DINOv2 推理的 Embedding。
本节我们来做数据准备。首先下载 CIFAR-100 数据集,它是一个由 100 类图片组成的图片分类数据集,其中训练集有 5 万张图片,测试集有 1 万张图片。下载完成后,使用上一节搭建的 FastAPI 服务,将图片转成 Embedding,然后用 csv 格式存储 Embedding 和对应标签。
为了方便 DataFrame 和 csv 之间的转换,我开发了两个工具函数,见 utils.py
:
embedding_df_to_csv()
: 将 DataFrame 存入 csvread_embedding_csv()
: 从 csv 读入 DataFrame
目录:
- 下载 CIFAR-100 数据集
- 图片转 Embedding
三、DEC 模型训练
DEC(Deep Embedded Clustering,深度嵌入聚类)是一种结合深度学习与 K-Means 聚类的算法。核心思想是通过联合优化特征表示和聚类目标,提升传统聚类方法在高维数据上的效果。
目录:
- 加载 Embedding 数据
- 训练 DEC 模型
- 初始化配置:初始化设备;定义评估指标函数
- 定义降噪自编码器:支持加入掩蔽噪声或高斯噪声;添加了 L2 归一化
- 定义主要组件:target_distribution, ClusterAssignment, DEC
- 阶段一:训练降噪自编码器
- 阶段二:初始化聚类中心
- 阶段三:训练 DEC
- 保存最优模型
- 计算指标
- 推理新数据
- 评估
四、对比传统聚类算法
在同一个数据集上,对比 DEC 与 K-Means, DBSCAN 两种传统聚类算法的效果。
目录:
- 加载数据
- 评估函数
- K-means 算法
- DBSCAN 算法
- 结论
五、深入学习 DEC 模型
第三节,我通过吃百家饭复现了 DEC 模型。不知道是不是我的问题,pt-dec 仓库的 DEC 模型的 loss 不降反升(见 test_ptdec.ipynb)。好在该仓库已经实现了 DEC 论文中几个重要的类和函数,将它们拼接一番,也顺利把模型跑起来了。
上一节,我们比较了 DEC 模型和传统的 K-Means 模型的准确率。在我的数据集上,它们准确率类似,都在 0.7 左右。这个结果不意外,因为 DEC 本身就是用 K-Means 来初始化聚类中心的。
经过前两节的工作,我们对 DEC 模型有了初步的理解,希望在这里停下来总结一下。
目录:
- 模型的创新点
- 聚类中心初始化
- 软分配策略
- 目标分布优化
- 模型训练
- 训练阶段
- 聚类标签匹配问题(匈牙利算法)
- 模型推理和优化
- 模型推理
- 模型优化
六、探索:在线学习
我希望当新一批 embeddings 进入时,只进行少量的训练。既让模型适应新数据,又尽量不使原来的 embedding - label 映射发生偏移。
我的计划是:
- 先训练一次 DEC 模型
- 再将原本一半样本丢弃,加入与丢弃数量相同的新样本
观察模型在新数据集上准确率是否有改善,以及聚类中心的变动是否平缓。
- 初次训练 DEC 模型
- 生成新样本
- 增量训练
- 原模型在新数据集上的效果
- 思路一:移动聚类中心
- 思路二:重训练拟合目标分布的阶段