设备支持时,一般使用低精度的模型运行速度会翻倍,重点是尽可能保证质量。为此在单图过拟合任务下,尝试构建对比实验探究直接量化与量化感知蒸馏的差距。
环境:
torch 2.10.0+cu130 pypi_0 pypi
torchvision 0.25.0+cu130 pypi_0 pypi
对一张1536x1536分辨率图像进行拟合,首先使用FP32精度训练,每50轮评估并保存最优权重。
- 对于INT8精度,量化组对FP32模型用自定义qconfig做FX prepare,校准视图做observer统计,convert成INT8模型;蒸馏组使用量化感知蒸馏QAD,在模型插入FakeQuantize模拟int8量化误差,让模型拟合原始图像并模拟教师模型输出,最终通过convert_fx转为INT8。
- 对于FP16精度,量化组直接转为half;蒸馏组使用FP32权重作为教师模型,使用AMP在cuda上做fp16训练,没有使用QAD的方法,主要是不会有严重数值裁切和分布失配。
我们对FP32模型训练与学生模型训练都进行600轮次。
INT8 结果

可以发现直接量化(PTQ方法)结果偏绿,后续不论我如何优化qconfig与提高calibration参数,也无法做到良好的质量,我也不觉得是int8的容量不够,毕竟蒸馏模型能跑到36.228。我认为可能是PTQ统计方法不对任务,观察图像可以发现轮廓和光影等表现都良好,可能是颜色通道激活值范围被低估而被裁剪了?
至于蒸馏方法,训练时蒸馏损失参数设置为0.2,学习率为1e-3。我们的结果相比教师损失约为10%,这是符合我的预期的,我也没有进行更多调参。
同时我还跑了一次损失参数为0,也就是类似原生训练一个int8精度的模型,他的结果为PSNR=33.0175 SSIM=0.9039 LPIPS=0.0655,并没有打过量化感知蒸馏。
基于上面的结果,能够说明模型在单图拟合任务上,做量化感知蒸馏相比直接蒸馏与原生低精度训练,能得到更好的质量。
FP16 结果

蒸馏方法训练时蒸馏损失参数设置为0.5,学习率为1e-3
根据结果,量化方法的表现要好于蒸馏方法,这可能是因为直接转换精度几乎无损,而蒸馏需要重新训练,收敛可能不足。
如果将训练轮次拉到1200,就可以达到 psnr=42.3066 ssim=0.9893 lpips=0.0062
结合INT8的结果来看,与教师模型训练同轮次时,在更低精度下,量化感知蒸馏的作用应该会更加明显;而半精度下使用直接量化的方法,损失比较微小,更加适合。
尾记
根据资料描述torchao库(0.17)似乎对conv支持有限,没能用GPU原生跑int8的训练,我感觉是少一个对比实验的。
INT8组,我们训练模型时其实都是跑在fp32上,这样最终的评分,在转int8前后的表现是会出现不一致的,一般是会降低一些,毕竟FakeQuant只是近似量化噪声,与真正转换可能不同。
Comments NOTHING