구현 일지 - INPAINT

2024. 5. 25. 17:04딥러닝/딥러닝

Inpainting은 주어진 이미지의 특정 부분을 사실성을 훼손하지 않고 재구성하는 과정입니다.
본 포스팅에서는 노이즈 제거 확산 모델(denoising diffusion model)을 사용하여 Inpainting하는 방법을 설명하겠습니다.


Diffusion model에서는 학습이 완료된 후, 새로운 이미지는 샘플링 과정을 통해 생성됩니다.
이 과정에서는 Random noise를 반복적으로 제거하여 최종적으로 실제와 유사한 이미지를 얻게 됩니다.
Inpainting을 위해서는 Sampling pipeline을 커스터마이징해야 합니다. (이 글을 쓴 이유!)
여기에서는 전체 확산 모델 과정을 설명하지 않고 수정 사항만 설명합니다. 전체 과정을 이해하려면 Denoising Diffusion Probabilistic Models을 참조하시기 바랍니다.


Sampling Pipepline Customizing (DDPM 또는 DDIM)

그림 1. Inpaint customization (우리가 구현할 것)

샘플링 과정의 InputReference ImageMask(ROI)입니다.
여기서는 각 time step마다 reverse diffusion과 forward diffusion을 모두 수행해야 합니다. T번째 시간 단계에서 T번째 forward noise image에서 마스크되지 않은 영역은 해당 역방향 노이즈 이미지와 결합되어 노이즈 제거되어 (T-1)번째 역방향 이미지를 얻습니다. 따라서 각 시간 단계마다 우리는 역방향 확산을 위한 방향을 주입하고 모델이 우리의 분포 쪽으로 단계를 밟도록 강제하고 있습니다.

 

그림 1은 이 커스터마이징에 대한 시각적 설명을 제공합니다. 여기서는 최종 출력 이미지를 생성하기 위해 5개의 시간 단계(T1 - T5)가 사용됩니다. 순방향 노이즈 이미지에서 마스크의 검은색 영역(그림 1의 왼쪽에 표시됨)에 해당하는 영역은 노이즈 제거 전에 해당 역방향 노이즈 이미지와 결합됩니다.

마스크를 M, T번째 단계에서의 순방향 및 역방향 확산 노이즈 이미지를 각각 Ft와 Rt라고 하면, Ft-1 = denoise(Ft (1-M) + M Rt, t)가 됩니다.


구현 (ddim 파이프라인 커스터마이징)

샘플링을 위해 참조 이미지와 마스크가 주어집니다. 마스크 이미지는 부드러운 가장자리를 얻기 위해 블러 처리됩니다. 아래 코드는 DDIMPipeline(diffusers 라이브러리)의 수정된 코드입니다. 파이프라인을 호출할 때 ref_image와 mask 텐서가 전달됩니다. 전체 구현은 GitHub https://github.com/aromalma/InPainting.git에 추가되었습니다.

class InPaintDDIM(DiffusionPipeline):


    model_cpu_offload_seq = "unet"

    def __init__(self, unet, scheduler):
        super().__init__()

        # make sure scheduler can always be converted to DDIM
        scheduler = DDIMScheduler.from_config(scheduler.config)

        self.register_modules(unet=unet, scheduler=scheduler)

    @torch.no_grad()
    ## inputs: ref_image, mask
    def __call__(
        self,
        ref_image=None, ###
        mask=None, ###
        batch_size: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        eta: float = 0.0,
        num_inference_steps: int = 50,
        use_clipped_model_output: Optional[bool] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple]:


        # Sample gaussian noise to begin loop
        if isinstance(self.unet.config.sample_size, int):
            image_shape = (
                batch_size,
                self.unet.config.in_channels,
                self.unet.config.sample_size,
                self.unet.config.sample_size,
            )
        else:
            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)

        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
        image_T=torch.tensor(image.clone())

        ref_image=ref_image.to(self.device)
        # set step values
        self.scheduler.set_timesteps(num_inference_steps)
        #############################################################
        for t in self.progress_bar(self.scheduler.timesteps):
            # get forward noisy image
            noisy=self.scheduler.add_noise(ref_image,image_T,t)

            # combine reverse anfd forward
            image=image*mask+noisy * (1-mask)
            
            # generate noise to remove
            model_output = self.unet(image, t).sample

            # 2. predict previous mean of image x_t-1 and add variance depending on eta
            # eta corresponds to η in paper and should be between [0, 1]
            # do x_t -> x_t-1
            image = self.scheduler.step(
                model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
            ).prev_sample

        #############################################################
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)


        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

결과

그림 2. Inpainting 결과 (reference image, mask, 생성된 이미지)

Reference

 

'딥러닝 > 딥러닝' 카테고리의 다른 글

[논문 리뷰] Vision Transformers need registers  (0) 2024.05.15