8. 条件控制之ControlNet

图像生成的一个重要问题就是如何控制生成的图像,前面已经介绍了如何对图像生成模型进行 guidance, 预训练的模型 Stable Diffusion 支持 text guidance。 然而文本的描述能力其实是非常有限的,图像的很多细节是无法用文本进行清晰描述的, 比如人的姿态(姿势)信息,只靠文字去描述是非常困难的,更何况目前的 text prompt 都是关键词(短语), 对于完整的句子支持并不好。 引入如果用更详细的信息去指导图像生成就称为了一个重要的研究方向, 这其中用一张额外的图像去控制图像生成过程就是一种显而易见的手段, 比如用一张人物姿态图去控制生成过程,使其生成一张完全符合姿态图的图像, 如图 8.2 所示,

../_images/controlnet_pose.png

图 8.2 利用人物姿态(pose)图控制图像生成,令其生成符合姿态图的人物图像。图片来自 []

本章要介绍的 ControlNet [1] 就是这样一种的技术,它是在预训练的 Stable Diffusion 模型基础上, 为其增加了一个额外的输入,这个输入是另外一张图像,比如姿态图、线稿、草图等等, 这个额外的输入作为 Stable Diffusion 的一个控制条件(condition), 它可以控制 Stable Diffusion 生成的图像结果,使其符合我们输入的条件图像特征。

8.1. 算法原理

我们知道,Stable Diffusion 本身是支持条件输入的,它的神经网络模型可以用如下函数表示

(8.1.42)\[\epsilon_{\theta}(z_t,t,\tau_{\theta}(y))\]

其中 \(\epsilon_{\theta}\) 代表 Stable Diffusion 的关键神经网络, 它的网络实现采用的是 Unet 网络, 它的输入分别是加噪后的潜空间图像数据 \(z_t\), 表示时刻的 \(t\), 以及表示条件控制信息的 \(\tau_{\theta}(y)\) 。标准版的 Stable Diffusion 本身支持文本条件控制信息, 它是通过引入 OpenClip 模型的文本编码器(Text Encoder)来实现对文本信息的编码 ,也就是 \(\tau_{\theta}(y)\) 就是 OpenClip 的 Text Encoder。

(8.1.43)\[\epsilon_{\theta}(z_t,t,\tau_{\text{OpenClip}}(\text{text prompt}))\]

现在我们期望再额外引入图像控制信息,其实也就是对 \(\tau_{\theta}(y)\) 进行扩展。 由于我们引入的条件控制信息是图像,所以就涉及到图像的编码处理(特征转换、特征抽取)问题, 可以最直接的想法是,引入一个图像编码器(Image Encoder)来处理条件图像信息, 比如把 OpenClip 的 Image Encoder 部分拿过来直接用。

(8.1.44)\[\epsilon_{\theta}(z_t,t,\tau_{\text{OpenClip}}(\text{text prompt}),\tau_{\text{OpenClip}}(\text{image prompt}))\]

但 ControlNet 不是这样的做的,它并没有从别处拿来一个图像编码器(Image Encoder), 而是直接利用 Stable Diffusion 自己的 Unet 来作为图像编码器, 接下来看下具体是怎么做的。

首先回顾一下 Unet 的结构,Unet 的网络结构呈现一个 U 型,按照左侧、中间底部、右侧的划分方式,可以将整个网络结构分为三个子部分:

  • 左侧部分,一共包含12个网络块(blocks),把输入的 \(64 \times 64\) 的 latent image 逐步降维到 \(8 \times 8\), 这个过程类似于压缩过程,所以可以称为编码器(encoder),也可以称为降采样(down sample)。

  • 中间底部部分,U 型结构的最低层,负责处理 \(8 \times 8\) 的数据,一般称为 Middle block。

  • 右侧部分,同样包含12个网络块(blocks),把输入的 \(8 \times 8\) 的 latent image 逐步升维到 \(64 \times 64\), 这个过程类似于解压过程,所以可以称为解码器(decoder),也可以称为上采样(up sample)。

图 8.1.3 中左侧的子图(a),就是标准版 Stable Diffusion 的 Unet 网络结构, 只是这里把 U 型结构给拉直了。右侧子图(b)是增加的 ControlNet 平行网络,ControlNet 是的网络结构是左侧 Unet 的一个镜像,只是去掉了 Decoder 部分,只包含 Encoder 部分和 Middle 部分,并且为每一个 block 套上了一个 \(1 \times 1\) 的卷积层(Convolution)。

../_images/controlnet_unet.png

图 8.1.3 左侧(a)图是标准版 Stable Diffusion 的Unet 网络,右侧(b)图是 ControlNet 的平行网络。图片来自 []

图 8.1.4 是具体一个 block 的展示, 同样左侧子图(a) 是标准版 Stable Diffusion 的一个 block, 右侧子图(b) 是 ControlNet 对它的一个复制镜像,但是在这个 block 的输出后追加了一个 \(1 \times 1\) 的卷积层(Convolution)。 ControlNet 的 block 和 Unet 的 Encoder、Middle 是一一对应的, ControlNet block 的输出会和对应的 Unet Encoder(Middle) block 的输出 元素 加在 一起, 然后一起通过跳线输入给对应的 Unet Decoder block 。 注意,这个 ControlNet 的 block 输出和 Unet Encoder(Middle) block 输出加一起的结果, 并不会向下传递给下一层的 Unet Encoder(Middle) block,只会通过跳线输入给对应的 Unet Decoder block, 所以增加了 ControlNet 并不影响原来 Unet 的网络结构

