电脑基础 · 2023年3月31日

一文讲解thop库计算FLOPs问题

问题

计算模型的FLOPs及参数大小

FLOPS是处理器性能的衡量指标,是“每秒所执行的浮点运算次数”的缩写。

FLOPs是算法复杂度的衡量指标,是“浮点运算次数”的缩写,s代表的是复数。

一般使用thop库来计算,GitHub: https://github.com/Lyken17/pytorch-OpCounter

from thop import profile
from thop import clever_format
input = torch.randn(1, 3, 512, 512)
model = Model()
flops, params = profile(model, inputs=(input, ))
flops, params = clever_format([flops, params], "%.3f")

但官网的Readme中详细写出了是用来计算MACs,而不是FLOPs的

from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))

MACs(Multiply-Accumulates)和 FLOPs(Floating-Point Operations)都是用来度量深度学习模型的计算量的指标。它们都可以用来衡量模型的计算复杂度,但是它们的具体定义略有不同。

MACs 是模型中所有乘加运算(即一个乘法和一个加法)的总数,因此,它可以衡量模型的计算性能和效率。通常,MACs 用于评估卷积神经网络(CNN)和其他基于矩阵乘法的模型,如循环神经网络(RNN)和自注意力模型(transformer)等。在实践中,MACs 被广泛用于模型的优化和压缩。

FLOPs 表示模型中所有浮点运算的总数,包括加、减、乘、除等运算。FLOPs 可以衡量模型的浮点运算量和计算成本。通常,FLOPs 用于评估基于全连接层的模型,如 MLP(多层感知器)和基于线性变换的模型,如语言模型和传统的机器学习模型。

MACs 和 FLOPs 之间没有固定的对应关系,因为它们的定义和应用范围略有不同。然而,在某些情况下,它们之间存在一定的相关性。例如,对于一个基于卷积神经网络的模型,可以通过计算 MACs 和 FLOPs 的比值来大致估计模型的运行时间。因为卷积操作中大多数乘法和加法都是浮点数,所以这个比值通常在 1-2 之间。然而,对于其他类型的模型,这个比值可能会有很大的差异,因此,需要根据具体情况来选择使用 MACs 还是 FLOPs 进行模型评估。

当我们使用yolov5官方模型时,训练时会打印模型的参数及FLOPs

在yolov5源码中,模型的FLOPs是这样计算的

调用thop库中的profile计算FLOPs

try:  # FLOPs
    from thop import profile
    stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
    # input
    # img = torch.zeros((1, model.yaml.get('ch', 3), stride * 8, stride * 8), device=next(model.parameters()).device)  # 帮助理解如何计算FLOPs的尝试
    img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device)  # input
    flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2  # stride GFLOPs
    img_size = img_size if isinstance(img_size, list) else [img_size, img_size]  # expand if int/float
    fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride)  # 640x640 GFLOPs

主要是这一句:

 flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2  # stride GFLOPs

因为GFLOPs 指的是每秒十亿次浮点运算(Giga Floating Point Operations per Second)GFLOPs和FLOPs是1e9关系,后面乘以2则认为FLOPs是MACs的2倍

是不是我们乘以2就可以了

目前大家推荐使用torchstat 库来计算FLOPs

from torchstat import stat

导入模型,输入一张输入图片的尺寸

stat(model, (3, 224, 224))

一文讲解thop库计算FLOPs问题

会输出FLOPs,其实和

flops, params = profile(model, inputs=(input, ))

计算出来的差不多,偏小一点

结论

flops, params = profile(model, inputs=(input, ))

这个命令够用了

Ref:

  1. https://github.com/Lyken17/pytorch-OpCounter

神经网络学习小记录72——Parameters参数量、FLOPs浮点运算次数、FPS每秒传输帧数等计算量衡量指标解析_Bubbliiiing的博客-CSDN博客