博客
关于我
【单源雷达回波外推1】DTCA: Precipitation Nowcasting Using Diffusion Transformer With Causal Attention
阅读量:451 次
发布时间:2019-03-06

本文共 2896 字,大约阅读时间需要 9 分钟。

论文解读与代码分析

背景与动机

传统数值天气预报(NWP)在细尺度、对流性降水预报上存在初始条件不足和计算资源限制等问题。传统深度学习方法(例如基于 U-Net 的模型)虽然在图像任务上表现良好,但在捕捉长时空依赖、全局信息以及条件与预测结果之间的因果关系上存在不足。为了解决这些问题,近年来生成模型(例如 GAN、VAE、扩散模型)被引入降水预报任务。扩散模型在稳定性、生成质量和多样性上展现出明显优势,并且通过学习数据分布来实现条件采样,从而生成更真实的预报结果。

主要贡献

本文的核心创新体现在以下几点:

  • 引入因果注意力机制:设计了基于条件降水分布特征的查询-键值机制,使得预测时能够通过因果注意力有效地访问历史降水条件(原因),从而动态调整预测结果。这种设计提高了模型对长时空依赖关系的捕获能力与预测解释性。

  • 提出多种时空特征交互变体:为了探讨不同的时空信息建模方式,论文设计了四种变体:

    • 全联合时空 DTCA Block(Variant 1):将空间和时间维度联合处理;
    • 分离时空 DTCA Block(Variant 2):先分别处理空间和时间,再融合;
    • 半联合时空+空间 DTCA Block(Variant 3):先进行联合时空处理,再侧重空间特征提取;
    • 半联合时空+时间 DTCA Block(Variant 4):在联合处理之后重点建模时间演变。
  • 引入 CTBS 操作:该操作将部分通道信息迁移到 batch 维度,从而增加数据多样性、加强特征表示。这一轻量级操作只需几行代码实现,但对捕获复杂降水动态有明显提升。

  • 方法流程概述

    模型采用扩散模型框架,在潜空间内进行降噪预测。核心部分为基于 Transformer 的模块,通过自注意力机制捕捉长程依赖,同时利用因果注意力机制将条件信息引入预测过程,确保预测与历史降水信息之间的因果一致性。具体流程如下:

  • 数据预处理与嵌入:首先利用 AutoEncoder 将高分辨率的降水图像序列压缩到低维潜空间,然后按照 ViT 的方式将潜空间划分为不重叠的时空 patch,再映射成 token 序列。前几帧作为条件输入,其余帧作为预测目标。

  • 扩散模型与 Transformer 架构:模型采用扩散模型框架,在潜空间内进行降噪预测。核心部分为基于 Transformer 的模块,通过自注意力机制捕捉长程依赖,同时利用因果注意力机制将条件信息引入预测过程,确保预测与历史降水信息之间的因果一致性。

  • 预测与重构:最终通过解码器将潜空间中的预测结果映射回像素空间,得到未来时刻的降水场预报。实验结果显示,在重预测重灾降水等关键指标上相较于传统 U-Net 模型有显著提升。

  • 代码解读

    论文对应的代码主要分布在两个文件中,下面分别解释核心模块及实现思路。

    1. PixArt_blocks.py

    该文件主要定义了 PixArt 模块中所使用的一系列基础组件,构成了 Transformer 模型的“构建块”。

    • 调制函数(modulate 系列函数):这些函数的作用是对输入特征进行尺度(scale)与偏移(shift)的调制。通过对特征进行线性变换,结合条件信息,使得网络能够自适应地调整各层输出,防止梯度消失和促进信息融合。

    • MultiHeadCrossAttention 类:实现了多头交叉注意力机制,其思路是:

      • 查询(query)来源于图像 token,键值(key, value)则由条件信息得到。
      • 使用 xformers 库中的高效注意力计算接口,实现内存高效的注意力运算。 这一模块在 DTCA 模型中用于将条件信息与预测信息进行交互,体现论文中因果注意力的设计理念。
    • WindowAttention 类:继承自已有的 Attention 模块,并扩展了相对位置编码功能。通过在注意力计算中引入位置信息,可以更好地捕捉局部与全局空间关系,对于降水图像中不均匀分布的区域尤其重要。

    • FinalLayer、T2IFinalLayer 等终层模块:这些层主要用于将 Transformer 模块的输出映射回最终的图像或潜空间表示,同时结合条件调制(例如 AdaLN 模块),保证生成结果在风格和数值上与条件输入保持一致。

    • TimestepEmbedder 和 LabelEmbedder:分别用于将时间步(diffusion timestep)和类别标签嵌入到高维空间,生成用于条件调制的向量表示。这对扩散模型中时间信息的编码非常关键,保证模型在不同扩散步骤下的行为一致。

    2. DTCA.py

    该文件实现了整个 DTCA 模型,是论文提出方法的完整实现。

    • MyConditiontrsBlock 类:这是一个自定义的 Transformer Block,融合了以下特点:

      • 因果注意力:在块内通过调用 t2i_modulate 函数对输入进行调制,再经过自注意力模块与交叉注意力模块实现条件信息的注入。
      • 多阶段处理:先进行自注意力操作(利用经过调制后的归一化输入),再通过交叉注意力将条件信息(如历史观测、流动信息)与当前特征融合。之后再经过 MLP 模块进一步提取特征。
      • 其中使用了 scale_shift_table(与时间步嵌入 t 相结合)来生成动态调制参数,实现了论文中提到的自适应特征调制。
    • MyTransformerBlock 类:为标准的 Transformer Block,包含 LayerNorm、自注意力和 MLP 模块,并使用 DropPath 实现随机深度丢弃(stochastic depth),以增强模型鲁棒性。

    • DTCA 类(模型主类):这是整个模型的核心实现:

      • 输入处理:通过 PatchEmbed 将输入图像(或潜空间表示)划分为 tokens,并加上空间和时间位置编码。
      • 条件信息嵌入:利用 condition_embedder 将条件图像进行嵌入,再加上专门的条件时空位置编码。
      • 时间步嵌入与调制:通过 TimestepEmbedder 生成时间信息向量,并经过线性层映射后传递给各个 Transformer Block,用于生成调制参数。
      • Transformer 层堆叠:模型由多个 MyConditiontrsBlock 叠加而成,部分模块还会结合交叉注意力进一步整合条件信息。
      • 最终映射:经过多层 Transformer 后,利用 T2IFinalLayer 将 token 序列映射回原始图像空间(或潜空间),并通过 unpatchify 操作还原为完整图像。
      • 扩展功能:代码中还包括了时空分离建模(Dividespacetime)和联合建模(jion_sapce_times)的实现,分别对应论文中不同的时空交互变体设计。

    此外,DTCA 模型通过注册模块(@MODELS.register_module())便于在更大框架中调用,同时结合了渐进式训练、相对位置编码、drop path 等多种技术,综合实现了论文提出的基于扩散 Transformer 和因果注意力的降水预报框架。

    转载地址:http://jnhyz.baihongyu.com/

    你可能感兴趣的文章
    mysql 死锁 Deadlock found when trying to get lock; try restarting transaction
    查看>>
    mysql 死锁(先delete 后insert)日志分析
    查看>>
    MySQL 死锁了,怎么办?
    查看>>
    MySQL 深度分页性能急剧下降,该如何优化?
    查看>>
    MySQL 深度分页性能急剧下降,该如何优化?
    查看>>
    MySQL 添加列,修改列,删除列
    查看>>
    mysql 添加索引
    查看>>
    MySQL 添加索引,删除索引及其用法
    查看>>
    mysql 状态检查,备份,修复
    查看>>
    MySQL 用 limit 为什么会影响性能?
    查看>>
    MySQL 用 limit 为什么会影响性能?有什么优化方案?
    查看>>
    MySQL 用户权限管理:授权、撤销、密码更新和用户删除(图文解析)
    查看>>
    mysql 用户管理和权限设置
    查看>>
    MySQL 的 varchar 水真的太深了!
    查看>>
    mysql 的GROUP_CONCAT函数的使用(group_by 如何显示分组之前的数据)
    查看>>
    MySQL 的instr函数
    查看>>
    MySQL 的mysql_secure_installation安全脚本执行过程介绍
    查看>>
    MySQL 的Rename Table语句
    查看>>
    MySQL 的全局锁、表锁和行锁
    查看>>
    mysql 的存储引擎介绍
    查看>>