在确定了流模型和扩散模型的定义之后,需要设计一个合适的神经网络架构来将我们需要的向量场参数化成utθ(xy)u_t^\theta(x|y)。这个函数处理三个输入并产生一个输出:

  • 输入:空间向量 xRdx \in \mathbb{R}^d(图像像素或潜变量);引导变量 yYy \in \mathcal{Y}(如文本、类别等);时间步长 t[0,1]t \in [0, 1]
  • 输出:预测的向量场 utθ(xy)Rdu_t^\theta(x|y) \in \mathbb{R}^d
    对于低维的任务(如简单的分类),可以使用 MLP 通过拼接 (x,y,t)(x, y, t) 就可以工作。但是在维度更高的图像或者视频生成中,需要更加复杂的设计来提取和处理带噪图像,提示词和时间中的信息,这个时候就需要更加复杂的架构设计。

嵌入条件变量

时间嵌入

对于比较简单的模型,直接将时间步tt的原始值拼接到输入中就足以训练一个可以工作的网络。但是实际情况下会将时间步使用傅立叶特征映射嵌入到一个高维空间中,确保模型能够捕捉不同噪声水平下高频的时间依赖:

TimeEmb(t)=2d[cos(2πw1t)cos(2πwd/2t),sin(2πw1t)sin(2πwd/2t)]TTimeEmb(t) = \sqrt{\frac{2}{d}} [ \cos(2\pi w_1 t) \dots \cos(2\pi w_{d/2} t), \sin(2\pi w_1 t) \dots \sin(2\pi w_{d/2} t) ]^T

其中频率系数 ωi\omega_i表示为:

wi=wmin(wmaxwmin)i1d/21,i=1,,d/2w_i = w_{min} \left( \frac{w_{max}}{w_{min}} \right)^{\frac{i-1}{d/2-1}}, \quad i=1, \dots, d/2

时间嵌入并不需要严格遵守以上的形式,但是这种方法可以简单地将时间映射到维度为dd的隐空间中,且TimeEmb(t)=1(sin2+cos2=1)||TimeEmb(t)||=1 \quad (sin^2 + cos^2 = 1)

类嵌入

如果yraw{0,,N}y_{raw} \in \{0, \dots, N\}只是一个类别标签,一般最简单的方法是为yrawy_{raw}N+1N+1种可能情况学习一个分离的嵌入向量yy。这些参数可以被包含在之前参数化的向量场中,因此可以在训练中学习。

文本提示词嵌入

yrawy_{raw}是一段文本提示词时,则需要使用更加复杂的模型架构——不同于简单的分类标签,我们需要使用冻结的预训练模型,将成段的文本提示词嵌入到一个连续的向量空间中。比较常见的是 CLIP 和 Transformer:

  • CLIP:生成全局语义嵌入 y=CLIP(yraw)RdCLIPy = CLIP(y_{raw}) \in \mathbb{R}^{d_{CLIP}}
  • T5:提供细粒度的序列嵌入,形如 PromptEmbed(yraw)RS×kPromptEmbed(y_{raw}) \in \mathbb{R}^{S \times k},允许模型通过注意力机制关注特定单词 。

Diffusion Transformer

DiT 结构

一个图片可以被标记成一个张量 xRCimage×H×Wx \in \mathbb{R}^{C_{image} \times H \times W}CimageC_{image}是图片的通道数(如一张RGB图片的通道数为3)。DiT使用注意力机制来构造参数化的向量场,首先定义隐藏维度dd,Transformer层数LL,每层的注意力头数hh
和 Vision Transformer 相同,DiT也将图片拆解成若干个 Patch 并嵌入成一个 token 序列,再使用 Transformer 处理,最后使用 Depatchification 操作回复成相同尺寸的图像张量 xRC×H×Wx \in \mathbb{R}^{C \times H \times W}

