深度学习工具箱 FlameAI

在上一篇博客《LightGBM 工程实践》中,我把通用函数写到 util.py 文件,以实现代码复用。但是这么复用好麻烦,每次开新项目,得把 util.py 文件贴过去才行。为了省下这个贴的动作,我把它写成 Python 包了。现在下载一次,就可以到处使用啦!

GitHub 项目地址:flameai

一、简单介绍

我的这个包叫 FlameAI。作为一个深度学习工具包,它的主要功能是 数据预处理模型评估。FlameAI 旨在解决最后一公里问题。即在框架之外,业务代码之内,寻求最佳实践。为了让我的包看起来上流一点,我让 Kimi 帮我想了几个名字,最后还是选择了 FlameAI,因为这个名字最霸气。

执行以下命令安装 FlameAI:

pip install flameai

如果嫌安装速度慢,可以使用阿里源:

pip install flameai -i https://mirrors.aliyun.com/pypi/simple/

二、模块功能介绍

2.1 二分类模型评估

二分类模型评估是咱们的拳头产品。只需一行命令,就可以计算下面全部指标:

  • accuracy: 准确率
  • precision: 精确率
  • recall: 召回率
  • f1_score: 精确率和召回率的调和平均数
  • auc: ROC曲线下的面积
  • cross-entropy loss: 交叉熵损失
  • True Positive (TP): 真阳
  • True Negative (TN): 真阴
  • False Positive (FP): 假阳
  • False Negative (FN): 假阴
  • confusion matrix: 混淆矩阵
1)简单用法

你可以直接指定阈值,下面代码中指定阈值为 0.5

Note: 这里的阈值指二分类模型划分正负样本的阈值。对每个样本,模型输出一个概率值,一般而言,若概率小于阈值,则判为负类;若大于等于阈值,则判为正类。

import flameai

# y_true 是真实标号,y_pred 是模型预测标号为 1 的概率
y_true = [0, 0, 0, 1, 0, 1, 0, 1, 1, 0]
y_pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

# 一行命令,计算全部指标
flameai.eval_binary(y_true, y_pred, threshold=0.5)

输出:

threshold: 0.50000
accuracy: 0.70000
precision: 0.60000
recall: 0.75000
f1_score: 0.66667
auc: 0.70833
cross-entropy loss: 4.03816
True Positive (TP): 3
True Negative (TN): 4
False Positive (FP): 2
False Negative (FN): 1
confusion matrix:
[[4 2]
 [1 3]]
2)对阈值寻优

你也可以不设定阈值,让 FlameAI 帮你寻找最优阈值。你只需要做设定优化目标就好了,FlameAI 会帮你找。

FlameAI 内置四个目标函数,分别是:

  • 准确率:Metric.ACCURACY
  • 精确率:Metric.PRECISION
  • 召回率:Metric.RECALL
  • f1 分数:Metric.F1_SCORE

下面这个例子在参数空间搜索最优阈值,以使精确率最大化。

from flameai import eval_binary, Metric

y_true = [0, 0, 0, 1, 0, 1, 0, 1, 1, 0]
y_pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

# 以最大化精确率为目标,对阈值寻优
eval_binary(y_true, y_pred, metric=Metric.PRECISION)
3)定义返回值

有时候,你想在代码中使用最优阈值,且不希望结果打印到标准输出。这可以通过设置 retverbose 参数实现:

from flameai import eval_binary, Metric

y_true = [0, 0, 0, 1, 0, 1, 0, 1, 1, 0]
y_pred = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

y_label, threshold = eval_binary(y_true,
                                 y_pred,
                                 metric=Metric.F1_SCORE,
                                 ret=True,  # 返回预测标号和阈值
                                 verbose=0)  # 不打印评估指标
print(f'y_label: {y_label}')
print(f'threshold: {threshold:.3f}')

输出:

y_label: [0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
threshold: 0.378

2.2 绘制 ROC 曲线

运行以下代码绘制 ROC 曲线:

from flameai.plot import roc_curve

y_true = [0, 1, 1, 0, 1, 1, 0, 0, 1, 1]
y_score = [0.1, 0.4, 0.35, 0.8, 0.15, 0.35, 0.2, 0.7, 0.05, 0.9]

roc_curve(y_true, y_score)

绘制效果:

2.3 绘制混淆矩阵

运行以下代码绘制混淆矩阵:

from flameai.plot import confusion_matrix

y_true = [0, 1, 1, 0, 1, 1, 0, 0, 0, 1]
y_label = [1, 1, 1, 0, 1, 0, 0, 0, 1, 1]

confusion_matrix(y_true, y_label)

绘制效果:

2.4 线性回归模型评估

eval_regression 函数可输出三个线性回归评估指标,分别是:

  • mae: 平均绝对误差
  • mse: 均方误差
  • r2_score: 确定系数

运行以下代码,对回归模型的结果进行评估:

import flameai

# y_true 是真实值,y_pred 是预测值
y_true = [0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 1.00]
y_pred = [0.11, 0.23, 0.29, 0.45, 0.50, 0.59, 0.72, 0.76, 0.94, 1.00]

flameai.eval_regression(y_true, y_pred)

输出:

mae:      0.02100
mse:      0.00073
r2_score: 0.99115

2.5 数据预处理

下面是我实现的野生 DataLoader,它是 torch.utils.data.DataLoader 的简化版本。它复刻了官方 DataLoader 的关键特性:每次调用索引从零开始迭代。凭此特性,DataLoader 能够在多轮训练中重复获取样本。

from flameai.preprocessing import DataLoader

dt = DataLoader([1, 2, 3, 4, 5])
for i in dt:
    print(i)
for i in dt:
    print(i)

dt.data = [1, 2, 3]
print([e for e in dt])

我的实现很简单,只有 23 行:

class DataLoader:
    def __init__(self, lst: list):
        self.i = 0
        self._data = lst

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, lst: list):
        self._data = lst

    def __iter__(self):
        self.i = 0
        return self

    def __next__(self):
        if self.i < len(self._data):
            self.i += 1
            return self._data[self.i - 1]
        else:
            raise StopIteration

PS: 装饰器所在的两个函数不是必须的。如果把它们去掉,还能变得更短

2.6 Hive 命令行工具

FlameAI 提供了一个 Hive 命令行工具,支持把 hql 文件的查询结果,输出到同名 csv 文件。下面是一个例子:

假设你已经把 HQL 查询语句写到 query.hql 文件里了。执行以下命令,将执行查询,并把查询结果输出到 query.csv 文件。

hve query

2.7 标准日志输出

FlameAI 支持用内部预设的 logging 模版输出日志。

from flameai.util import set_logger

logger = set_logger(__name__)

logger.error('This is an error message!')
logger.warning('This is an warning message!')

直接运行,输出大概长这样:

2024-05-25 23:23:27 ERROR [__main__]: (set_logger:<module>(5)) - This is an error message!
2024-05-25 23:23:27 WARNING [__main__]: (set_logger:<module>(6)) - This is an warning message!

如果在模块中使用,__main__ 会变成当前文件的无后缀文件名,module 会变成模块名。