Skip to content
/ DAVI Public

Official Implementation (Pytorch) of "DAVI: Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems", ECCV 2024 Oral paper

License

Notifications You must be signed in to change notification settings

mlvlab/DAVI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

677717b Β· Aug 16, 2024

History

6 Commits
Jul 23, 2024
Jul 23, 2024
Jul 23, 2024
Jul 23, 2024
Jul 23, 2024
Aug 16, 2024
Jul 23, 2024
Jul 23, 2024
Jul 23, 2024
Jul 23, 2024

Repository files navigation

Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems

Sojin Lee*, Dogyun Park*, Inho Kong, Hyunwoo J. Kim†.

ECCV 2024 Oral

This repository contains the official PyTorch implementation of DAVI: Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems accepted at ECCV 2024 as an oral presentation.

Our framework allows efficient posterior sampling with a single evaluation of a neural network, and enables generalization to both seen and unseen measurements without the need for test-time optimization. We provide five image restoration tasks (Gaussian deblur, 4x Super-resolution, Box inpainting, Denoising, and Colorization) with two benchmark datasets (FFHQ and ImageNet).

Setting

Please follow these steps to set up the repository.

1. Clone the Repository

git clone https://github.com/mlvlab/DAVI.git
cd DAVI

2. Install Environment

conda create -n DAVI python==3.8
conda activate DAVI
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install accelerate ema_pytorch matplotlib piq scikit-image pytorch-fid wandb

3. Download Pre-trained models and Official Checkpoints

We utilize pre-trained models from FFHQ (ffhq_10m.pt) and ImageNet (256x256_diffusion_uncond.pt) obtained from DPS and guided_diffusion, respectively.

  • Download the checkpoints for FFHQ and ImageNet from this Google Drive Link.
  • Place the pre-trained models into the model/ directory.
  • Place our checkpoints into model/official_ckpt/ffhq or model/official_ckpt/imagenet.

4. Prepare Data

For amortized optimization, we use the FFHQ 49K dataset and the ImageNet 130K dataset, which are subsets of the training datasets used for the pre-trained models. These subsets are distinct from the validation datasets (ffhq_1K and imagenet_val_1K) used for evaluation.

  • FFHQ 256x256

    • data/ffhq_1K and data/ffhq_49K.

    We downloaded the FFHQ dataset and resized it to 256x256, following the instructions on the ffhq-dataset public site.

    We use 00000-00999 as the validation set (1K) and 01000-49999 (49K) as the training set.

  • ImageNet 256x256

    • data/imagenet_val_1K and data/imagenet_130K.

    We downloaded the ImageNet 100 dataset and use its training set.

  • Measurements as numpy format

    • data/y_npy

    During amortized training, we load a subset of the training set to monitor the convergence of the training process.

    You can specify degradation types using the --deg option.

    • Gaussian Deblur gaussian
    • 4x Super-resolution sr_averagepooling
    • Box inpainting inpainting
    • Denoising deno
    • Colorization colorization
    python utils/get_measurements.py --deg gaussian --data_dir data/ffhq_49K
    

Overall directory

β”œβ”€β”€ results
β”‚
β”œβ”€β”€ models
β”‚ β”œβ”€β”€ ffhq_10m.pt # FFHQ for training
β”‚ β”œβ”€β”€ 256x256_diffusion_uncond.pt # ImageNet for training
β”‚ └── official_ckpt # For Evaluation
β”‚     β”œβ”€β”€ ffhq
β”‚     β”‚   β”œβ”€β”€ gaussian_ema.pt
β”‚     β”‚   β”œβ”€β”€ sr_averagepooling_ema.pt
β”‚     β”‚   β”œβ”€β”€ ...
β”‚     β”‚   β”œβ”€β”€ ...
β”‚     β”œβ”€β”€ imagenet
β”‚     β”‚   β”œβ”€β”€ gaussian_ema.pt
β”‚     β”‚   β”œβ”€β”€ sr_averagepooling_ema.pt
β”‚     β”‚   β”œβ”€β”€ ...
β”‚     └── └── ...
β”‚
β”œβ”€β”€ data # including training set and evaluation set
β”‚ β”œβ”€β”€ ffhq_1K # FFHQ evluation
β”‚ β”œβ”€β”€ imagenet_val_1K # ImageNet evluation
β”‚ β”œβ”€β”€ ffhq_49K # FFHQ training
β”‚ β”œβ”€β”€ imagenet_130K # ImageNet training
β”‚ └── y_npy
β”‚         β”œβ”€β”€ ffhq_1k_npy
β”‚         β”‚   β”œβ”€β”€ gaussian
β”‚         β”‚   β”œβ”€β”€ sr_averagepooling
β”‚         β”‚   β”œβ”€β”€ ...
β”‚         β”‚   └── ...
β”‚         β”œβ”€β”€ imagenet_val_1k_npy
β”‚         β”‚   β”œβ”€β”€ gaussian
β”‚         β”‚   β”œβ”€β”€ sr_averagepooling
β”‚         β”‚   β”œβ”€β”€ ...
└─────────└── └── ...

Evaluation

1. Restore degraded images

  • You can specify the directory of measurements with --y_dir data/y_npy
accelerate launch --num_processes=1 eval.py --eval_dir data/ffhq_1K --deg gaussian --perturb_h 0.1 --ckpt model/official_ckpt/ffhq/gaussian_ema.pt

2. Evaluate PSNR,LPIPS and FID

  • PSNR and LPIPS
    python utils/eval_psnr_lpips.py
    
  • FID: pytorch-fid
    python -m pytorch_fid source_dir recon_dir
    

Train with MultiGPU

  • To check training logs, use the --use_wandb flag.
accelerate launch --multi_gpu --num_processes=4 train.py --data_dir data/ffhq_49K/ --model_path model/ffhq_10m.pt --deg gaussian --t_ikl 400 --weight_con 0.5 --reg_coeff 0.25 --perturb_h 0.1

About

Official Implementation (Pytorch) of "DAVI: Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems", ECCV 2024 Oral paper

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published