模型结构如下:

  1. Patch 化:输入一张图片张量 xRC×H×Wx \in \mathbb{R}^{C \times H \times W},以 P×PP \times P 的 Patch 大小,产生 N=(H/P)(W/P)N=(H/P) \cdot (W/P) 个 Patch;每个 Patch 做 Patchify 之后的的维度 C=CP2C' = CP^2。表示为 Patchify(x)RN×CPatchify(x) \in \mathbb{R}^{N \times C'}
  2. Patch 嵌入:学习一个矩阵 WRC×dW \in \mathbb{R}^{C' \times d} 来将每个 Patch Token 潜入到隐藏空间 dd 中。表示为 x~0=PatchEmb(x)=Patchify(x)WRN×d\tilde{x}_0=PatchEmb(x)=Patchify(x)W \in \mathbb{R}^{N \times d}
  3. 时间/提示词嵌入t~=TimeEmb(t)Rd\tilde{t} = TimeEmb(t) \in \mathbb{R}^dy~=PromptEmbed(y)RS×d\tilde{y} = PromptEmbed(y) \in \mathbb{R}^{S \times d}
  4. 输入 DiT:对每一个 DiT Block,其接受 x~0,t~,y~\tilde{x}_0, \tilde{t}, \tilde{y} 作为输入并依次通过 ii 层 DiT Block。对每一层 Block,有:x~i+1=DiTBlock(x~i,t~,y~)RN×d,(i=0,,L1)\tilde{x}_{i+1} = DiTBlock(\tilde{x}_i, \tilde{t}, \tilde{y}) \in \mathbb{R}^{N \times d}, \quad (i=0, \dots, L-1)。简单来说,每个 Block 通过 对 Patches 做自注意力 对提示词做交叉注意力对时间做 AdaLN 自适应层归一化 来更新 xix_i
  5. 解 Patch:学习一个矩阵 W~Rd×C\tilde{W} \in \mathbb{R}^{d \times C'} 来将 DiT 的输出映射回到一张图片,即:u=Depatchify(x~NW~)RC×H×Wu = Depatchify(\tilde{x}_N \tilde{W}) \in \mathbb{R}^{C \times H \times W}。这个 uu 就是模型的输出,也就是我们需要预测的向量场 utθ(xy)u_t^\theta(x|y)

DiT Block

进入到 DiT Block ,了解其内部构造:

  1. 缩放点积注意力

    Attn(Q,K,V)=softmax(QKTdh)VRN×dhAttn(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_h}}\right)V \in \mathbb{R}^{N \times d_h}

  2. 多头注意力:学习投影矩阵 WQ(h),WK(h),WV(h)Rd×dhW_Q^{(h)}, W_K^{(h)}, W_V^{(h)} \in \mathbb{R}^{d \times d_h}。定义:headh(x,z)=Attn(xWQ(h),zWK(h),zWV(h))head_h(x, z) = Attn(xW_Q^{(h)}, zW_K^{(h)}, zW_V^{(h)}) 其中当 z=xz=x 时为自注意力,当 z=yz=y 时为交叉注意力。 最后拼接并投影:

    MultiHeadAttention(x,z)=Concat(head1,,headh)WORN×d\text{MultiHeadAttention}(x, z) = \text{Concat}(head_1, \dots, head_h)W_O \in \mathbb{R}^{N \times d}

  3. AdaLN 自适应层归一化:使用 MLP g:RdR2dg: \mathbb{R}^d \rightarrow \mathbb{R}^{2d} 从时间步嵌入 t~\tilde{t} 预测缩放因子 γ\gamma 和偏移因子 β\beta(γ,β)=g(t~)(\gamma, \beta) = g(\tilde{t}),其中 γ,βRd\gamma, \beta \in \mathbb{R}^d。AdaLN 在初始化时将 gg 设置为输出全零。此时 γ,β=0\gamma, \beta = 0,公式退化为 1Norm(x)+01 \cdot \text{Norm}(x) + 0,使得每个 DiT Block 初始状态下是一个恒等函数,极大增强了大模型的训练稳定性 。

    AdaNormt~(x)=(1+γ)Norm(x)+βAdaNorm_{\tilde{t}}(x) = (1 + \gamma) \odot \text{Norm}(x) + \beta