../_images/controlnet_block.png

图 8.1.4 block 的细节展示,左侧(a)图是标准版 Stable Diffusion 的一个 block,右侧(b)图是 ControlNet 对它的复制镜像, 并在后面追加了一个 \(1 \times 1\) 的卷积层(Convolution)。图片来自 []

ControlNet 并没有从别处拿来一个 Image Encoder 来用, 而是利用 Stable Diffusion 自己的 Unet 网络在作为条件图像信息的编码器, 它把预训练好的 Stable Diffusion 的 Unet(部分) 复制了一份,作为条件图像信息的编码器, 而又由于 Unet 由三部分组成:左侧编码器、中间部分、右侧解码部分, ControlNet 只复制了其中左侧编码器和中间部分,并在每一个block 后面追加一个 \(1 \times 1\) 的卷积。ControlNet 每一个 block 的输出,会通过跳线连接到 Stable Diffusion 的 Unet 右侧解码器对应部分。 图 8.1.4 画成一个 W 型可能更直观一些。

训练过程

ControlNet 在训练时,原始 Stable Diffusion 的 Unet 部分是冻住的,参数不更新, 仅仅 ControlNet 部分的参数进行更新,这里有两个细节

  1. ControlNet 中从 Unet 复制过来的 block,其参数初始值就是 Unet 的参数值,参数值直接复制过来作为初始值。

  2. ControlNet 中 \(1 \times 1\) 的卷积层其初始值为 \(0\)

由于 \(1 \times 1\) 卷积层的初始值为 \(0\), 未训练过的 ControlNet 各个block 输出都是 \(0\), 这时相当于 ControlNet 不起作用,模型等价于原始的 Stable Diffusion。 通过一步步迭代训练,ControlNet 逐步发挥作用。 这样做的好处是,不对原始 Stable Diffusion 的能力造成大的破坏, 能够保持原始 Stable Diffusion 的能力。

当然这里可能有一个问题,初始值为 \(0\) 卷积层能否求出参数梯度,并正常进行参数更新, 论文作者已经给出了解释:是可以正常更新参数的。

这里可以从一个最简单的线性函数进行解释,

(8.1.45)\[y = \omega x + b\]

参数 \(\omega\) ,特征 \(x\) , 偏置 \(b\) 的偏导分别为

(8.1.46)\[\begin{split}\frac{\partial y}{\partial \omega} &= x \\ \frac{\partial y}{\partial x} &= \omega \\ \frac{\partial y}{\partial b} &= 1\end{split}\]

参数 \(\omega\) 的初始值为 \(0\) 并代表它的梯度为 \(0\) ,所以 \(0\) 初始化的卷积层是可以正常进行参数更新的。

增加了 ControlNet 后,整体的噪声预测参数化网络用函数表示为

(8.1.47)\[\epsilon_{\theta}(z_t,t,\tau_{\text{OpenClip}}(\text{text prompt}),\tau_{\text{ControlNet}}(\text{image prompt}))\]

8.2. 代码实现

ControlNet 的官方实现在 Github/ControlNet ,但这个官方实现略显复杂,这里我们还是推荐阅读 diffusers 的实现, 这里我们摘录其中关键的部分,帮助理解。

diffusers 中对原始 StableDiffusion 的实现是类 DiffusionPipeline, ControlNet 的实现是类 StableDiffusionControlNetPipeline ,其在原始类 DiffusionPipeline 的基础上增加了一个子模块 ControlNetModel ,ControlNet 网络的实现就是在类 ControlNetModel 中。

class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        # 在 StableDiffusion[DiffusionPipeline]的基础上增加了一个 controlnet 子模块
        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        requires_safety_checker: bool = True,
    ):

    ...

接下来看 StableDiffusionControlNetPipeline 图像生成过程中对 controlnet 部分的处理, 这里关键的部分有:

  1. 因为 Unet 处理的是 \(64 \times 64\) 的 Latent 空间的 Image,需要把 Condition Image 的尺寸转换为 \(64 \times 64\)

  2. 支持多个 ControlNetModel 一起生效,由类 MultiControlNetModel 负责处理多个 ControlNetModel

  3. 在降噪循环中,先计算 ControlNetModel ,再把 ControlNetModel 的输出传递给 Unet