最终得到完整的计算流程:

  • 自注意力:xx+gself(t~)MultiHeadAttention(AdaNormt~(x),AdaNormt~(x))x \leftarrow x + g_{self}(\tilde{t}) \cdot \text{MultiHeadAttention}(AdaNorm_{\tilde{t}}(x), AdaNorm_{\tilde{t}}(x))
  • 交叉注意力:xx+gcross(t~)MultiHeadAttention(AdaNormt~(x),y~)x \leftarrow x + g_{cross}(\tilde{t}) \cdot \text{MultiHeadAttention}(AdaNorm_{\tilde{t}}(x), \tilde{y})
  • 前馈网络:xx+gMLP(t~)MLP(AdaNormt~(x))x \leftarrow x + g_{MLP}(\tilde{t}) \cdot MLP(AdaNorm_{\tilde{t}}(x)) (注:g(t~)g_{\dots}(\tilde{t}) 为学习到的门控参数)

U-Net

U-Net 架构 是扩散模型中除 DiT 之外的另一种主流选择。它本质上是一种 卷积神经网络 CNN。U-Net 最初是为医学图像分割设计的,其最核心的特征在于:输入和输出的形状完全一致。
在扩散模型中,我们需要构建一个参数化的向量场:

xutθ(xy)x \mapsto u_t^\theta(x|y)

由于在固定条件 yy 和时间 tt 时,输入 xx 是图像形状,输出 uu 也必须是图像形状,这使得 U-Net 成为参数化该向量场的理想选择 。

U-Net 结构

一个典型的 U-Net 由以下三部分序列组成 :

  1. 编码器序列 Encoders:记为 Ei\mathcal{E}_i。负责特征提取和空间压缩。
  2. 解码器序列 Decoders:记为 Di\mathcal{D}_i。负责特征融合和空间还原。
  3. 中层处理块 Midcoder:记为 M\mathcal{M}。这是位于编码器和解码器最底层的连接部分(“Midcoder”是一个非标准术语,这里用来指代 U-Net 最底层的瓶颈层) 。

计算流程

以处理一张尺寸为 256×256256 \times 256 的 RGB 图像 xtx_t 为例,假设输入维度 (Cinput,H,W)=(3,256,256)(C_{input}, H, W) = (3, 256, 256)。数据在网络中的流动过程如下 :

  • 输入阶段xtinputR3×256×256x_t^{input} \in \mathbb{R}^{3 \times 256 \times 256}
    这是初始的带噪图像或潜变量,拥有 3 个颜色通道,分辨率为 256×256256 \times 256
  • 编码压缩阶段xtlatent=E(xtinput)R512×32×32x_t^{latent} = \mathcal{E}(x_t^{input}) \in \mathbb{R}^{512 \times 32 \times 32}
    数据经过一系列编码器 E\mathcal{E}。在此过程中,图像的分辨率(H,WH, W)从 256256 降至 3232,而通道数从 33 激增至 512512。这体现了卷积网络“用空间分辨率换取特征深度”的逻辑。
  • 中间处理阶段xtlatent=M(xtlatent)R512×32×32x_t^{latent} = \mathcal{M}(x_t^{latent}) \in \mathbb{R}^{512 \times 32 \times 32}
    数据通过中层块 M\mathcal{M}。这里的维度保持不变,主要进行深层语义特征的进一步加工。
  • 解码还原阶段xtoutput=D(xtlatent)R3×256×256x_t^{output} = \mathcal{D}(x_t^{latent}) \in \mathbb{R}^{3 \times 256 \times 256}
    数据经过解码器 D\mathcal{D}。通过上采样,分辨率重新从 3232 回复到 256256,通道数也还原回初始的 33

设计细节

  • 层组成:编码器和解码器通常由一系列卷积层组成,中间夹杂着激活函数(如 ReLU)、池化操作(Pooling)等 。
  • 预编码块:在进入第一个编码器块之前,输入 xtinputR3×256×256x_t^{input} \in \mathbb{R}^{3 \times 256 \times 256} 通常会先进入一个初始块,目的是在不改变分辨率的情况下先增加通道数 。
  • 残差连接:这是 U-Net 的灵魂。编码器和解码器之间存在直接的连接 。
    • 作用:编码器在压缩过程中会丢失大量精细的空间位置信息。跳跃连接通过将编码器的特征直接传递给解码器,帮助模型找回“在哪里 (Where)”的信息,从而生成细节锐利、像素对齐准确的图像 。
    • 如果没有它:模型可能只能生成大致的颜色和轮廓(“是什么”),但会丢失精细的纹理和清晰的边缘 。