def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
):

    ...
    # 4. Prepare image
    # 预处理 Condition image 信息,因为 Unet 处理的是 64 X64 的 Latent 空间的 image,
    # 所以这里需要 Condition image 的尺寸转换为 64 X 64
    if isinstance(self.controlnet, ControlNetModel): # 单个 ControlNetModel
        image = self.prepare_image(
            image=image,
            width=width,
            height=height,
            batch_size=batch_size * num_images_per_prompt,
            num_images_per_prompt=num_images_per_prompt,
            device=device,
            dtype=self.controlnet.dtype,
            do_classifier_free_guidance=do_classifier_free_guidance,
        )
    elif isinstance(self.controlnet, MultiControlNetModel): # 多个 ControlNetModel
        images = []

        for image_ in image:
            image_ = self.prepare_image(
                image=image_,
                width=width,
                height=height,
                batch_size=batch_size * num_images_per_prompt,
                num_images_per_prompt=num_images_per_prompt,
                device=device,
                dtype=self.controlnet.dtype,
                do_classifier_free_guidance=do_classifier_free_guidance,
            )

            images.append(image_)

        image = images
    else:
        assert False

    ...

    # 8. Denoising loop 降噪循环
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # controlnet(s) inference
            # controlnet 网络的计算,输出
            # down_block_res_samples:list  对应 Unet encoder blocks
            # mid_block_res_sample:list  对应 Unet middle blocks
            down_block_res_samples, mid_block_res_sample = self.controlnet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                controlnet_cond=image,
                conditioning_scale=controlnet_conditioning_scale,
                return_dict=False,
            )

            # predict the noise residual
            # Stable Diffusion 的 unet 网络计算
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                cross_attention_kwargs=cross_attention_kwargs,
                down_block_additional_residuals=down_block_res_samples, # controlnet 输出的 encoder blocks
                mid_block_additional_residual=mid_block_res_sample, # controlnet 输出的 middle blocks
            ).sample

            ...
        ...

接下来是 self.unet 中的关键部分,这里 self.unet 是类 UNet2DConditionModel 的实例。

def forward(
    self,
    sample: torch.FloatTensor,
    timestep: Union[torch.Tensor, float, int],
    encoder_hidden_states: torch.Tensor,
    class_labels: Optional[torch.Tensor] = None,
    timestep_cond: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
    mid_block_additional_residual: Optional[torch.Tensor] = None,
    return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
    ...
    # 3. 先逐层计算 Unet 的全部 encoder blocks
    down_block_res_samples = (sample,)
    for downsample_block in self.down_blocks:
        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
            sample, res_samples = downsample_block(
                hidden_states=sample,
                temb=emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )
        else:
            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

        down_block_res_samples += res_samples
    # 把 ControlNet 对应的输出 和 Unet 对应的  encoder block 的输出 加 在一起
    if down_block_additional_residuals is not None:
        new_down_block_res_samples = ()

        for down_block_res_sample, down_block_additional_residual in zip(
            down_block_res_samples, down_block_additional_residuals
        ):
            # 对应 block 进行元素加法
            down_block_res_sample = down_block_res_sample + down_block_additional_residual
            new_down_block_res_samples += (down_block_res_sample,)
        # 累加了 ControlNet 的结果,
        down_block_res_samples = new_down_block_res_samples

    # 4. 先逐层计算 Unet 的全部 middle blocks
    if self.mid_block is not None:
        sample = self.mid_block(
            sample,
            emb,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
        )

    # 把 ControlNet 对应的输出 和 Unet 对应的 middle block 的输出 加 在一起
    if mid_block_additional_residual is not None:
        sample = sample + mid_block_additional_residual

    ...

这里只是对 StableDiffusionControlNetPipeline 的一些关键点进行了展示,方便快速理解 ControlNet 的工作方式, 对更多细节感兴趣的可以直接查阅相关源码。

8.3. 最后的总结

ControlNet 是一个任务相关的端到端方法,即对于每一种控制类型都要训练一个特定的 ControlNet 支持, 比如线图控制、深度图控制、姿态控制等等。这样有好处也有坏处, 单独看一个场景,拥有使用简单、训练成本低等优点。 但是,如果面对一个复杂场景(多场景),反而变得略麻烦,每一个细分场景都要训练和维护一个模型, 成本高昂,也不易用。

我们留下一个思考:

  1. \(1 \times 1\) 的卷积是否必须得,不要行不行?

  2. 一个从 Unet 镜像的 ControlNet 网络是否必须得,不要行不行?

  3. 如果上述两个都可以去掉,那是不是意味着直接在 Stable Diffusion 基础上进行 Image Condition 训练就可以了? 图像生成时,初始化的随机噪声改为 \(x_T = x_{\text{Gaussian}} + x_{\text{condition image}}\) 或者是通道维度的拼接。

ControlNet 原论文其实有做一个对比消融实验(原论文中图20), 对比了 Stable Diffusion 官方实现的 depth2img,具体实现可以参考 diffusers.pipelines.StableDiffusionDepth2ImgPipeline ,在这个实现中就是把深度图和噪声图在通道维度进行拼接,然后一起输入给 UNET 网络, 当然论文中呈现的结果肯定是 ControlNet 效果更好,但是对于这个对比实验的细节没有过多说明, 不太确定到底是 ControlNet 网络结构带来的效果提升,还是两者的训练样本和训练过程不一样导致的, 其实就是怀疑 StableDiffusionDepth2ImgPipeline 效果不好可能是其因为本身没有训练好。 当然以上仅是个人的浅见,不保证正确。

8.4. 参考文献