mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-15 14:31:44 +03:00
The Merge (#1705)
This commit is contained in:
commit
5291437085
2
.gitignore
vendored
2
.gitignore
vendored
@ -54,10 +54,12 @@ condaenv.*.requirements.txt
|
||||
# Repo-specific
|
||||
# =========================================================================== #
|
||||
/configs/webui/userconfig_streamlit.yaml
|
||||
/configs/webui/userconfig_flet.yaml
|
||||
/custom-conda-path.txt
|
||||
!/src/components/*
|
||||
!/src/pages/*
|
||||
/src/*
|
||||
/inputs
|
||||
/outputs
|
||||
/model_cache
|
||||
/log/**/*.png
|
||||
|
24
.gitmodules
vendored
Normal file
24
.gitmodules
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
[submodule "backend"]
|
||||
path = backend
|
||||
url = ../../Sygil-Dev/dalle-flow.git
|
||||
[submodule "backend/clip-as-service"]
|
||||
path = backend/clip-as-service
|
||||
url = ../../jina-ai/clip-as-service.git
|
||||
[submodule "backend/clipseg"]
|
||||
path = backend/clipseg
|
||||
url = ../../timojl/clipseg.git
|
||||
[submodule "backend/dalle_flow"]
|
||||
path = backend/dalle_flow
|
||||
url = ../../Sygil-Dev/dalle-flow.git
|
||||
[submodule "backend/glid-3-xl"]
|
||||
path = backend/glid-3-xl
|
||||
url = ../../jina-ai/glid-3-xl.git
|
||||
[submodule "backend/latent-diffusion"]
|
||||
path = backend/latent-diffusion
|
||||
url = ../../CompVis/latent-diffusion.git
|
||||
[submodule "backend/stable-diffusion"]
|
||||
path = backend/stable-diffusion
|
||||
url = ../../AmericanPresidentJimmyCarter/stable-diffusion.git
|
||||
[submodule "backend/SwinIR"]
|
||||
path = backend/SwinIR
|
||||
url = ../../jina-ai/SwinIR.git
|
@ -9,6 +9,7 @@ SHELL ["/bin/bash", "-c"]
|
||||
ENV PYTHONPATH=/sd
|
||||
|
||||
EXPOSE 8501
|
||||
COPY ./entrypoint.sh /sd/
|
||||
COPY ./data/DejaVuSans.ttf /usr/share/fonts/truetype/
|
||||
COPY ./data/ /sd/data/
|
||||
copy ./images/ /sd/images/
|
||||
@ -16,8 +17,9 @@ copy ./scripts/ /sd/scripts/
|
||||
copy ./ldm/ /sd/ldm/
|
||||
copy ./frontend/ /sd/frontend/
|
||||
copy ./configs/ /sd/configs/
|
||||
copy ./configs/webui/webui_streamlit.yaml /sd/configs/webui/userconfig_streamlit.yaml
|
||||
copy ./.streamlit/ /sd/.streamlit/
|
||||
COPY ./entrypoint.sh /sd/
|
||||
copy ./optimizedSD/ /sd/optimizedSD/
|
||||
ENTRYPOINT /sd/entrypoint.sh
|
||||
|
||||
RUN mkdir -p ~/.streamlit/
|
||||
|
@ -6,11 +6,12 @@ SHELL ["/bin/bash", "-c"]
|
||||
WORKDIR /install
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y wget curl git build-essential zip unzip nano openssh-server libgl1 && \
|
||||
apt-get install -y wget curl git build-essential zip unzip nano openssh-server libgl1 libsndfile1-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY ./requirements.txt /install/
|
||||
COPY ./setup.py /install/
|
||||
|
||||
RUN /opt/conda/bin/python -m pip install -r /install/requirements.txt
|
||||
|
||||
|
@ -9,17 +9,19 @@ SHELL ["/bin/bash", "-c"]
|
||||
ENV PYTHONPATH=/sd
|
||||
|
||||
EXPOSE 8501
|
||||
COPY ./runpod_entrypoint.sh /sd/entrypoint.sh
|
||||
COPY ./data/DejaVuSans.ttf /usr/share/fonts/truetype/
|
||||
COPY ./configs/ /sd/configs/
|
||||
copy ./configs/webui/webui_streamlit.yaml /sd/configs/webui/userconfig_streamlit.yaml
|
||||
COPY ./data/ /sd/data/
|
||||
COPY ./frontend/ /sd/frontend/
|
||||
COPY ./gfpgan/ /sd/gfpgan/
|
||||
COPY ./images/ /sd/images/
|
||||
COPY ./ldm/ /sd/ldm/
|
||||
COPY ./models/ /sd/models/
|
||||
copy ./optimizedSD/ /sd/optimizedSD/
|
||||
COPY ./scripts/ /sd/scripts/
|
||||
COPY ./.streamlit/ /sd/.streamlit/
|
||||
COPY ./runpod_entrypoint.sh /sd/entrypoint.sh
|
||||
ENTRYPOINT /sd/entrypoint.sh
|
||||
|
||||
RUN mkdir -p ~/.streamlit/
|
||||
|
18
README.md
18
README.md
@ -6,8 +6,8 @@
|
||||
|
||||
## Installation instructions for:
|
||||
|
||||
- **[Windows](https://sygil-dev.github.io/sygil-webui/docs/1.windows-installation.html)**
|
||||
- **[Linux](https://sygil-dev.github.io/sygil-webui/docs/2.linux-installation.html)**
|
||||
- **[Windows](https://sygil-dev.github.io/sygil-webui/docs/Installation/windows-installation)**
|
||||
- **[Linux](https://sygil-dev.github.io/sygil-webui/docs/Installation/linux-installation)**
|
||||
|
||||
### Want to ask a question or request a feature?
|
||||
|
||||
@ -118,7 +118,7 @@ Please see the [Streamlit Documentation](docs/4.streamlit-interface.md) to learn
|
||||
|
||||
**Note: the Gradio interface is no longer being actively developed by Sygil.Dev and is only receiving bug fixes.**
|
||||
|
||||
Please see the [Gradio Documentation](docs/5.gradio-interface.md) to learn more.
|
||||
Please see the [Gradio Documentation](https://sygil-dev.github.io/sygil-webui/docs/Gradio/gradio-interface/) to learn more.
|
||||
|
||||
## Image Upscalers
|
||||
|
||||
@ -146,13 +146,13 @@ Put them into the `sygil-webui/models/realesrgan` directory.
|
||||
|
||||
### LSDR
|
||||
|
||||
Download **LDSR** [project.yaml](https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1) and [model last.cpkt](https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1). Rename last.ckpt to model.ckpt and place both under `sygil-webui/models/ldsr/`
|
||||
Download **LDSR** [project.yaml](https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1) and [model last.cpkt](https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1). Rename `last.ckpt` to `model.ckpt` and place both under `sygil-webui/models/ldsr/`
|
||||
|
||||
### GoBig, and GoLatent *(Currently on the Gradio version Only)*
|
||||
|
||||
More powerful upscalers that uses a seperate Latent Diffusion model to more cleanly upscale images.
|
||||
More powerful upscalers that uses a separate Latent Diffusion model to more cleanly upscale images.
|
||||
|
||||
Please see the [Image Enhancers Documentation](docs/6.image_enhancers.md) to learn more.
|
||||
Please see the [Post-Processing Documentation](https://sygil-dev.github.io/sygil-webui/docs/post-processing) to learn more.
|
||||
|
||||
-----
|
||||
|
||||
@ -162,12 +162,12 @@ Please see the [Image Enhancers Documentation](docs/6.image_enhancers.md) to lea
|
||||
|
||||
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
|
||||
|
||||
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)<br/>
|
||||
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)
|
||||
[Robin Rombach](https://github.com/rromb)\*,
|
||||
[Andreas Blattmann](https://github.com/ablattmann)\*,
|
||||
[Dominik Lorenz](https://github.com/qp-qp)\,
|
||||
[Patrick Esser](https://github.com/pesser),
|
||||
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
|
||||
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
|
||||
|
||||
**CVPR '22 Oral**
|
||||
|
||||
@ -194,7 +194,7 @@ Details on the training procedure and data, as well as the intended use of the m
|
||||
|
||||
## Comments
|
||||
|
||||
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||
- Our code base for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
||||
Thanks for open-sourcing!
|
||||
|
||||
|
1
backend/SwinIR
Submodule
1
backend/SwinIR
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 41d8c990adfbeeba929f20ae11d3a8494a83d12d
|
1
backend/clip-as-service
Submodule
1
backend/clip-as-service
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 9bb7d1f47d19e15e844108dec5f84cabcce7975d
|
1
backend/clipseg
Submodule
1
backend/clipseg
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 656e0c662bd1c9a5ae511011642da5b7d8503312
|
1
backend/dalle_flow
Submodule
1
backend/dalle_flow
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 491c52af85f6d75d30094974c97a5a0ed53ba6db
|
1
backend/glid-3-xl
Submodule
1
backend/glid-3-xl
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit b21a3acdd478a4fa41c529b55199c8ac3b1b807a
|
1
backend/latent-diffusion
Submodule
1
backend/latent-diffusion
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit a506df5756472e2ebaf9078affdde2c4f1502cd4
|
1
backend/stable-diffusion
Submodule
1
backend/stable-diffusion
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 2de63ea62862106de27706cd280e692f34c12d9f
|
68
configs/stable-diffusion/v2-inference-v.yaml
Normal file
68
configs/stable-diffusion/v2-inference-v.yaml
Normal file
@ -0,0 +1,68 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
67
configs/stable-diffusion/v2-inference.yaml
Normal file
67
configs/stable-diffusion/v2-inference.yaml
Normal file
@ -0,0 +1,67 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
158
configs/stable-diffusion/v2-inpainting-inference.yaml
Normal file
158
configs/stable-diffusion/v2-inpainting-inference.yaml
Normal file
@ -0,0 +1,158 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: hybrid
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
finetune_keys: null
|
||||
use_ema: False
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 9
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
|
||||
data:
|
||||
target: ldm.data.laion.WebDataModuleFromConfig
|
||||
params:
|
||||
tar_base: null # for concat as in LAION-A
|
||||
p_unsafe_threshold: 0.1
|
||||
filter_word_list: "data/filters.yaml"
|
||||
max_pwatermark: 0.45
|
||||
batch_size: 8
|
||||
num_workers: 6
|
||||
multinode: True
|
||||
min_size: 512
|
||||
train:
|
||||
shards:
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||
shuffle: 10000
|
||||
image_key: jpg
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.RandomCrop
|
||||
params:
|
||||
size: 512
|
||||
postprocess:
|
||||
target: ldm.data.laion.AddMask
|
||||
params:
|
||||
mode: "512train-large"
|
||||
p_drop: 0.25
|
||||
# NOTE use enough shards to avoid empty validation loops in workers
|
||||
validation:
|
||||
shards:
|
||||
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||
shuffle: 0
|
||||
image_key: jpg
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.CenterCrop
|
||||
params:
|
||||
size: 512
|
||||
postprocess:
|
||||
target: ldm.data.laion.AddMask
|
||||
params:
|
||||
mode: "512train-large"
|
||||
p_drop: 0.25
|
||||
|
||||
lightning:
|
||||
find_unused_parameters: True
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 5000
|
||||
|
||||
callbacks:
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
params:
|
||||
every_n_train_steps: 10000
|
||||
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
enable_autocast: False
|
||||
disabled: False
|
||||
batch_frequency: 1000
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 5.0
|
||||
unconditional_guidance_label: [""]
|
||||
ddim_steps: 50 # todo check these out for depth2img,
|
||||
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
val_check_interval: 5000000
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
74
configs/stable-diffusion/v2-midas-inference.yaml
Normal file
74
configs/stable-diffusion/v2-midas-inference.yaml
Normal file
@ -0,0 +1,74 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-07
|
||||
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: hybrid
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
finetune_keys: null
|
||||
use_ema: False
|
||||
|
||||
depth_stage_config:
|
||||
target: ldm.modules.midas.api.MiDaSInference
|
||||
params:
|
||||
model_type: "dpt_hybrid"
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 5
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
|
76
configs/stable-diffusion/x4-upscaling.yaml
Normal file
76
configs/stable-diffusion/x4-upscaling.yaml
Normal file
@ -0,0 +1,76 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
low_scale_key: "lr"
|
||||
linear_start: 0.0001
|
||||
linear_end: 0.02
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 128
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: "hybrid-adm"
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.08333
|
||||
use_ema: False
|
||||
|
||||
low_scale_config:
|
||||
target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
|
||||
params:
|
||||
noise_schedule_config: # image space
|
||||
linear_start: 0.0001
|
||||
linear_end: 0.02
|
||||
max_noise_level: 350
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
|
||||
image_size: 128
|
||||
in_channels: 7
|
||||
out_channels: 4
|
||||
model_channels: 256
|
||||
attention_resolutions: [ 2,4,8]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 2, 4]
|
||||
disable_self_attentions: [True, True, True, False]
|
||||
disable_middle_self_attn: False
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
use_linear_in_transformer: True
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
ddconfig:
|
||||
# attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
|
||||
double_z: True
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
199
configs/webui/webui_flet.yaml
Normal file
199
configs/webui/webui_flet.yaml
Normal file
@ -0,0 +1,199 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_flet.yaml.
|
||||
# Any changes made here will be available automatically on the web app without having to stop it.
|
||||
# You may add overrides in a file named "userconfig_flet.yaml" in this folder, which can contain any subset
|
||||
# of the properties below.
|
||||
|
||||
# any section labeled '_page' will get it's own tab in settings
|
||||
# any section without that suffix will still be read by parser and stored in session
|
||||
#
|
||||
# display types
|
||||
# -- every display type must have 'value: '
|
||||
# -- to do: add 'tooltip : ' to every display type
|
||||
# --(make optional, not everything needs one.)
|
||||
# bool
|
||||
# -value
|
||||
# dropdown
|
||||
# -value
|
||||
# -option_list
|
||||
# slider
|
||||
# -value
|
||||
# -min
|
||||
# -max
|
||||
# -step
|
||||
# textinput
|
||||
# -value
|
||||
#
|
||||
# list of value types
|
||||
# !!bool boolean 'true' 'false'
|
||||
# !!float float '0.01'
|
||||
# !!int integer '23'
|
||||
# !!str string 'foo' 'bar'
|
||||
# !!null None
|
||||
|
||||
webui_page:
|
||||
default_theme:
|
||||
display: dropdown
|
||||
value: 'dark'
|
||||
option_list:
|
||||
- !!str 'dark'
|
||||
- !!str 'light'
|
||||
default_text_size:
|
||||
display: slider
|
||||
value: !!int '20'
|
||||
min: !!int '10'
|
||||
max: !!int '32'
|
||||
step: !!float '2.0'
|
||||
max_message_history:
|
||||
display: slider
|
||||
value: !!int '20'
|
||||
min: !!int '1'
|
||||
max: !!int '100'
|
||||
step: !!int '1'
|
||||
|
||||
general_page:
|
||||
huggingface_token:
|
||||
display: textinput
|
||||
value: !!str ''
|
||||
stable_horde_api:
|
||||
display: textinput
|
||||
value: !!str '0000000000'
|
||||
global_negative_prompt:
|
||||
display: textinput
|
||||
value: !!str " "
|
||||
default_model:
|
||||
display: textinput
|
||||
value: !!str "Stable Diffusion v1.5"
|
||||
base_model:
|
||||
display: textinput
|
||||
value: !!str "Stable Diffusion v1.5"
|
||||
default_model_config:
|
||||
display: textinput
|
||||
value: !!str "configs/stable-diffusion/v1-inference.yaml"
|
||||
default_model_path:
|
||||
display: textinput
|
||||
value: !!str "models/ldm/stable-diffusion-v1/Stable Diffusion v1.5.ckpt"
|
||||
use_sd_concepts_library:
|
||||
display: bool
|
||||
value: !!bool 'true'
|
||||
sd_concepts_library_folder:
|
||||
display: textinput
|
||||
value: !!str "models/custom/sd-concepts-library"
|
||||
GFPGAN_dir:
|
||||
display: textinput
|
||||
value: !!str "./models/gfpgan"
|
||||
GFPGAN_model:
|
||||
display: textinput
|
||||
value: !!str "GFPGANv1.4"
|
||||
LDSR_dir:
|
||||
display: textinput
|
||||
value: !!str "./models/ldsr"
|
||||
LDSR_model:
|
||||
display: textinput
|
||||
value: !!str "model"
|
||||
RealESRGAN_dir:
|
||||
display: textinput
|
||||
value: !!str "./models/realesrgan"
|
||||
RealESRGAN_model:
|
||||
display: textinput
|
||||
value: !!str "RealESRGAN_x4plus"
|
||||
upscaling_method:
|
||||
display: textinput
|
||||
value: !!str "RealESRGAN"
|
||||
|
||||
output_page:
|
||||
outdir:
|
||||
display: textinput
|
||||
value: !!str 'outputs'
|
||||
outdir_txt2img:
|
||||
display: textinput
|
||||
value: !!str "outputs/txt2img"
|
||||
outdir_img2img:
|
||||
display: textinput
|
||||
value: !!str "outputs/img2img"
|
||||
outdir_img2txt:
|
||||
display: textinput
|
||||
value: !!str "outputs/img2txt"
|
||||
save_metadata:
|
||||
display: bool
|
||||
value: !!bool true
|
||||
save_format:
|
||||
display: dropdown
|
||||
value: !!str "png"
|
||||
option_list:
|
||||
- !!str 'png'
|
||||
- !!str 'jpeg'
|
||||
skip_grid:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
skip_save:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
#grid_quality: 95
|
||||
#n_rows: -1
|
||||
#update_preview: True
|
||||
#update_preview_frequency: 10
|
||||
|
||||
performance_page:
|
||||
gpu:
|
||||
display: dropdown
|
||||
value: !!str ''
|
||||
option_list:
|
||||
- !!str '0:'
|
||||
gfpgan_cpu:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
esrgan_cpu:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
extra_models_cpu:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
extra_models_gpu:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
gfpgan_gpu:
|
||||
display: textinput
|
||||
value: !!int 0
|
||||
esrgan_gpu:
|
||||
display: textinput
|
||||
value: !!int 0
|
||||
keep_all_models_loaded:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
#no_verify_input: False
|
||||
#no_half: False
|
||||
#use_float16: False
|
||||
#precision: "autocast"
|
||||
#optimized: False
|
||||
#optimized_turbo: False
|
||||
#optimized_config: "optimizedSD/v1-inference.yaml"
|
||||
#enable_attention_slicing: False
|
||||
#enable_minimal_memory_usage: False
|
||||
|
||||
server_page:
|
||||
hide_server_setting:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
hide_browser_setting:
|
||||
display: bool
|
||||
value: !!bool 'false'
|
||||
|
||||
textual_inversion:
|
||||
pretrained_model_name_or_path: "models/diffusers/stable-diffusion-v1-5"
|
||||
tokenizer_name: "models/clip-vit-large-patch14"
|
@ -59,6 +59,7 @@ general:
|
||||
no_half: False
|
||||
use_float16: False
|
||||
precision: "autocast"
|
||||
use_cudnn: False
|
||||
optimized: False
|
||||
optimized_turbo: False
|
||||
optimized_config: "optimizedSD/v1-inference.yaml"
|
||||
@ -70,6 +71,7 @@ general:
|
||||
admin:
|
||||
hide_server_setting: False
|
||||
hide_browser_setting: False
|
||||
global_negative_prompt: ""
|
||||
|
||||
debug:
|
||||
enable_hydralit: False
|
||||
@ -219,6 +221,7 @@ txt2vid:
|
||||
|
||||
beta_scheduler_type: "scaled_linear"
|
||||
max_duration_in_seconds: 30
|
||||
fps: 30
|
||||
|
||||
LDSR_config:
|
||||
sampling_steps: 50
|
||||
|
@ -1,4 +1,12 @@
|
||||
a 2 koma
|
||||
a 2koma
|
||||
a 3D render
|
||||
a 4 koma
|
||||
a 4koma
|
||||
a 6 koma
|
||||
a 6koma
|
||||
a 8 koma
|
||||
a 8koma
|
||||
a black and white photo
|
||||
a bronze sculpture
|
||||
a cartoon
|
||||
@ -25,6 +33,7 @@ a gouache
|
||||
a hologram
|
||||
a hyperrealistic painting
|
||||
a jigsaw puzzle
|
||||
a koma
|
||||
a low poly render
|
||||
a macro photograph
|
||||
a manga drawing
|
||||
|
@ -8,15 +8,21 @@ Home Page: https://github.com/Sygil-Dev/sygil-webui
|
||||
|
||||
### Installation on Windows:
|
||||
|
||||
|
||||
|
||||
- Clone or download the code from the [Repository](https://github.com/Sygil-Dev/sygil-webui).
|
||||
|
||||
- Double-click the `installer/install.bat` file and wait for it to handle everything for you.
|
||||
- Open the `installer` folder and copy the `install.bat` to the root folder next to the `webui.bat`
|
||||
|
||||
- Double-click the `install.bat` file and wait for it to handle everything for you.
|
||||
|
||||
### Installation on Linux:
|
||||
|
||||
- Clone or download the code from the [Repository](https://github.com/Sygil-Dev/sygil-webui).
|
||||
|
||||
- Open a terminal on the folder where the code is located and run `./installer/install.sh` ,make sure it has the right permissions and can be executed.
|
||||
- Open the `installer` folder and copy the `install.sh` to the root folder next to the `webui.sh`
|
||||
|
||||
- Open a terminal on the folder where the code is located and run `./install.sh` ,make sure it has the right permissions and can be executed.
|
||||
|
||||
- Wait for the installer to handle everything for you.
|
||||
|
||||
|
@ -15,22 +15,21 @@ name: ldm
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
channels:
|
||||
- conda-forge
|
||||
- pytorch
|
||||
- defaults
|
||||
- nvidia
|
||||
# Psst. If you change a dependency, make sure it's mirrored in the docker requirement
|
||||
# files as well.
|
||||
dependencies:
|
||||
- nodejs=18.11.0
|
||||
- conda-forge::nodejs=18.11.0
|
||||
- yarn=1.22.19
|
||||
- cudatoolkit=11.3
|
||||
- cudatoolkit=11.7
|
||||
- git
|
||||
- numpy=1.22.3
|
||||
- numpy=1.23.3
|
||||
- pip=20.3
|
||||
- python=3.8.5
|
||||
- pytorch=1.11.0
|
||||
- pytorch=1.13.0
|
||||
- scikit-image=0.19.2
|
||||
- torchvision=0.12.0
|
||||
- torchvision=0.14.0
|
||||
- pip:
|
||||
- -r requirements.txt
|
||||
- -r requirements.txt
|
||||
|
0
ldm/__init__.py
Normal file
0
ldm/__init__.py
Normal file
@ -1,101 +0,0 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
|
||||
from data.nocaps_dataset import nocaps_eval
|
||||
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
||||
from data.vqa_dataset import vqa_dataset
|
||||
from data.nlvr_dataset import nlvr_dataset
|
||||
from data.pretrain_dataset import pretrain_dataset
|
||||
from transform.randaugment import RandomAugment
|
||||
|
||||
def create_dataset(dataset, config, min_scale=0.5):
|
||||
|
||||
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
||||
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
if dataset=='pretrain':
|
||||
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
|
||||
return dataset
|
||||
|
||||
elif dataset=='caption_coco':
|
||||
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
|
||||
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
elif dataset=='nocaps':
|
||||
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return val_dataset, test_dataset
|
||||
|
||||
elif dataset=='retrieval_coco':
|
||||
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
|
||||
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
elif dataset=='retrieval_flickr':
|
||||
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
|
||||
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
elif dataset=='vqa':
|
||||
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
|
||||
train_files = config['train_files'], split='train')
|
||||
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
|
||||
return train_dataset, test_dataset
|
||||
|
||||
elif dataset=='nlvr':
|
||||
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
|
||||
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
|
||||
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
|
||||
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
||||
samplers = []
|
||||
for dataset,shuffle in zip(datasets,shuffles):
|
||||
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
|
||||
samplers.append(sampler)
|
||||
return samplers
|
||||
|
||||
|
||||
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
||||
loaders = []
|
||||
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
|
||||
if is_train:
|
||||
shuffle = (sampler is None)
|
||||
drop_last = True
|
||||
else:
|
||||
shuffle = False
|
||||
drop_last = False
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=bs,
|
||||
num_workers=n_worker,
|
||||
pin_memory=True,
|
||||
sampler=sampler,
|
||||
shuffle=shuffle,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
loaders.append(loader)
|
||||
return loaders
|
||||
|
@ -1,11 +1,17 @@
|
||||
from abc import abstractmethod
|
||||
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||
from torch.utils.data import (
|
||||
Dataset,
|
||||
ConcatDataset,
|
||||
ChainDataset,
|
||||
IterableDataset,
|
||||
)
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
"""
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
@ -13,11 +19,13 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||
print(
|
||||
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
pass
|
||||
|
@ -11,24 +11,34 @@ from tqdm import tqdm
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
import taming.data.utils as tdu
|
||||
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
||||
from taming.data.imagenet import (
|
||||
str_to_indices,
|
||||
give_synsets_from_indices,
|
||||
download,
|
||||
retrieve,
|
||||
)
|
||||
from taming.data.imagenet import ImagePaths
|
||||
|
||||
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||
from ldm.modules.image_degradation import (
|
||||
degradation_fn_bsr,
|
||||
degradation_fn_bsr_light,
|
||||
)
|
||||
|
||||
|
||||
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||
def synset2idx(path_to_yaml='data/index_synset.yaml'):
|
||||
with open(path_to_yaml) as f:
|
||||
di2s = yaml.load(f)
|
||||
return dict((v,k) for k,v in di2s.items())
|
||||
return dict((v, k) for k, v in di2s.items())
|
||||
|
||||
|
||||
class ImageNetBase(Dataset):
|
||||
def __init__(self, config=None):
|
||||
self.config = config or OmegaConf.create()
|
||||
if not type(self.config)==dict:
|
||||
if not type(self.config) == dict:
|
||||
self.config = OmegaConf.to_container(self.config)
|
||||
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||
self.keep_orig_class_label = self.config.get(
|
||||
'keep_orig_class_label', False
|
||||
)
|
||||
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||
self._prepare()
|
||||
self._prepare_synset_to_human()
|
||||
@ -46,17 +56,23 @@ class ImageNetBase(Dataset):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _filter_relpaths(self, relpaths):
|
||||
ignore = set([
|
||||
"n06596364_9591.JPEG",
|
||||
])
|
||||
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||
if "sub_indices" in self.config:
|
||||
indices = str_to_indices(self.config["sub_indices"])
|
||||
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
||||
ignore = set(
|
||||
[
|
||||
'n06596364_9591.JPEG',
|
||||
]
|
||||
)
|
||||
relpaths = [
|
||||
rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore
|
||||
]
|
||||
if 'sub_indices' in self.config:
|
||||
indices = str_to_indices(self.config['sub_indices'])
|
||||
synsets = give_synsets_from_indices(
|
||||
indices, path_to_yaml=self.idx2syn
|
||||
) # returns a list of strings
|
||||
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
||||
files = []
|
||||
for rpath in relpaths:
|
||||
syn = rpath.split("/")[0]
|
||||
syn = rpath.split('/')[0]
|
||||
if syn in synsets:
|
||||
files.append(rpath)
|
||||
return files
|
||||
@ -65,78 +81,89 @@ class ImageNetBase(Dataset):
|
||||
|
||||
def _prepare_synset_to_human(self):
|
||||
SIZE = 2655750
|
||||
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||
if (not os.path.exists(self.human_dict) or
|
||||
not os.path.getsize(self.human_dict)==SIZE):
|
||||
URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1'
|
||||
self.human_dict = os.path.join(self.root, 'synset_human.txt')
|
||||
if (
|
||||
not os.path.exists(self.human_dict)
|
||||
or not os.path.getsize(self.human_dict) == SIZE
|
||||
):
|
||||
download(URL, self.human_dict)
|
||||
|
||||
def _prepare_idx_to_synset(self):
|
||||
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||
if (not os.path.exists(self.idx2syn)):
|
||||
URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1'
|
||||
self.idx2syn = os.path.join(self.root, 'index_synset.yaml')
|
||||
if not os.path.exists(self.idx2syn):
|
||||
download(URL, self.idx2syn)
|
||||
|
||||
def _prepare_human_to_integer_label(self):
|
||||
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
||||
if (not os.path.exists(self.human2integer)):
|
||||
URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1'
|
||||
self.human2integer = os.path.join(
|
||||
self.root, 'imagenet1000_clsidx_to_labels.txt'
|
||||
)
|
||||
if not os.path.exists(self.human2integer):
|
||||
download(URL, self.human2integer)
|
||||
with open(self.human2integer, "r") as f:
|
||||
with open(self.human2integer, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
assert len(lines) == 1000
|
||||
self.human2integer_dict = dict()
|
||||
for line in lines:
|
||||
value, key = line.split(":")
|
||||
value, key = line.split(':')
|
||||
self.human2integer_dict[key] = int(value)
|
||||
|
||||
def _load(self):
|
||||
with open(self.txt_filelist, "r") as f:
|
||||
with open(self.txt_filelist, 'r') as f:
|
||||
self.relpaths = f.read().splitlines()
|
||||
l1 = len(self.relpaths)
|
||||
self.relpaths = self._filter_relpaths(self.relpaths)
|
||||
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
||||
print(
|
||||
'Removed {} files from filelist during filtering.'.format(
|
||||
l1 - len(self.relpaths)
|
||||
)
|
||||
)
|
||||
|
||||
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
||||
self.synsets = [p.split('/')[0] for p in self.relpaths]
|
||||
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
||||
|
||||
unique_synsets = np.unique(self.synsets)
|
||||
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
||||
class_dict = dict(
|
||||
(synset, i) for i, synset in enumerate(unique_synsets)
|
||||
)
|
||||
if not self.keep_orig_class_label:
|
||||
self.class_labels = [class_dict[s] for s in self.synsets]
|
||||
else:
|
||||
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
||||
|
||||
with open(self.human_dict, "r") as f:
|
||||
with open(self.human_dict, 'r') as f:
|
||||
human_dict = f.read().splitlines()
|
||||
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
||||
|
||||
self.human_labels = [human_dict[s] for s in self.synsets]
|
||||
|
||||
labels = {
|
||||
"relpath": np.array(self.relpaths),
|
||||
"synsets": np.array(self.synsets),
|
||||
"class_label": np.array(self.class_labels),
|
||||
"human_label": np.array(self.human_labels),
|
||||
'relpath': np.array(self.relpaths),
|
||||
'synsets': np.array(self.synsets),
|
||||
'class_label': np.array(self.class_labels),
|
||||
'human_label': np.array(self.human_labels),
|
||||
}
|
||||
|
||||
if self.process_images:
|
||||
self.size = retrieve(self.config, "size", default=256)
|
||||
self.data = ImagePaths(self.abspaths,
|
||||
labels=labels,
|
||||
size=self.size,
|
||||
random_crop=self.random_crop,
|
||||
)
|
||||
self.size = retrieve(self.config, 'size', default=256)
|
||||
self.data = ImagePaths(
|
||||
self.abspaths,
|
||||
labels=labels,
|
||||
size=self.size,
|
||||
random_crop=self.random_crop,
|
||||
)
|
||||
else:
|
||||
self.data = self.abspaths
|
||||
|
||||
|
||||
class ImageNetTrain(ImageNetBase):
|
||||
NAME = "ILSVRC2012_train"
|
||||
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
||||
NAME = 'ILSVRC2012_train'
|
||||
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
|
||||
AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2'
|
||||
FILES = [
|
||||
"ILSVRC2012_img_train.tar",
|
||||
'ILSVRC2012_img_train.tar',
|
||||
]
|
||||
SIZES = [
|
||||
147897477120,
|
||||
@ -151,57 +178,64 @@ class ImageNetTrain(ImageNetBase):
|
||||
if self.data_root:
|
||||
self.root = os.path.join(self.data_root, self.NAME)
|
||||
else:
|
||||
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||
cachedir = os.environ.get(
|
||||
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
|
||||
)
|
||||
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
|
||||
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.datadir = os.path.join(self.root, 'data')
|
||||
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
|
||||
self.expected_length = 1281167
|
||||
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
||||
default=True)
|
||||
self.random_crop = retrieve(
|
||||
self.config, 'ImageNetTrain/random_crop', default=True
|
||||
)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
|
||||
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
path = os.path.join(self.root, self.FILES[0])
|
||||
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or not os.path.getsize(path) == self.SIZES[0]
|
||||
):
|
||||
import academictorrents as at
|
||||
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
print("Extracting {} to {}".format(path, datadir))
|
||||
print('Extracting {} to {}'.format(path, datadir))
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
with tarfile.open(path, "r:") as tar:
|
||||
with tarfile.open(path, 'r:') as tar:
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
print("Extracting sub-tars.")
|
||||
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||
print('Extracting sub-tars.')
|
||||
subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar')))
|
||||
for subpath in tqdm(subpaths):
|
||||
subdir = subpath[:-len(".tar")]
|
||||
subdir = subpath[: -len('.tar')]
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
with tarfile.open(subpath, "r:") as tar:
|
||||
with tarfile.open(subpath, 'r:') as tar:
|
||||
tar.extractall(path=subdir)
|
||||
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = "\n".join(filelist)+"\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
filelist = '\n'.join(filelist) + '\n'
|
||||
with open(self.txt_filelist, 'w') as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
|
||||
|
||||
class ImageNetValidation(ImageNetBase):
|
||||
NAME = "ILSVRC2012_validation"
|
||||
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
||||
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
||||
NAME = 'ILSVRC2012_validation'
|
||||
URL = 'http://www.image-net.org/challenges/LSVRC/2012/'
|
||||
AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5'
|
||||
VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1'
|
||||
FILES = [
|
||||
"ILSVRC2012_img_val.tar",
|
||||
"validation_synset.txt",
|
||||
'ILSVRC2012_img_val.tar',
|
||||
'validation_synset.txt',
|
||||
]
|
||||
SIZES = [
|
||||
6744924160,
|
||||
@ -217,39 +251,49 @@ class ImageNetValidation(ImageNetBase):
|
||||
if self.data_root:
|
||||
self.root = os.path.join(self.data_root, self.NAME)
|
||||
else:
|
||||
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
cachedir = os.environ.get(
|
||||
'XDG_CACHE_HOME', os.path.expanduser('~/.cache')
|
||||
)
|
||||
self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME)
|
||||
self.datadir = os.path.join(self.root, 'data')
|
||||
self.txt_filelist = os.path.join(self.root, 'filelist.txt')
|
||||
self.expected_length = 50000
|
||||
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
||||
default=False)
|
||||
self.random_crop = retrieve(
|
||||
self.config, 'ImageNetValidation/random_crop', default=False
|
||||
)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
print('Preparing dataset {} in {}'.format(self.NAME, self.root))
|
||||
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
path = os.path.join(self.root, self.FILES[0])
|
||||
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or not os.path.getsize(path) == self.SIZES[0]
|
||||
):
|
||||
import academictorrents as at
|
||||
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
print("Extracting {} to {}".format(path, datadir))
|
||||
print('Extracting {} to {}'.format(path, datadir))
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
with tarfile.open(path, "r:") as tar:
|
||||
with tarfile.open(path, 'r:') as tar:
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
vspath = os.path.join(self.root, self.FILES[1])
|
||||
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
||||
if (
|
||||
not os.path.exists(vspath)
|
||||
or not os.path.getsize(vspath) == self.SIZES[1]
|
||||
):
|
||||
download(self.VS_URL, vspath)
|
||||
|
||||
with open(vspath, "r") as f:
|
||||
with open(vspath, 'r') as f:
|
||||
synset_dict = f.read().splitlines()
|
||||
synset_dict = dict(line.split() for line in synset_dict)
|
||||
|
||||
print("Reorganizing into synset folders")
|
||||
print('Reorganizing into synset folders')
|
||||
synsets = np.unique(list(synset_dict.values()))
|
||||
for s in synsets:
|
||||
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
||||
@ -258,21 +302,26 @@ class ImageNetValidation(ImageNetBase):
|
||||
dst = os.path.join(datadir, v)
|
||||
shutil.move(src, dst)
|
||||
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG'))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = "\n".join(filelist)+"\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
filelist = '\n'.join(filelist) + '\n'
|
||||
with open(self.txt_filelist, 'w') as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
|
||||
|
||||
|
||||
class ImageNetSR(Dataset):
|
||||
def __init__(self, size=None,
|
||||
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
||||
random_crop=True):
|
||||
def __init__(
|
||||
self,
|
||||
size=None,
|
||||
degradation=None,
|
||||
downscale_f=4,
|
||||
min_crop_f=0.5,
|
||||
max_crop_f=1.0,
|
||||
random_crop=True,
|
||||
):
|
||||
"""
|
||||
Imagenet Superresolution Dataloader
|
||||
Performs following ops in order:
|
||||
@ -296,67 +345,86 @@ class ImageNetSR(Dataset):
|
||||
self.LR_size = int(size / downscale_f)
|
||||
self.min_crop_f = min_crop_f
|
||||
self.max_crop_f = max_crop_f
|
||||
assert(max_crop_f <= 1.)
|
||||
assert max_crop_f <= 1.0
|
||||
self.center_crop = not random_crop
|
||||
|
||||
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
||||
self.image_rescaler = albumentations.SmallestMaxSize(
|
||||
max_size=size, interpolation=cv2.INTER_AREA
|
||||
)
|
||||
|
||||
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||
self.pil_interpolation = (
|
||||
False # gets reset later if incase interp_op is from pillow
|
||||
)
|
||||
|
||||
if degradation == "bsrgan":
|
||||
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||
if degradation == 'bsrgan':
|
||||
self.degradation_process = partial(
|
||||
degradation_fn_bsr, sf=downscale_f
|
||||
)
|
||||
|
||||
elif degradation == "bsrgan_light":
|
||||
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
||||
elif degradation == 'bsrgan_light':
|
||||
self.degradation_process = partial(
|
||||
degradation_fn_bsr_light, sf=downscale_f
|
||||
)
|
||||
|
||||
else:
|
||||
interpolation_fn = {
|
||||
"cv_nearest": cv2.INTER_NEAREST,
|
||||
"cv_bilinear": cv2.INTER_LINEAR,
|
||||
"cv_bicubic": cv2.INTER_CUBIC,
|
||||
"cv_area": cv2.INTER_AREA,
|
||||
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||
"pil_nearest": PIL.Image.NEAREST,
|
||||
"pil_bilinear": PIL.Image.BILINEAR,
|
||||
"pil_bicubic": PIL.Image.BICUBIC,
|
||||
"pil_box": PIL.Image.BOX,
|
||||
"pil_hamming": PIL.Image.HAMMING,
|
||||
"pil_lanczos": PIL.Image.LANCZOS,
|
||||
'cv_nearest': cv2.INTER_NEAREST,
|
||||
'cv_bilinear': cv2.INTER_LINEAR,
|
||||
'cv_bicubic': cv2.INTER_CUBIC,
|
||||
'cv_area': cv2.INTER_AREA,
|
||||
'cv_lanczos': cv2.INTER_LANCZOS4,
|
||||
'pil_nearest': PIL.Image.NEAREST,
|
||||
'pil_bilinear': PIL.Image.BILINEAR,
|
||||
'pil_bicubic': PIL.Image.BICUBIC,
|
||||
'pil_box': PIL.Image.BOX,
|
||||
'pil_hamming': PIL.Image.HAMMING,
|
||||
'pil_lanczos': PIL.Image.LANCZOS,
|
||||
}[degradation]
|
||||
|
||||
self.pil_interpolation = degradation.startswith("pil_")
|
||||
self.pil_interpolation = degradation.startswith('pil_')
|
||||
|
||||
if self.pil_interpolation:
|
||||
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
||||
self.degradation_process = partial(
|
||||
TF.resize,
|
||||
size=self.LR_size,
|
||||
interpolation=interpolation_fn,
|
||||
)
|
||||
|
||||
else:
|
||||
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
||||
interpolation=interpolation_fn)
|
||||
self.degradation_process = albumentations.SmallestMaxSize(
|
||||
max_size=self.LR_size, interpolation=interpolation_fn
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base)
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = self.base[i]
|
||||
image = Image.open(example["file_path_"])
|
||||
image = Image.open(example['file_path_'])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
if not image.mode == 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
|
||||
min_side_len = min(image.shape[:2])
|
||||
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
||||
crop_side_len = min_side_len * np.random.uniform(
|
||||
self.min_crop_f, self.max_crop_f, size=None
|
||||
)
|
||||
crop_side_len = int(crop_side_len)
|
||||
|
||||
if self.center_crop:
|
||||
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
||||
self.cropper = albumentations.CenterCrop(
|
||||
height=crop_side_len, width=crop_side_len
|
||||
)
|
||||
|
||||
else:
|
||||
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
||||
self.cropper = albumentations.RandomCrop(
|
||||
height=crop_side_len, width=crop_side_len
|
||||
)
|
||||
|
||||
image = self.cropper(image=image)["image"]
|
||||
image = self.image_rescaler(image=image)["image"]
|
||||
image = self.cropper(image=image)['image']
|
||||
image = self.image_rescaler(image=image)['image']
|
||||
|
||||
if self.pil_interpolation:
|
||||
image_pil = PIL.Image.fromarray(image)
|
||||
@ -364,10 +432,10 @@ class ImageNetSR(Dataset):
|
||||
LR_image = np.array(LR_image).astype(np.uint8)
|
||||
|
||||
else:
|
||||
LR_image = self.degradation_process(image=image)["image"]
|
||||
LR_image = self.degradation_process(image=image)['image']
|
||||
|
||||
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
return example
|
||||
|
||||
@ -377,9 +445,11 @@ class ImageNetSRTrain(ImageNetSR):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_base(self):
|
||||
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||
with open('data/imagenet_train_hr_indices.p', 'rb') as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetTrain(process_images=False,)
|
||||
dset = ImageNetTrain(
|
||||
process_images=False,
|
||||
)
|
||||
return Subset(dset, indices)
|
||||
|
||||
|
||||
@ -388,7 +458,9 @@ class ImageNetSRValidation(ImageNetSR):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_base(self):
|
||||
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||
with open('data/imagenet_val_hr_indices.p', 'rb') as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetValidation(process_images=False,)
|
||||
dset = ImageNetValidation(
|
||||
process_images=False,
|
||||
)
|
||||
return Subset(dset, indices)
|
||||
|
104
ldm/data/lsun.py
104
ldm/data/lsun.py
@ -7,30 +7,33 @@ from torchvision import transforms
|
||||
|
||||
|
||||
class LSUNBase(Dataset):
|
||||
def __init__(self,
|
||||
txt_file,
|
||||
data_root,
|
||||
size=None,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
txt_file,
|
||||
data_root,
|
||||
size=None,
|
||||
interpolation='bicubic',
|
||||
flip_p=0.5,
|
||||
):
|
||||
self.data_paths = txt_file
|
||||
self.data_root = data_root
|
||||
with open(self.data_paths, "r") as f:
|
||||
with open(self.data_paths, 'r') as f:
|
||||
self.image_paths = f.read().splitlines()
|
||||
self._length = len(self.image_paths)
|
||||
self.labels = {
|
||||
"relative_file_path_": [l for l in self.image_paths],
|
||||
"file_path_": [os.path.join(self.data_root, l)
|
||||
for l in self.image_paths],
|
||||
'relative_file_path_': [l for l in self.image_paths],
|
||||
'file_path_': [
|
||||
os.path.join(self.data_root, l) for l in self.image_paths
|
||||
],
|
||||
}
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.interpolation = {
|
||||
'linear': PIL.Image.LINEAR,
|
||||
'bilinear': PIL.Image.BILINEAR,
|
||||
'bicubic': PIL.Image.BICUBIC,
|
||||
'lanczos': PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
@ -38,55 +41,86 @@ class LSUNBase(Dataset):
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||
image = Image.open(example["file_path_"])
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
image = Image.open(example['file_path_'])
|
||||
if not image.mode == 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = img.shape[0], img.shape[1]
|
||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2]
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
||||
|
||||
|
||||
class LSUNChurchesTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||
super().__init__(
|
||||
txt_file='data/lsun/church_outdoor_train.txt',
|
||||
data_root='data/lsun/churches',
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class LSUNChurchesValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0., **kwargs):
|
||||
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||
flip_p=flip_p, **kwargs)
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/church_outdoor_val.txt',
|
||||
data_root='data/lsun/churches',
|
||||
flip_p=flip_p,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class LSUNBedroomsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||
super().__init__(
|
||||
txt_file='data/lsun/bedrooms_train.txt',
|
||||
data_root='data/lsun/bedrooms',
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class LSUNBedroomsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||
flip_p=flip_p, **kwargs)
|
||||
super().__init__(
|
||||
txt_file='data/lsun/bedrooms_val.txt',
|
||||
data_root='data/lsun/bedrooms',
|
||||
flip_p=flip_p,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class LSUNCatsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||
super().__init__(
|
||||
txt_file='data/lsun/cat_train.txt',
|
||||
data_root='data/lsun/cats',
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class LSUNCatsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0., **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||
flip_p=flip_p, **kwargs)
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(
|
||||
txt_file='data/lsun/cat_val.txt',
|
||||
data_root='data/lsun/cats',
|
||||
flip_p=flip_p,
|
||||
**kwargs
|
||||
)
|
||||
|
202
ldm/data/personalized.py
Normal file
202
ldm/data/personalized.py
Normal file
@ -0,0 +1,202 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
import random
|
||||
|
||||
imagenet_templates_smallest = [
|
||||
'a photo of a {}',
|
||||
]
|
||||
|
||||
imagenet_templates_small = [
|
||||
'a photo of a {}',
|
||||
'a rendering of a {}',
|
||||
'a cropped photo of the {}',
|
||||
'the photo of a {}',
|
||||
'a photo of a clean {}',
|
||||
'a photo of a dirty {}',
|
||||
'a dark photo of the {}',
|
||||
'a photo of my {}',
|
||||
'a photo of the cool {}',
|
||||
'a close-up photo of a {}',
|
||||
'a bright photo of the {}',
|
||||
'a cropped photo of a {}',
|
||||
'a photo of the {}',
|
||||
'a good photo of the {}',
|
||||
'a photo of one {}',
|
||||
'a close-up photo of the {}',
|
||||
'a rendition of the {}',
|
||||
'a photo of the clean {}',
|
||||
'a rendition of a {}',
|
||||
'a photo of a nice {}',
|
||||
'a good photo of a {}',
|
||||
'a photo of the nice {}',
|
||||
'a photo of the small {}',
|
||||
'a photo of the weird {}',
|
||||
'a photo of the large {}',
|
||||
'a photo of a cool {}',
|
||||
'a photo of a small {}',
|
||||
]
|
||||
|
||||
imagenet_dual_templates_small = [
|
||||
'a photo of a {} with {}',
|
||||
'a rendering of a {} with {}',
|
||||
'a cropped photo of the {} with {}',
|
||||
'the photo of a {} with {}',
|
||||
'a photo of a clean {} with {}',
|
||||
'a photo of a dirty {} with {}',
|
||||
'a dark photo of the {} with {}',
|
||||
'a photo of my {} with {}',
|
||||
'a photo of the cool {} with {}',
|
||||
'a close-up photo of a {} with {}',
|
||||
'a bright photo of the {} with {}',
|
||||
'a cropped photo of a {} with {}',
|
||||
'a photo of the {} with {}',
|
||||
'a good photo of the {} with {}',
|
||||
'a photo of one {} with {}',
|
||||
'a close-up photo of the {} with {}',
|
||||
'a rendition of the {} with {}',
|
||||
'a photo of the clean {} with {}',
|
||||
'a rendition of a {} with {}',
|
||||
'a photo of a nice {} with {}',
|
||||
'a good photo of a {} with {}',
|
||||
'a photo of the nice {} with {}',
|
||||
'a photo of the small {} with {}',
|
||||
'a photo of the weird {} with {}',
|
||||
'a photo of the large {} with {}',
|
||||
'a photo of a cool {} with {}',
|
||||
'a photo of a small {} with {}',
|
||||
]
|
||||
|
||||
per_img_token_list = [
|
||||
'א',
|
||||
'ב',
|
||||
'ג',
|
||||
'ד',
|
||||
'ה',
|
||||
'ו',
|
||||
'ז',
|
||||
'ח',
|
||||
'ט',
|
||||
'י',
|
||||
'כ',
|
||||
'ל',
|
||||
'מ',
|
||||
'נ',
|
||||
'ס',
|
||||
'ע',
|
||||
'פ',
|
||||
'צ',
|
||||
'ק',
|
||||
'ר',
|
||||
'ש',
|
||||
'ת',
|
||||
]
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
size=None,
|
||||
repeats=100,
|
||||
interpolation='bicubic',
|
||||
flip_p=0.5,
|
||||
set='train',
|
||||
placeholder_token='*',
|
||||
per_image_tokens=False,
|
||||
center_crop=False,
|
||||
mixing_prob=0.25,
|
||||
coarse_class_text=None,
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.per_image_tokens = per_image_tokens
|
||||
self.center_crop = center_crop
|
||||
self.mixing_prob = mixing_prob
|
||||
|
||||
self.coarse_class_text = coarse_class_text
|
||||
|
||||
if per_image_tokens:
|
||||
assert self.num_images < len(
|
||||
per_img_token_list
|
||||
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||
|
||||
if set == 'train':
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {
|
||||
'linear': PIL.Image.LINEAR,
|
||||
'bilinear': PIL.Image.BILINEAR,
|
||||
'bicubic': PIL.Image.BICUBIC,
|
||||
'lanczos': PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
if self.coarse_class_text:
|
||||
placeholder_string = (
|
||||
f'{self.coarse_class_text} {placeholder_string}'
|
||||
)
|
||||
|
||||
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
|
||||
text = random.choice(imagenet_dual_templates_small).format(
|
||||
placeholder_string, per_img_token_list[i % self.num_images]
|
||||
)
|
||||
else:
|
||||
text = random.choice(imagenet_templates_small).format(
|
||||
placeholder_string
|
||||
)
|
||||
|
||||
example['caption'] = text
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
169
ldm/data/personalized_file.py
Normal file
169
ldm/data/personalized_file.py
Normal file
@ -0,0 +1,169 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
import random
|
||||
|
||||
imagenet_templates_small = [
|
||||
'a painting in the style of {}',
|
||||
'a rendering in the style of {}',
|
||||
'a cropped painting in the style of {}',
|
||||
'the painting in the style of {}',
|
||||
'a clean painting in the style of {}',
|
||||
'a dirty painting in the style of {}',
|
||||
'a dark painting in the style of {}',
|
||||
'a picture in the style of {}',
|
||||
'a cool painting in the style of {}',
|
||||
'a close-up painting in the style of {}',
|
||||
'a bright painting in the style of {}',
|
||||
'a cropped painting in the style of {}',
|
||||
'a good painting in the style of {}',
|
||||
'a close-up painting in the style of {}',
|
||||
'a rendition in the style of {}',
|
||||
'a nice painting in the style of {}',
|
||||
'a small painting in the style of {}',
|
||||
'a weird painting in the style of {}',
|
||||
'a large painting in the style of {}',
|
||||
]
|
||||
|
||||
imagenet_dual_templates_small = [
|
||||
'a painting in the style of {} with {}',
|
||||
'a rendering in the style of {} with {}',
|
||||
'a cropped painting in the style of {} with {}',
|
||||
'the painting in the style of {} with {}',
|
||||
'a clean painting in the style of {} with {}',
|
||||
'a dirty painting in the style of {} with {}',
|
||||
'a dark painting in the style of {} with {}',
|
||||
'a cool painting in the style of {} with {}',
|
||||
'a close-up painting in the style of {} with {}',
|
||||
'a bright painting in the style of {} with {}',
|
||||
'a cropped painting in the style of {} with {}',
|
||||
'a good painting in the style of {} with {}',
|
||||
'a painting of one {} in the style of {}',
|
||||
'a nice painting in the style of {} with {}',
|
||||
'a small painting in the style of {} with {}',
|
||||
'a weird painting in the style of {} with {}',
|
||||
'a large painting in the style of {} with {}',
|
||||
]
|
||||
|
||||
per_img_token_list = [
|
||||
'א',
|
||||
'ב',
|
||||
'ג',
|
||||
'ד',
|
||||
'ה',
|
||||
'ו',
|
||||
'ז',
|
||||
'ח',
|
||||
'ט',
|
||||
'י',
|
||||
'כ',
|
||||
'ל',
|
||||
'מ',
|
||||
'נ',
|
||||
'ס',
|
||||
'ע',
|
||||
'פ',
|
||||
'צ',
|
||||
'ק',
|
||||
'ר',
|
||||
'ש',
|
||||
'ת',
|
||||
]
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
size=None,
|
||||
repeats=100,
|
||||
interpolation='bicubic',
|
||||
flip_p=0.5,
|
||||
set='train',
|
||||
placeholder_token='*',
|
||||
per_image_tokens=False,
|
||||
center_crop=False,
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.per_image_tokens = per_image_tokens
|
||||
self.center_crop = center_crop
|
||||
|
||||
if per_image_tokens:
|
||||
assert self.num_images < len(
|
||||
per_img_token_list
|
||||
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
|
||||
|
||||
if set == 'train':
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {
|
||||
'linear': PIL.Image.LINEAR,
|
||||
'bilinear': PIL.Image.BILINEAR,
|
||||
'bicubic': PIL.Image.BICUBIC,
|
||||
'lanczos': PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
if self.per_image_tokens and np.random.uniform() < 0.25:
|
||||
text = random.choice(imagenet_dual_templates_small).format(
|
||||
self.placeholder_token, per_img_token_list[i % self.num_images]
|
||||
)
|
||||
else:
|
||||
text = random.choice(imagenet_templates_small).format(
|
||||
self.placeholder_token
|
||||
)
|
||||
|
||||
example['caption'] = text
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
24
ldm/data/util.py
Normal file
24
ldm/data/util.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from ldm.modules.midas.api import load_midas_transform
|
||||
|
||||
|
||||
class AddMiDaS(object):
|
||||
def __init__(self, model_type):
|
||||
super().__init__()
|
||||
self.transform = load_midas_transform(model_type)
|
||||
|
||||
def pt2np(self, x):
|
||||
x = ((x + 1.0) * .5).detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
def np2pt(self, x):
|
||||
x = torch.from_numpy(x) * 2 - 1.
|
||||
return x
|
||||
|
||||
def __call__(self, sample):
|
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||
x = self.pt2np(sample['jpg'])
|
||||
x = self.transform({"image": x})["image"]
|
||||
sample['midas_in'] = x
|
||||
return sample
|
1
ldm/devices/__init__.py
Normal file
1
ldm/devices/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from ldm.devices.devices import choose_autocast_device, choose_torch_device
|
24
ldm/devices/devices.py
Normal file
24
ldm/devices/devices.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
def choose_torch_device() -> str:
|
||||
'''Convenience routine for guessing which GPU device to run model on'''
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda'
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
def choose_autocast_device(device):
|
||||
'''Returns an autocast compatible device from a torch device'''
|
||||
device_type = device.type # this returns 'mps' on M1
|
||||
# autocast only for cuda, but GTX 16xx have issues with it
|
||||
if device_type == 'cuda':
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name:
|
||||
return device_type,nullcontext
|
||||
else:
|
||||
return device_type,autocast
|
||||
else:
|
||||
return 'cpu',nullcontext
|
@ -5,32 +5,49 @@ class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
lr_min,
|
||||
lr_max,
|
||||
lr_start,
|
||||
max_decay_steps,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.
|
||||
self.last_lr = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f'current step: {n}, recent lr-multiplier: {self.last_lr}'
|
||||
)
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||
lr = (
|
||||
self.lr_max - self.lr_start
|
||||
) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||
t = (n - self.lr_warm_up_steps) / (
|
||||
self.lr_max_decay_steps - self.lr_warm_up_steps
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n,**kwargs)
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
@ -38,15 +55,30 @@ class LambdaWarmUpCosineScheduler2:
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
f_min,
|
||||
f_max,
|
||||
f_start,
|
||||
cycle_lengths,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
assert (
|
||||
len(warm_up_steps)
|
||||
== len(f_min)
|
||||
== len(f_max)
|
||||
== len(f_start)
|
||||
== len(cycle_lengths)
|
||||
)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.
|
||||
self.last_f = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
@ -60,17 +92,25 @@ class LambdaWarmUpCosineScheduler2:
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
|
||||
f'current cycle {cycle}'
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
f = (
|
||||
self.f_max[cycle] - self.f_start[cycle]
|
||||
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
f = self.f_min[cycle] + 0.5 * (
|
||||
self.f_max[cycle] - self.f_min[cycle]
|
||||
) * (1 + np.cos(t * np.pi))
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
@ -79,20 +119,25 @@ class LambdaWarmUpCosineScheduler2:
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f'current step: {n}, recent lr-multiplier: {self.last_f}, '
|
||||
f'current cycle {cycle}'
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
f = (
|
||||
self.f_max[cycle] - self.f_start[cycle]
|
||||
) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
self.cycle_lengths[cycle] - n
|
||||
) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
|
@ -6,29 +6,32 @@ from contextlib import contextmanager
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from ldm.modules.distributions.distributions import (
|
||||
DiagonalGaussianDistribution,
|
||||
)
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
@ -36,24 +39,34 @@ class VQModel(pl.LightningModule):
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.quantize = VectorQuantizer(
|
||||
n_embed,
|
||||
embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
embed_dim, ddconfig['z_channels'], 1
|
||||
)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||
)
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
print(
|
||||
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
@ -66,28 +79,30 @@ class VQModel(pl.LightningModule):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
print(f'{context}: Switched to EMA weights')
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
print(f'{context}: Restored training weights')
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
print(f'Missing Keys: {missing}')
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
@ -115,7 +130,7 @@ class VQModel(pl.LightningModule):
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
quant, diff, (_, _, ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
@ -125,7 +140,11 @@ class VQModel(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
x = (
|
||||
x.permute(0, 3, 1, 2)
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
@ -133,9 +152,11 @@ class VQModel(pl.LightningModule):
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
new_resize = np.random.choice(
|
||||
np.arange(lower_size, upper_size + 16, 16)
|
||||
)
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = F.interpolate(x, size=new_resize, mode='bicubic')
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
@ -147,81 +168,139 @@ class VQModel(pl.LightningModule):
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
log_dict_ema = self._validation_step(
|
||||
batch, batch_idx, suffix='_ema'
|
||||
)
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
def _validation_step(self, batch, batch_idx, suffix=''):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log(
|
||||
f'val{suffix}/rec_loss',
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
f'val{suffix}/aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
del log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print('lr_d', lr_d)
|
||||
print('lr_g', lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quantize.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
|
||||
)
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'scheduler': LambdaLR(
|
||||
opt_ae, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
'frequency': 1,
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'scheduler': LambdaLR(
|
||||
opt_disc, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
'frequency': 1,
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
@ -235,7 +314,7 @@ class VQModel(pl.LightningModule):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
log['inputs'] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
@ -243,21 +322,24 @@ class VQModel(pl.LightningModule):
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
log['inputs'] = x
|
||||
log['reconstructions'] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||
)
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
@ -283,43 +365,50 @@ class VQModelInterface(VQModel):
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
2 * ddconfig['z_channels'], 2 * embed_dim, 1
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
embed_dim, ddconfig['z_channels'], 1
|
||||
)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, colorize_nlabels, 1, 1)
|
||||
)
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
print(f'Restored from {path}')
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
@ -345,7 +434,11 @@ class AutoencoderKL(pl.LightningModule):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
x = (
|
||||
x.permute(0, 3, 1, 2)
|
||||
.to(memory_format=torch.contiguous_format)
|
||||
.float()
|
||||
)
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
@ -354,44 +447,102 @@ class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
)
|
||||
self.log(
|
||||
'aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False,
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
)
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
self.log(
|
||||
'discloss',
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False,
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val',
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val',
|
||||
)
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
||||
)
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
@ -409,17 +560,19 @@ class AutoencoderKL(pl.LightningModule):
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer(
|
||||
'colorize', torch.randn(3, x.shape[1], 1, 1).to(x)
|
||||
)
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
|
@ -10,13 +10,13 @@ from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||
from ldm.modules.diffusionmodules.openaimodel import (
|
||||
EncoderUNetModel,
|
||||
UNetModel,
|
||||
)
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {
|
||||
'class_label': EncoderUNetModel,
|
||||
'segmentation': UNetModel
|
||||
}
|
||||
__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
@ -26,37 +26,49 @@ def disabled_train(self, mode=True):
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
def __init__(self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
*args,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.0e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
||||
diffusion_config = natsorted(
|
||||
glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
|
||||
)[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.numd = (
|
||||
self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
)
|
||||
self.log_time_interval = (
|
||||
self.diffusion_model.num_timesteps // log_steps
|
||||
)
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
||||
self.label_key = (
|
||||
label_key
|
||||
if not hasattr(self.diffusion_model, 'cond_stage_key')
|
||||
else self.diffusion_model.cond_stage_key
|
||||
)
|
||||
|
||||
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
||||
assert (
|
||||
self.label_key is not None
|
||||
), 'label_key neither in diffusion model nor in model.params'
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
@ -68,22 +80,27 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
sd = torch.load(path, map_location='cpu')
|
||||
if 'state_dict' in list(sd.keys()):
|
||||
sd = sd['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
missing, unexpected = (
|
||||
self.load_state_dict(sd, strict=False)
|
||||
if not only_model
|
||||
else self.model.load_state_dict(sd, strict=False)
|
||||
)
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f'Missing Keys: {missing}')
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
@ -93,17 +110,25 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
||||
model_config = deepcopy(
|
||||
self.diffusion_config.params.unet_config.params
|
||||
)
|
||||
model_config.in_channels = (
|
||||
self.diffusion_config.params.unet_config.params.out_channels
|
||||
)
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == 'class_label':
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print('#####################################################################')
|
||||
print(
|
||||
'#####################################################################'
|
||||
)
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print('#####################################################################')
|
||||
print(
|
||||
'#####################################################################'
|
||||
)
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
@ -111,11 +136,19 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
continuous_sqrt_alpha_cumprod = (
|
||||
self.diffusion_model.sample_continuous_noise_level(
|
||||
x.shape[0], t + 1
|
||||
)
|
||||
)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
||||
return self.diffusion_model.q_sample(
|
||||
x_start=x,
|
||||
t=t,
|
||||
noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
|
||||
)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
@ -141,17 +174,21 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
||||
targets = F.interpolate(
|
||||
targets, size=(h // 2, w // 2), mode='nearest'
|
||||
)
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
def compute_top_k(self, logits, labels, k, reduction='mean'):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
if reduction == 'mean':
|
||||
return (
|
||||
(top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
)
|
||||
elif reduction == 'none':
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
@ -162,29 +199,59 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = 'train' if self.training else 'val'
|
||||
log = {}
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
log[f'{log_prefix}/loss'] = loss.mean()
|
||||
log[f'{log_prefix}/acc@1'] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction='mean'
|
||||
)
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
log[f'{log_prefix}/acc@5'] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction='mean'
|
||||
)
|
||||
|
||||
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
||||
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
||||
self.log_dict(
|
||||
log,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=self.training,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log(
|
||||
'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
|
||||
)
|
||||
self.log(
|
||||
'global_step',
|
||||
self.global_step,
|
||||
logger=False,
|
||||
on_epoch=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||
self.log(
|
||||
'lr_abs',
|
||||
lr,
|
||||
on_step=True,
|
||||
logger=True,
|
||||
on_epoch=False,
|
||||
prog_bar=True,
|
||||
)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
||||
x, *_ = self.diffusion_model.get_input(
|
||||
batch, k=self.diffusion_model.first_stage_key
|
||||
)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
||||
t = torch.randint(
|
||||
0,
|
||||
self.diffusion_model.num_timesteps,
|
||||
(x.shape[0],),
|
||||
device=self.device,
|
||||
).long()
|
||||
else:
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
t = torch.full(
|
||||
size=(x.shape[0],), fill_value=t, device=self.device
|
||||
).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
@ -200,8 +267,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
||||
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
||||
self.noisy_acc = {
|
||||
t: {'acc@1': [], 'acc@5': []}
|
||||
for t in range(
|
||||
0,
|
||||
self.diffusion_model.num_timesteps,
|
||||
self.diffusion_model.log_every_t,
|
||||
)
|
||||
}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
@ -212,24 +285,35 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
||||
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
||||
self.noisy_acc[t]['acc@1'].append(
|
||||
self.compute_top_k(logits, targets, k=1, reduction='mean')
|
||||
)
|
||||
self.noisy_acc[t]['acc@5'].append(
|
||||
self.compute_top_k(logits, targets, k=5, reduction='mean')
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
'scheduler': LambdaLR(
|
||||
optimizer, lr_lambda=scheduler.schedule
|
||||
),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
}]
|
||||
'frequency': 1,
|
||||
}
|
||||
]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
@ -243,7 +327,7 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == 'class_label':
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
|
||||
log['labels'] = y
|
||||
|
||||
if ismap(y):
|
||||
@ -256,10 +340,14 @@ class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
log[f'inputs@t{current_time}'] = x_noisy
|
||||
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = F.one_hot(
|
||||
logits.argmax(dim=1), num_classes=self.num_classes
|
||||
)
|
||||
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
|
||||
pred
|
||||
)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
@ -4,88 +4,146 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
||||
extract_into_tensor
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device or choose_torch_device()
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(dtype=torch.float32, device=self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=True,
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
(
|
||||
ddim_sigmas,
|
||||
ddim_alphas,
|
||||
ddim_alphas_prev,
|
||||
) = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
@ -93,30 +151,48 @@ class DDIMSampler(object):
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
# This routine gets called from img2img
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
@ -125,17 +201,38 @@ class DDIMSampler(object):
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = (
|
||||
timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
)
|
||||
print(f'\nRunning DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc='DDIM Sampler',
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
@ -143,18 +240,30 @@ class DDIMSampler(object):
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
@ -162,43 +271,84 @@ class DDIMSampler(object):
|
||||
|
||||
return img, intermediates
|
||||
|
||||
# This routine gets called from ddim_sampling() and decode()
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
alphas = (
|
||||
self.model.alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_alphas
|
||||
)
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = (
|
||||
sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
)
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
@ -216,33 +366,68 @@ class DDIMSampler(object):
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, z_mask = None, x0=None):
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
img_callback=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if z_mask is not None and i < total_steps - 2:
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
mask_inv = 1. - z_mask
|
||||
x_dec = (img_orig * mask_inv) + (z_mask * x_dec)
|
||||
xdec_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
||||
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
return x_dec
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
|
||||
if img_callback:
|
||||
img_callback(x_dec, i)
|
||||
|
||||
return x_dec
|
||||
|
File diff suppressed because it is too large
Load Diff
1
ldm/models/diffusion/dpm_solver/__init__.py
Normal file
1
ldm/models/diffusion/dpm_solver/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .sampler import DPMSolverSampler
|
1154
ldm/models/diffusion/dpm_solver/dpm_solver.py
Normal file
1154
ldm/models/diffusion/dpm_solver/dpm_solver.py
Normal file
File diff suppressed because it is too large
Load Diff
87
ldm/models/diffusion/dpm_solver/sampler.py
Normal file
87
ldm/models/diffusion/dpm_solver/sampler.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""SAMPLING ONLY."""
|
||||
import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
}
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
ns,
|
||||
model_type=MODEL_TYPES[self.model.parameterization],
|
||||
guidance_type="classifier-free",
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
@ -4,120 +4,195 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
)
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device if device else choose_torch_device()
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.float32).to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.0,
|
||||
verbose=True,
|
||||
):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), 'alphas have to be defined for each timestep'
|
||||
to_torch = (
|
||||
lambda x: x.clone()
|
||||
.detach()
|
||||
.to(torch.float32)
|
||||
.to(self.model.device)
|
||||
)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
self.register_buffer(
|
||||
'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
self.register_buffer(
|
||||
'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
'sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
(
|
||||
ddim_sigmas,
|
||||
ddim_alphas,
|
||||
ddim_alphas_prev,
|
||||
) = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
self.register_buffer(
|
||||
'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas)
|
||||
)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
'ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
# print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
samples, intermediates = self.plms_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
def plms_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
@ -126,42 +201,81 @@ class PLMSSampler(object):
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
time_range = (
|
||||
list(reversed(range(0, timesteps)))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = (
|
||||
timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
)
|
||||
# print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
iterator = tqdm(
|
||||
time_range,
|
||||
desc='PLMS Sampler',
|
||||
total=total_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next)
|
||||
outs = self.p_sample_plms(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
@ -170,47 +284,95 @@ class PLMSSampler(object):
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
def p_sample_plms(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
e_t_uncond, e_t = self.model.apply_model(
|
||||
x_in, t_in, c_in
|
||||
).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
alphas = (
|
||||
self.model.alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_alphas
|
||||
)
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
a_prev = torch.full(
|
||||
(b, 1, 1, 1), alphas_prev[index], device=device
|
||||
)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = (
|
||||
sigma_t
|
||||
* noise_like(x.shape, device, repeat_noise)
|
||||
* temperature
|
||||
)
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
@ -229,7 +391,12 @@ class PLMSSampler(object):
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
e_t_prime = (
|
||||
55 * e_t
|
||||
- 59 * old_eps[-1]
|
||||
+ 37 * old_eps[-2]
|
||||
- 9 * old_eps[-3]
|
||||
) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
|
22
ldm/models/diffusion/sampling_util.py
Normal file
22
ldm/models/diffusion/sampling_util.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
0
ldm/modules/__init__.py
Normal file
0
ldm/modules/__init__.py
Normal file
@ -1,3 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
@ -9,7 +11,6 @@ from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@ -91,7 +92,7 @@ class LinearAttention(nn.Module):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
@ -169,98 +170,84 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
def einsum_op_compvis(self, q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
def einsum_op_slice_0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
return self.einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
return self.einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
def einsum_op(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if q.device.type == 'mps':
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)# * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
del x
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
return self.to_out(r2)
|
||||
r = self.einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
@ -280,12 +267,118 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = x.contiguous() if x.device.type == 'mps' else x
|
||||
x += self.attn1(self.norm1(x))
|
||||
x += self.attn2(self.norm2(x), context=context)
|
||||
x += self.ff(self.norm3(x))
|
||||
x += self.attn1(self.norm1(x.clone()))
|
||||
x += self.attn2(self.norm2(x.clone()), context=context)
|
||||
x += self.ff(self.norm3(x.clone()))
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlockMECA(nn.Module):
|
||||
'''
|
||||
Memory efficient cross-attention transformer block.
|
||||
'''
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
AttentionBuilder = MemoryEfficientCrossAttention
|
||||
self.attn1 = AttentionBuilder(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = AttentionBuilder(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
self.attn2._slice_size = slice_size
|
||||
|
||||
def forward(self, hidden_states, context=None):
|
||||
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
|
||||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
||||
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def _maybe_init(self, x):
|
||||
"""
|
||||
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
|
||||
: B, Head, Length
|
||||
"""
|
||||
from xformers.ops import AttentionOpDispatch
|
||||
if self.attention_op is not None:
|
||||
return
|
||||
|
||||
# _, K, M = x.shape
|
||||
_, M, K = x.shape
|
||||
try:
|
||||
self.attention_op = AttentionOpDispatch(
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
k=K,
|
||||
attn_bias_type=type(None),
|
||||
has_dropout=False,
|
||||
kv_len=M,
|
||||
q_len=M,
|
||||
).op
|
||||
|
||||
except NotImplementedError as err:
|
||||
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
from xformers.ops import memory_efficient_attention
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# init the attention op, if required, using the proper dimensions
|
||||
self._maybe_init(q)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
# TODO: Use this directly in the attention operation, as a bias
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
@ -307,10 +400,16 @@ class SpatialTransformer(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if os.environ.get('MEMORY_EFFICIENT_CROSS_ATTENTION', False):
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlockMECA(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
)
|
||||
else:
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
|
@ -3,12 +3,14 @@ import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.functional import silu
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.attention import LinearAttention
|
||||
|
||||
import psutil
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
@ -31,11 +33,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
@ -120,30 +117,17 @@ class ResnetBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h1 = x
|
||||
h2 = self.norm1(h1)
|
||||
del h1
|
||||
|
||||
h3 = nonlinearity(h2)
|
||||
del h2
|
||||
|
||||
h4 = self.conv1(h3)
|
||||
del h3
|
||||
h = self.norm1(x)
|
||||
h = silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(silu(temb))[:,:,None,None]
|
||||
|
||||
h5 = self.norm2(h4)
|
||||
del h4
|
||||
|
||||
h6 = nonlinearity(h5)
|
||||
del h5
|
||||
|
||||
h7 = self.dropout(h6)
|
||||
del h6
|
||||
|
||||
h8 = self.conv2(h7)
|
||||
del h7
|
||||
h = self.norm2(h)
|
||||
h = silu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
@ -151,8 +135,7 @@ class ResnetBlock(nn.Module):
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h8
|
||||
|
||||
return x + h
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
@ -209,21 +192,29 @@ class AttnBlock(nn.Module):
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
if q.device.type == 'cuda':
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
else:
|
||||
if psutil.virtual_memory().available / (1024**3) < 12:
|
||||
slice_size = 1
|
||||
else:
|
||||
slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])))
|
||||
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
@ -373,7 +364,7 @@ class Model(nn.Module):
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = silu(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
@ -407,7 +398,7 @@ class Model(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -504,7 +495,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -590,54 +581,36 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h1 = self.conv_in(z)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h2 = self.mid.block_1(h1, temb)
|
||||
del h1
|
||||
|
||||
h3 = self.mid.attn_1(h2)
|
||||
del h2
|
||||
|
||||
h = self.mid.block_2(h3, temb)
|
||||
del h3
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# prepare for up sampling
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if h.device.type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
t = h
|
||||
h = self.up[i_level].attn[i_block](t)
|
||||
del t
|
||||
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
t = h
|
||||
h = self.up[i_level].upsample(t)
|
||||
del t
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h1 = self.norm_out(h)
|
||||
del h
|
||||
|
||||
h2 = nonlinearity(h1)
|
||||
del h1
|
||||
|
||||
h = self.conv_out(h2)
|
||||
del h2
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
t = h
|
||||
h = torch.tanh(t)
|
||||
del t
|
||||
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
@ -672,7 +645,7 @@ class SimpleDecoder(nn.Module):
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
@ -720,7 +693,7 @@ class UpsampleDecoder(nn.Module):
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -896,7 +869,7 @@ class FirstStagePostProcessor(nn.Module):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
z = silu(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
@ -905,4 +878,3 @@ class FirstStagePostProcessor(nn.Module):
|
||||
if self.do_reshape:
|
||||
z = rearrange(z,'b c h w -> b (h w) c')
|
||||
return z
|
||||
|
||||
|
@ -24,6 +24,7 @@ from ldm.modules.attention import SpatialTransformer
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
@ -42,7 +43,9 @@ class AttentionPool2d(nn.Module):
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
||||
)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
@ -97,37 +100,45 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
def __init__(
|
||||
self, channels, use_conv, dims=2, out_channels=None, padding=1
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
self.conv = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest'
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
"""Learned 2x upsampling without padding"""
|
||||
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
self.up = nn.ConvTranspose2d(
|
||||
self.channels, self.out_channels, kernel_size=ks, stride=2
|
||||
)
|
||||
|
||||
def forward(self,x):
|
||||
def forward(self, x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
@ -140,7 +151,9 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
def __init__(
|
||||
self, channels, use_conv, dims=2, out_channels=None, padding=1
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@ -149,7 +162,12 @@ class Downsample(nn.Module):
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
@ -219,7 +237,9 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm
|
||||
else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
@ -227,7 +247,9 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
conv_nd(
|
||||
dims, self.out_channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -238,7 +260,9 @@ class ResBlock(TimestepBlock):
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 1
|
||||
)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
@ -251,7 +275,6 @@ class ResBlock(TimestepBlock):
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
@ -297,7 +320,7 @@ class AttentionBlock(nn.Module):
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
@ -312,8 +335,10 @@ class AttentionBlock(nn.Module):
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
return checkpoint(
|
||||
self._forward, (x,), self.parameters(), True
|
||||
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
# return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
@ -340,7 +365,7 @@ def count_flops_attn(model, _x, y):
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
@ -362,13 +387,15 @@ class QKVAttentionLegacy(nn.Module):
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
|
||||
ch, dim=1
|
||||
)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
'bct,bcs->bts', q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
a = th.einsum('bts,bcs->bct', weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
@ -397,12 +424,14 @@ class QKVAttention(nn.Module):
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
'bct,bcs->bts',
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
a = th.einsum(
|
||||
'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length)
|
||||
)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
@ -461,19 +490,24 @@ class UNetModel(nn.Module):
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
assert (
|
||||
context_dim is not None
|
||||
), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
@ -481,10 +515,14 @@ class UNetModel(nn.Module):
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
assert (
|
||||
num_head_channels != -1
|
||||
), 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
assert (
|
||||
num_heads != -1
|
||||
), 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
@ -545,8 +583,12 @@ class UNetModel(nn.Module):
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
# num_heads = 1
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@ -554,8 +596,14 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -592,8 +640,12 @@ class UNetModel(nn.Module):
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
# num_heads = 1
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
@ -609,9 +661,15 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
),
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
@ -646,8 +704,12 @@ class UNetModel(nn.Module):
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
# num_heads = 1
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@ -655,8 +717,14 @@ class UNetModel(nn.Module):
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
@ -673,7 +741,9 @@ class UNetModel(nn.Module):
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -682,14 +752,16 @@ class UNetModel(nn.Module):
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
@ -707,7 +779,7 @@ class UNetModel(nn.Module):
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@ -718,9 +790,11 @@ class UNetModel(nn.Module):
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
), 'must specify y if and only if the model is class-conditional'
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False
|
||||
)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@ -733,6 +807,8 @@ class UNetModel(nn.Module):
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
if h.shape[-2:] != hs[-1].shape[-2:]:
|
||||
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
@ -768,9 +844,9 @@ class EncoderUNetModel(nn.Module):
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
pool='adaptive',
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -888,7 +964,7 @@ class EncoderUNetModel(nn.Module):
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
if pool == 'adaptive':
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
@ -896,7 +972,7 @@ class EncoderUNetModel(nn.Module):
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
elif pool == 'attention':
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
@ -905,13 +981,13 @@ class EncoderUNetModel(nn.Module):
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
elif pool == 'spatial':
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
elif pool == 'spatial_v2':
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
@ -919,7 +995,7 @@ class EncoderUNetModel(nn.Module):
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
raise NotImplementedError(f'Unexpected {pool} pooling')
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
@ -942,20 +1018,21 @@ class EncoderUNetModel(nn.Module):
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(
|
||||
timestep_embedding(timesteps, self.model_channels)
|
||||
)
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
if self.pool.startswith('spatial'):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
if self.pool.startswith('spatial'):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
81
ldm/modules/diffusionmodules/upscaling.py
Normal file
81
ldm/modules/diffusionmodules/upscaling.py
Normal file
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
|
||||
from ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
# for concatenating a downsampled image to the latent representation
|
||||
def __init__(self, noise_schedule_config=None):
|
||||
super(AbstractLowScaleModel, self).__init__()
|
||||
if noise_schedule_config is not None:
|
||||
self.register_schedule(**noise_schedule_config)
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageConcat(AbstractLowScaleModel):
|
||||
# no noise level conditioning
|
||||
def __init__(self):
|
||||
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||
self.max_noise_level = 0
|
||||
|
||||
def forward(self, x):
|
||||
# fix to constant noise level
|
||||
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||
|
||||
|
||||
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
|
@ -18,15 +18,24 @@ from einops import repeat
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
def make_beta_schedule(
|
||||
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||
):
|
||||
if schedule == 'linear':
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
torch.linspace(
|
||||
linear_start**0.5,
|
||||
linear_end**0.5,
|
||||
n_timestep,
|
||||
dtype=torch.float64,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
elif schedule == 'cosine':
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
|
||||
+ cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
@ -34,44 +43,73 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
elif schedule == 'sqrt_linear':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||
)
|
||||
elif schedule == 'sqrt':
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||
)
|
||||
** 0.5
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
def make_ddim_timesteps(
|
||||
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
||||
):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
ddim_timesteps = (
|
||||
(
|
||||
np.linspace(
|
||||
0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
# steps_out = ddim_timesteps + 1 # removed due to some issues when reaching 1000
|
||||
steps_out = np.where(ddim_timesteps != 999, ddim_timesteps+1, ddim_timesteps)
|
||||
# steps_out = ddim_timesteps + 1
|
||||
steps_out = ddim_timesteps
|
||||
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
def make_ddim_sampling_parameters(
|
||||
alphacums, ddim_timesteps, eta, verbose=True
|
||||
):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
alphas_prev = np.asarray(
|
||||
[alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()
|
||||
)
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
sigmas = eta * np.sqrt(
|
||||
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
||||
)
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
@ -110,7 +148,9 @@ def checkpoint(func, inputs, params, flag):
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
if (
|
||||
False
|
||||
): # disabled checkpointing to allow requires_grad = False for main model
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
@ -130,7 +170,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
ctx.input_tensors = [
|
||||
x.detach().requires_grad_(True) for x in ctx.input_tensors
|
||||
]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
@ -161,12 +203,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
@ -206,16 +252,11 @@ def normalization(channels):
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
@ -226,7 +267,7 @@ def conv_nd(dims, *args, **kwargs):
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
@ -246,15 +287,16 @@ def avg_pool_nd(dims, *args, **kwargs):
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config
|
||||
)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
@ -263,6 +305,8 @@ class HybridConditioner(nn.Module):
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||
shape[0], *((1,) * (len(shape) - 1))
|
||||
)
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
@ -30,33 +30,45 @@ class DiagonalGaussianDistribution(object):
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
logtwopi
|
||||
+ self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
@ -74,7 +86,7 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
assert tensor is not None, 'at least one argument must be a Tensor'
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
|
@ -10,24 +10,30 @@ class LitEma(nn.Module):
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1,dtype=torch.int))
|
||||
self.register_buffer(
|
||||
'num_updates',
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
)
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
#remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.','')
|
||||
self.m_name2s_name.update({name:s_name})
|
||||
self.register_buffer(s_name,p.clone().detach().data)
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self,model):
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
||||
decay = min(
|
||||
self.decay, (1 + self.num_updates) / (10 + self.num_updates)
|
||||
)
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
@ -38,8 +44,12 @@ class LitEma(nn.Module):
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
shadow_params[sname] = shadow_params[sname].type_as(
|
||||
m_param[key]
|
||||
)
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
@ -48,7 +58,9 @@ class LitEma(nn.Module):
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
m_param[key].data.copy_(
|
||||
shadow_params[self.m_name2s_name[key]].data
|
||||
)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
|
273
ldm/modules/embedding_manager.py
Normal file
273
ldm/modules/embedding_manager.py
Normal file
@ -0,0 +1,273 @@
|
||||
from cmath import log
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import sys
|
||||
|
||||
from ldm.data.personalized import per_img_token_list
|
||||
from transformers import CLIPTokenizer
|
||||
from functools import partial
|
||||
|
||||
DEFAULT_PLACEHOLDER_TOKEN = ['*']
|
||||
|
||||
PROGRESSIVE_SCALE = 2000
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(
|
||||
string,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt',
|
||||
)
|
||||
tokens = batch_encoding['input_ids']
|
||||
""" assert (
|
||||
torch.count_nonzero(tokens - 49407) == 2
|
||||
), f"String '{string}' maps to more than a single token. Please use another string" """
|
||||
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
|
||||
token = token[0, 1]
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_embedding_for_clip_token(embedder, token):
|
||||
return embedder(token.unsqueeze(0))[0, 0]
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
placeholder_strings=None,
|
||||
initializer_words=None,
|
||||
per_image_tokens=False,
|
||||
num_vectors_per_token=1,
|
||||
progressive_words=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedder = embedder
|
||||
device = embedder.device
|
||||
|
||||
self.string_to_token_dict = {}
|
||||
self.string_to_param_dict = nn.ParameterDict()
|
||||
|
||||
self.initial_embeddings = (
|
||||
nn.ParameterDict()
|
||||
) # These should not be optimized
|
||||
|
||||
self.progressive_words = progressive_words
|
||||
self.progressive_counter = 0
|
||||
|
||||
self.max_vectors_per_token = num_vectors_per_token
|
||||
|
||||
if hasattr(
|
||||
embedder, 'tokenizer'
|
||||
): # using Stable Diffusion's CLIP encoder
|
||||
self.is_clip = True
|
||||
get_token_for_string = partial(
|
||||
get_clip_token_for_string, embedder.tokenizer
|
||||
)
|
||||
get_embedding_for_tkn = partial(
|
||||
get_embedding_for_clip_token,
|
||||
embedder.transformer.text_model.embeddings,
|
||||
)
|
||||
# per bug report #572
|
||||
#token_dim = 1280
|
||||
token_dim = 768
|
||||
else: # using LDM's BERT encoder
|
||||
self.is_clip = False
|
||||
get_token_for_string = partial(
|
||||
get_bert_token_for_string, embedder.tknz_fn
|
||||
)
|
||||
get_embedding_for_tkn = embedder.transformer.token_emb
|
||||
token_dim = 1280
|
||||
|
||||
if per_image_tokens:
|
||||
placeholder_strings.extend(per_img_token_list)
|
||||
|
||||
for idx, placeholder_string in enumerate(placeholder_strings):
|
||||
|
||||
token = get_token_for_string(placeholder_string)
|
||||
|
||||
if initializer_words and idx < len(initializer_words):
|
||||
init_word_token = get_token_for_string(initializer_words[idx])
|
||||
|
||||
with torch.no_grad():
|
||||
init_word_embedding = get_embedding_for_tkn(
|
||||
init_word_token.to(device)
|
||||
)
|
||||
|
||||
token_params = torch.nn.Parameter(
|
||||
init_word_embedding.unsqueeze(0).repeat(
|
||||
num_vectors_per_token, 1
|
||||
),
|
||||
requires_grad=True,
|
||||
)
|
||||
self.initial_embeddings[
|
||||
placeholder_string
|
||||
] = torch.nn.Parameter(
|
||||
init_word_embedding.unsqueeze(0).repeat(
|
||||
num_vectors_per_token, 1
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
token_params = torch.nn.Parameter(
|
||||
torch.rand(
|
||||
size=(num_vectors_per_token, token_dim),
|
||||
requires_grad=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.string_to_token_dict[placeholder_string] = token
|
||||
self.string_to_param_dict[placeholder_string] = token_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokenized_text,
|
||||
embedded_text,
|
||||
):
|
||||
b, n, device = *tokenized_text.shape, tokenized_text.device
|
||||
|
||||
for (
|
||||
placeholder_string,
|
||||
placeholder_token,
|
||||
) in self.string_to_token_dict.items():
|
||||
|
||||
placeholder_embedding = self.string_to_param_dict[
|
||||
placeholder_string
|
||||
].to(device)
|
||||
|
||||
if (
|
||||
self.max_vectors_per_token == 1
|
||||
): # If there's only one vector per token, we can do a simple replacement
|
||||
placeholder_idx = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
embedded_text[placeholder_idx] = placeholder_embedding
|
||||
else: # otherwise, need to insert and keep track of changing indices
|
||||
if self.progressive_words:
|
||||
self.progressive_counter += 1
|
||||
max_step_tokens = (
|
||||
1 + self.progressive_counter // PROGRESSIVE_SCALE
|
||||
)
|
||||
else:
|
||||
max_step_tokens = self.max_vectors_per_token
|
||||
|
||||
num_vectors_for_token = min(
|
||||
placeholder_embedding.shape[0], max_step_tokens
|
||||
)
|
||||
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
|
||||
if placeholder_rows.nelement() == 0:
|
||||
continue
|
||||
|
||||
sorted_cols, sort_idx = torch.sort(
|
||||
placeholder_cols, descending=True
|
||||
)
|
||||
sorted_rows = placeholder_rows[sort_idx]
|
||||
|
||||
for idx in range(len(sorted_rows)):
|
||||
row = sorted_rows[idx]
|
||||
col = sorted_cols[idx]
|
||||
|
||||
new_token_row = torch.cat(
|
||||
[
|
||||
tokenized_text[row][:col],
|
||||
placeholder_token.repeat(num_vectors_for_token).to(
|
||||
device
|
||||
),
|
||||
tokenized_text[row][col + 1 :],
|
||||
],
|
||||
axis=0,
|
||||
)[:n]
|
||||
new_embed_row = torch.cat(
|
||||
[
|
||||
embedded_text[row][:col],
|
||||
placeholder_embedding[:num_vectors_for_token],
|
||||
embedded_text[row][col + 1 :],
|
||||
],
|
||||
axis=0,
|
||||
)[:n]
|
||||
|
||||
embedded_text[row] = new_embed_row
|
||||
tokenized_text[row] = new_token_row
|
||||
|
||||
return embedded_text
|
||||
|
||||
def save(self, ckpt_path):
|
||||
torch.save(
|
||||
{
|
||||
'string_to_token': self.string_to_token_dict,
|
||||
'string_to_param': self.string_to_param_dict,
|
||||
},
|
||||
ckpt_path,
|
||||
)
|
||||
|
||||
def load(self, ckpt_path, full=True):
|
||||
ckpt = torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
# Handle .pt textual inversion files
|
||||
if 'string_to_token' in ckpt and 'string_to_param' in ckpt:
|
||||
self.string_to_token_dict = ckpt["string_to_token"]
|
||||
self.string_to_param_dict = ckpt["string_to_param"]
|
||||
|
||||
# Handle .bin textual inversion files from Huggingface Concepts
|
||||
# https://huggingface.co/sd-concepts-library
|
||||
else:
|
||||
for token_str in list(ckpt.keys()):
|
||||
token = get_clip_token_for_string(self.embedder.tokenizer, token_str)
|
||||
self.string_to_token_dict[token_str] = token
|
||||
ckpt[token_str] = torch.nn.Parameter(ckpt[token_str])
|
||||
|
||||
self.string_to_param_dict.update(ckpt)
|
||||
|
||||
if not full:
|
||||
for key, value in self.string_to_param_dict.items():
|
||||
self.string_to_param_dict[key] = torch.nn.Parameter(value.half())
|
||||
|
||||
def get_embedding_norms_squared(self):
|
||||
all_params = torch.cat(
|
||||
list(self.string_to_param_dict.values()), axis=0
|
||||
) # num_placeholders x embedding_dim
|
||||
param_norm_squared = (all_params * all_params).sum(
|
||||
axis=-1
|
||||
) # num_placeholders
|
||||
|
||||
return param_norm_squared
|
||||
|
||||
def embedding_parameters(self):
|
||||
return self.string_to_param_dict.parameters()
|
||||
|
||||
def embedding_to_coarse_loss(self):
|
||||
|
||||
loss = 0.0
|
||||
num_embeddings = len(self.initial_embeddings)
|
||||
|
||||
for key in self.initial_embeddings:
|
||||
optimized = self.string_to_param_dict[key]
|
||||
coarse = self.initial_embeddings[key].clone().to(optimized.device)
|
||||
|
||||
loss = (
|
||||
loss
|
||||
+ (optimized - coarse)
|
||||
@ (optimized - coarse).T
|
||||
/ num_embeddings
|
||||
)
|
||||
|
||||
return loss
|
@ -5,8 +5,40 @@ import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
import os
|
||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
from ldm.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
TransformerWrapper,
|
||||
) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
|
||||
|
||||
def _expand_mask(mask, dtype, tgt_len=None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
@ -17,7 +49,6 @@ class AbstractEncoder(nn.Module):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
super().__init__()
|
||||
@ -35,11 +66,22 @@ class ClassEmbedder(nn.Module):
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
@ -51,19 +93,44 @@ class TransformerEmbedder(AbstractEncoder):
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(
|
||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
from transformers import (
|
||||
BertTokenizerFast,
|
||||
) # TODO: add to reuquirements
|
||||
|
||||
# Modified to allow to run on non-internet connected compute nodes.
|
||||
# Model needs to be loaded into cache from an internet-connected machine
|
||||
# by running:
|
||||
# from transformers import BertTokenizerFast
|
||||
# BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
try:
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||
'bert-base-uncased', local_files_only=False
|
||||
)
|
||||
except OSError:
|
||||
raise SystemExit(
|
||||
"* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine."
|
||||
)
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt',
|
||||
)
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
@ -79,54 +146,84 @@ class BERTTokenizer(AbstractEncoder):
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.tknz_fn = BERTTokenizer(
|
||||
vq_interface=False, max_length=max_seq_len
|
||||
)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout,
|
||||
)
|
||||
|
||||
def forward(self, text):
|
||||
def forward(self, text, embedding_manager=None):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text)#.to(self.device)
|
||||
tokens = self.tknz_fn(text) # .to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
z = self.transformer(
|
||||
tokens, return_embeddings=True, embedding_manager=embedding_manager
|
||||
)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
def encode(self, text, **kwargs):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
return self(text, **kwargs)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False):
|
||||
def __init__(
|
||||
self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||
assert method in [
|
||||
'nearest',
|
||||
'linear',
|
||||
'bilinear',
|
||||
'trilinear',
|
||||
'bicubic',
|
||||
'area',
|
||||
]
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.interpolator = partial(
|
||||
torch.nn.functional.interpolate, mode=method
|
||||
)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||
print(
|
||||
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(
|
||||
in_channels, out_channels, 1, bias=bias
|
||||
)
|
||||
|
||||
def forward(self,x):
|
||||
def forward(self, x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
@ -134,45 +231,244 @@ class SpatialRescaler(nn.Module):
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device=choose_torch_device(),
|
||||
max_length=77,
|
||||
):
|
||||
super().__init__()
|
||||
if os.path.exists("models/clip-vit-large-patch14"):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("models/clip-vit-large-patch14")
|
||||
self.transformer = CLIPTextModel.from_pretrained("models/clip-vit-large-patch14")
|
||||
else:
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
version, local_files_only=False
|
||||
)
|
||||
self.transformer = CLIPTextModel.from_pretrained(
|
||||
version, local_files_only=False
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.freeze()
|
||||
|
||||
def embedding_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
embedding_manager=None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = (
|
||||
input_ids.shape[-1]
|
||||
if input_ids is not None
|
||||
else inputs_embeds.shape[-2]
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
if embedding_manager is not None:
|
||||
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
self.transformer.text_model.embeddings.forward = (
|
||||
embedding_forward.__get__(self.transformer.text_model.embeddings)
|
||||
)
|
||||
|
||||
def encoder_forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask=None,
|
||||
causal_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
return hidden_states
|
||||
|
||||
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
|
||||
self.transformer.text_model.encoder
|
||||
)
|
||||
|
||||
def text_encoder_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError('You have to specify either input_ids')
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
bsz, seq_len = input_shape
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _build_causal_attention_mask(
|
||||
bsz, seq_len, hidden_states.dtype
|
||||
).to(hidden_states.device)
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(
|
||||
attention_mask, hidden_states.dtype
|
||||
)
|
||||
|
||||
last_hidden_state = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
self.transformer.text_model.forward = text_encoder_forward.__get__(
|
||||
self.transformer.text_model
|
||||
)
|
||||
|
||||
def transformer_forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
self.transformer.forward = transformer_forward.__get__(
|
||||
self.transformer
|
||||
)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
def forward(self, text, **kwargs):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt',
|
||||
)
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the CLIP transformer encoder for text.
|
||||
"""
|
||||
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version='ViT-L/14',
|
||||
device=choose_torch_device(),
|
||||
max_length=77,
|
||||
n_repeat=1,
|
||||
normalize=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(version, jit=False, device="cpu")
|
||||
self.model, _ = clip.load(version, jit=False, device=device)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.n_repeat = n_repeat
|
||||
@ -192,7 +488,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
|
||||
|
||||
def encode(self, text):
|
||||
z = self(text)
|
||||
if z.ndim==2:
|
||||
if z.ndim == 2:
|
||||
z = z[:, None, :]
|
||||
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
||||
return z
|
||||
@ -200,29 +496,42 @@ class FrozenCLIPTextEmbedder(nn.Module):
|
||||
|
||||
class FrozenClipImageEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=False,
|
||||
):
|
||||
self,
|
||||
model,
|
||||
jit=False,
|
||||
device=choose_torch_device(),
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
x = kornia.geometry.resize(
|
||||
x,
|
||||
(224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
@ -232,7 +541,8 @@ class FrozenClipImageEmbedder(nn.Module):
|
||||
return self.model.encode_image(self.preprocess(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
from ldm.util import count_params
|
||||
|
||||
model = FrozenCLIPEmbedder()
|
||||
count_params(model, verbose=True)
|
||||
count_params(model, verbose=True)
|
||||
|
170
ldm/modules/midas/api.py
Normal file
170
ldm/modules/midas/api.py
Normal file
@ -0,0 +1,170 @@
|
||||
# based on https://github.com/isl-org/MiDaS
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
||||
from ldm.modules.midas.midas.midas_net import MidasNet
|
||||
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
||||
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
||||
|
||||
|
||||
ISL_PATHS = {
|
||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
|
||||
"midas_v21": "",
|
||||
"midas_v21_small": "",
|
||||
}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def load_midas_transform(model_type):
|
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||
# load transform only
|
||||
if model_type == "dpt_large": # DPT-Large
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "midas_v21":
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
elif model_type == "midas_v21_small":
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
else:
|
||||
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def load_model(model_type):
|
||||
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
||||
# load network
|
||||
model_path = ISL_PATHS[model_type]
|
||||
if model_type == "dpt_large": # DPT-Large
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "midas_v21":
|
||||
model = MidasNet(model_path, non_negative=True)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
elif model_type == "midas_v21_small":
|
||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
||||
non_negative=True, blocks={'expand': True})
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
else:
|
||||
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
||||
assert False
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=True,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
return model.eval(), transform
|
||||
|
||||
|
||||
class MiDaSInference(nn.Module):
|
||||
MODEL_TYPES_TORCH_HUB = [
|
||||
"DPT_Large",
|
||||
"DPT_Hybrid",
|
||||
"MiDaS_small"
|
||||
]
|
||||
MODEL_TYPES_ISL = [
|
||||
"dpt_large",
|
||||
"dpt_hybrid",
|
||||
"midas_v21",
|
||||
"midas_v21_small",
|
||||
]
|
||||
|
||||
def __init__(self, model_type):
|
||||
super().__init__()
|
||||
assert (model_type in self.MODEL_TYPES_ISL)
|
||||
model, _ = load_model(model_type)
|
||||
self.model = model
|
||||
self.model.train = disabled_train
|
||||
|
||||
def forward(self, x):
|
||||
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
|
||||
# NOTE: we expect that the correct transform has been called during dataloading.
|
||||
with torch.no_grad():
|
||||
prediction = self.model(x)
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=x.shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
return prediction
|
||||
|
0
ldm/modules/midas/midas/__init__.py
Normal file
0
ldm/modules/midas/midas/__init__.py
Normal file
16
ldm/modules/midas/midas/base_model.py
Normal file
16
ldm/modules/midas/midas/base_model.py
Normal file
@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def load(self, path):
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
parameters = torch.load(path, map_location=torch.device('cpu'))
|
||||
|
||||
if "optimizer" in parameters:
|
||||
parameters = parameters["model"]
|
||||
|
||||
self.load_state_dict(parameters)
|
342
ldm/modules/midas/midas/blocks.py
Normal file
342
ldm/modules/midas/midas/blocks.py
Normal file
@ -0,0 +1,342 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .vit import (
|
||||
_make_pretrained_vitb_rn50_384,
|
||||
_make_pretrained_vitl16_384,
|
||||
_make_pretrained_vitb16_384,
|
||||
forward_vit,
|
||||
)
|
||||
|
||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
||||
if backbone == "vitl16_384":
|
||||
pretrained = _make_pretrained_vitl16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||
) # ViT-L/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb_rn50_384":
|
||||
pretrained = _make_pretrained_vitb_rn50_384(
|
||||
use_pretrained,
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 768, 768], features, groups=groups, expand=expand
|
||||
) # ViT-H/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb16_384":
|
||||
pretrained = _make_pretrained_vitb16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||
) # ViT-B/16 - 84.6% Top1 (backbone)
|
||||
elif backbone == "resnext101_wsl":
|
||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
elif backbone == "efficientnet_lite3":
|
||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
else:
|
||||
print(f"Backbone '{backbone}' not implemented")
|
||||
assert False
|
||||
|
||||
return pretrained, scratch
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
out_shape4 = out_shape
|
||||
if expand==True:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape*2
|
||||
out_shape3 = out_shape*4
|
||||
out_shape4 = out_shape*8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
||||
efficientnet = torch.hub.load(
|
||||
"rwightman/gen-efficientnet-pytorch",
|
||||
"tf_efficientnet_lite3",
|
||||
pretrained=use_pretrained,
|
||||
exportable=exportable
|
||||
)
|
||||
return _make_efficientnet_backbone(efficientnet)
|
||||
|
||||
|
||||
def _make_efficientnet_backbone(effnet):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
||||
)
|
||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_resnet_backbone(resnet):
|
||||
pretrained = nn.Module()
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
||||
)
|
||||
|
||||
pretrained.layer2 = resnet.layer2
|
||||
pretrained.layer3 = resnet.layer3
|
||||
pretrained.layer4 = resnet.layer4
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
||||
return _make_resnet_backbone(resnet)
|
||||
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module.
|
||||
"""
|
||||
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
|
||||
x = self.interp(
|
||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
out = self.relu(x)
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features)
|
||||
self.resConfUnit2 = ResidualConvUnit(features)
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output += self.resConfUnit1(xs[1])
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class ResidualConvUnit_custom(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
if self.bn==True:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn==True:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn==True:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
# return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock_custom(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_custom, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand==True:
|
||||
out_features = features//2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
# output += res
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
109
ldm/modules/midas/midas/dpt_depth.py
Normal file
109
ldm/modules/midas/midas/dpt_depth.py
Normal file
@ -0,0 +1,109 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import (
|
||||
FeatureFusionBlock,
|
||||
FeatureFusionBlock_custom,
|
||||
Interpolate,
|
||||
_make_encoder,
|
||||
forward_vit,
|
||||
)
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn):
|
||||
return FeatureFusionBlock_custom(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
|
||||
class DPT(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
head,
|
||||
features=256,
|
||||
backbone="vitb_rn50_384",
|
||||
readout="project",
|
||||
channels_last=False,
|
||||
use_bn=False,
|
||||
):
|
||||
|
||||
super(DPT, self).__init__()
|
||||
|
||||
self.channels_last = channels_last
|
||||
|
||||
hooks = {
|
||||
"vitb_rn50_384": [0, 1, 8, 11],
|
||||
"vitb16_384": [2, 5, 8, 11],
|
||||
"vitl16_384": [5, 11, 17, 23],
|
||||
}
|
||||
|
||||
# Instantiate backbone and reassemble blocks
|
||||
self.pretrained, self.scratch = _make_encoder(
|
||||
backbone,
|
||||
features,
|
||||
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
||||
groups=1,
|
||||
expand=False,
|
||||
exportable=False,
|
||||
hooks=hooks[backbone],
|
||||
use_readout=readout,
|
||||
)
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
self.scratch.output_conv = head
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last == True:
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPTDepthModel(DPT):
|
||||
def __init__(self, path=None, non_negative=True, **kwargs):
|
||||
features = kwargs["features"] if "features" in kwargs else 256
|
||||
|
||||
head = nn.Sequential(
|
||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
super().__init__(head, **kwargs)
|
||||
|
||||
if path is not None:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x).squeeze(dim=1)
|
||||
|
76
ldm/modules/midas/midas/midas_net.py
Normal file
76
ldm/modules/midas/midas/midas_net.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=256, non_negative=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print("Loading weights: ", path)
|
||||
|
||||
super(MidasNet, self).__init__()
|
||||
|
||||
use_pretrained = False if path is None else True
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear"),
|
||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
128
ldm/modules/midas/midas/midas_net_custom.py
Normal file
128
ldm/modules/midas/midas/midas_net_custom.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet_small(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
||||
blocks={'expand': True}):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print("Loading weights: ", path)
|
||||
|
||||
super(MidasNet_small, self).__init__()
|
||||
|
||||
use_pretrained = False if path else True
|
||||
|
||||
self.channels_last = channels_last
|
||||
self.blocks = blocks
|
||||
self.backbone = backbone
|
||||
|
||||
self.groups = 1
|
||||
|
||||
features1=features
|
||||
features2=features
|
||||
features3=features
|
||||
features4=features
|
||||
self.expand = False
|
||||
if "expand" in self.blocks and self.blocks['expand'] == True:
|
||||
self.expand = True
|
||||
features1=features
|
||||
features2=features*2
|
||||
features3=features*4
|
||||
features4=features*8
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
||||
|
||||
self.scratch.activation = nn.ReLU(False)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
||||
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
||||
Interpolate(scale_factor=2, mode="bilinear"),
|
||||
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
||||
self.scratch.activation,
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
if self.channels_last==True:
|
||||
print("self.channels_last = ", self.channels_last)
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
|
||||
|
||||
|
||||
def fuse_model(m):
|
||||
prev_previous_type = nn.Identity()
|
||||
prev_previous_name = ''
|
||||
previous_type = nn.Identity()
|
||||
previous_name = ''
|
||||
for name, module in m.named_modules():
|
||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", prev_previous_name, previous_name, name)
|
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
||||
# print("FUSED ", prev_previous_name, previous_name)
|
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", previous_name, name)
|
||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
||||
|
||||
prev_previous_type = previous_type
|
||||
prev_previous_name = previous_name
|
||||
previous_type = type(module)
|
||||
previous_name = name
|
234
ldm/modules/midas/midas/transforms.py
Normal file
234
ldm/modules/midas/midas/transforms.py
Normal file
@ -0,0 +1,234 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
||||
)
|
||||
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(
|
||||
f"resize_method {self.__resize_method} not implemented"
|
||||
)
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(
|
||||
scale_height * height, min_val=self.__height
|
||||
)
|
||||
new_width = self.constrain_to_multiple_of(
|
||||
scale_width * width, min_val=self.__width
|
||||
)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(
|
||||
scale_height * height, max_val=self.__height
|
||||
)
|
||||
new_width = self.constrain_to_multiple_of(
|
||||
scale_width * width, max_val=self.__width
|
||||
)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(
|
||||
sample["image"].shape[1], sample["image"].shape[0]
|
||||
)
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(
|
||||
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "disparity" in sample:
|
||||
disparity = sample["disparity"].astype(np.float32)
|
||||
sample["disparity"] = np.ascontiguousarray(disparity)
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
return sample
|
491
ldm/modules/midas/midas/vit.py
Normal file
491
ldm/modules/midas/midas/vit.py
Normal file
@ -0,0 +1,491 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
import types
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Slice(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(Slice, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, self.start_index :]
|
||||
|
||||
|
||||
class AddReadout(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(AddReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
if self.start_index == 2:
|
||||
readout = (x[:, 0] + x[:, 1]) / 2
|
||||
else:
|
||||
readout = x[:, 0]
|
||||
return x[:, self.start_index :] + readout.unsqueeze(1)
|
||||
|
||||
|
||||
class ProjectReadout(nn.Module):
|
||||
def __init__(self, in_features, start_index=1):
|
||||
super(ProjectReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
||||
features = torch.cat((x[:, self.start_index :], readout), -1)
|
||||
|
||||
return self.project(features)
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim0, dim1):
|
||||
super(Transpose, self).__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
|
||||
def forward_vit(pretrained, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
glob = pretrained.model.forward_flex(x)
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
layer_3 = pretrained.activations["3"]
|
||||
layer_4 = pretrained.activations["4"]
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
||||
|
||||
unflatten = nn.Sequential(
|
||||
nn.Unflatten(
|
||||
2,
|
||||
torch.Size(
|
||||
[
|
||||
h // pretrained.model.patch_size[1],
|
||||
w // pretrained.model.patch_size[0],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if layer_1.ndim == 3:
|
||||
layer_1 = unflatten(layer_1)
|
||||
if layer_2.ndim == 3:
|
||||
layer_2 = unflatten(layer_2)
|
||||
if layer_3.ndim == 3:
|
||||
layer_3 = unflatten(layer_3)
|
||||
if layer_4.ndim == 3:
|
||||
layer_4 = unflatten(layer_4)
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
||||
|
||||
return layer_1, layer_2, layer_3, layer_4
|
||||
|
||||
|
||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
||||
posemb_tok, posemb_grid = (
|
||||
posemb[:, : self.start_index],
|
||||
posemb[0, self.start_index :],
|
||||
)
|
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
||||
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
|
||||
return posemb
|
||||
|
||||
|
||||
def forward_flex(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
pos_embed = self._resize_pos_embed(
|
||||
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
||||
)
|
||||
|
||||
B = x.shape[0]
|
||||
|
||||
if hasattr(self.patch_embed, "backbone"):
|
||||
x = self.patch_embed.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
|
||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
if getattr(self, "dist_token", None) is not None:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
else:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
activations = {}
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activations[name] = output
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
||||
if use_readout == "ignore":
|
||||
readout_oper = [Slice(start_index)] * len(features)
|
||||
elif use_readout == "add":
|
||||
readout_oper = [AddReadout(start_index)] * len(features)
|
||||
elif use_readout == "project":
|
||||
readout_oper = [
|
||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
||||
]
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
||||
|
||||
return readout_oper
|
||||
|
||||
|
||||
def _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[2, 5, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
# 32, 48, 136, 384
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[256, 512, 1024, 1024],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model(
|
||||
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
||||
)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout,
|
||||
start_index=2,
|
||||
)
|
||||
|
||||
|
||||
def _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=[0, 1, 8, 11],
|
||||
vit_features=768,
|
||||
use_vit_only=False,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
|
||||
if use_vit_only == True:
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
else:
|
||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
||||
get_activation("1")
|
||||
)
|
||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
||||
get_activation("2")
|
||||
)
|
||||
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
if use_vit_only == True:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
nn.Identity(), nn.Identity(), nn.Identity()
|
||||
)
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
nn.Identity(), nn.Identity(), nn.Identity()
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitb_rn50_384(
|
||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
||||
):
|
||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
||||
|
||||
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
189
ldm/modules/midas/utils.py
Normal file
189
ldm/modules/midas/utils.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""Utils for monoDepth."""
|
||||
import sys
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
|
||||
def read_pfm(path):
|
||||
"""Read pfm file.
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
tuple: (data, scale)
|
||||
"""
|
||||
with open(path, "rb") as file:
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header.decode("ascii") == "PF":
|
||||
color = True
|
||||
elif header.decode("ascii") == "Pf":
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Not a PFM file: " + path)
|
||||
|
||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
||||
if dim_match:
|
||||
width, height = list(map(int, dim_match.groups()))
|
||||
else:
|
||||
raise Exception("Malformed PFM header.")
|
||||
|
||||
scale = float(file.readline().decode("ascii").rstrip())
|
||||
if scale < 0:
|
||||
# little-endian
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
# big-endian
|
||||
endian = ">"
|
||||
|
||||
data = np.fromfile(file, endian + "f")
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
|
||||
return data, scale
|
||||
|
||||
|
||||
def write_pfm(path, image, scale=1):
|
||||
"""Write pfm file.
|
||||
|
||||
Args:
|
||||
path (str): pathto file
|
||||
image (array): data
|
||||
scale (int, optional): Scale. Defaults to 1.
|
||||
"""
|
||||
|
||||
with open(path, "wb") as file:
|
||||
color = None
|
||||
|
||||
if image.dtype.name != "float32":
|
||||
raise Exception("Image dtype must be float32.")
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif (
|
||||
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
||||
): # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||
|
||||
file.write("PF\n" if color else "Pf\n".encode())
|
||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||
scale = -scale
|
||||
|
||||
file.write("%f\n".encode() % scale)
|
||||
|
||||
image.tofile(file)
|
||||
|
||||
|
||||
def read_image(path):
|
||||
"""Read image and output RGB image (0-1).
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
array: RGB image (0-1)
|
||||
"""
|
||||
img = cv2.imread(path)
|
||||
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def resize_image(img):
|
||||
"""Resize image and make it fit for network.
|
||||
|
||||
Args:
|
||||
img (array): image
|
||||
|
||||
Returns:
|
||||
tensor: data ready for network
|
||||
"""
|
||||
height_orig = img.shape[0]
|
||||
width_orig = img.shape[1]
|
||||
|
||||
if width_orig > height_orig:
|
||||
scale = width_orig / 384
|
||||
else:
|
||||
scale = height_orig / 384
|
||||
|
||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
||||
|
||||
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
img_resized = (
|
||||
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
||||
)
|
||||
img_resized = img_resized.unsqueeze(0)
|
||||
|
||||
return img_resized
|
||||
|
||||
|
||||
def resize_depth(depth, width, height):
|
||||
"""Resize depth map and bring to CPU (numpy).
|
||||
|
||||
Args:
|
||||
depth (tensor): depth
|
||||
width (int): image width
|
||||
height (int): image height
|
||||
|
||||
Returns:
|
||||
array: processed depth
|
||||
"""
|
||||
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
||||
|
||||
depth_resized = cv2.resize(
|
||||
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
|
||||
return depth_resized
|
||||
|
||||
def write_depth(path, depth, bits=1):
|
||||
"""Write depth map to pfm and png file.
|
||||
|
||||
Args:
|
||||
path (str): filepath without extension
|
||||
depth (array): depth
|
||||
"""
|
||||
write_pfm(path + ".pfm", depth.astype(np.float32))
|
||||
|
||||
depth_min = depth.min()
|
||||
depth_max = depth.max()
|
||||
|
||||
max_val = (2**(8*bits))-1
|
||||
|
||||
if depth_max - depth_min > np.finfo("float").eps:
|
||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
||||
else:
|
||||
out = np.zeros(depth.shape, dtype=depth.type)
|
||||
|
||||
if bits == 1:
|
||||
cv2.imwrite(path + ".png", out.astype("uint8"))
|
||||
elif bits == 2:
|
||||
cv2.imwrite(path + ".png", out.astype("uint16"))
|
||||
|
||||
return
|
@ -11,15 +11,13 @@ from einops import rearrange, repeat, reduce
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
'pre_softmax_attn',
|
||||
'post_softmax_attn'
|
||||
])
|
||||
Intermediates = namedtuple(
|
||||
'Intermediates', ['pre_softmax_attn', 'post_softmax_attn']
|
||||
)
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
])
|
||||
LayerIntermediates = namedtuple(
|
||||
'Intermediates', ['hiddens', 'attn_intermediates']
|
||||
)
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
@ -39,11 +37,16 @@ class AbsolutePositionalEmbedding(nn.Module):
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
t = (
|
||||
torch.arange(x.shape[seq_dim], device=x.device).type_as(
|
||||
self.inv_freq
|
||||
)
|
||||
+ offset
|
||||
)
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
@ -51,6 +54,7 @@ class FixedPositionalEmbedding(nn.Module):
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@ -64,18 +68,21 @@ def default(val, d):
|
||||
def always(val):
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
def inner(x):
|
||||
return x != val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
def inner(x):
|
||||
return x == val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@ -85,6 +92,7 @@ def max_neg_value(tensor):
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
@ -108,8 +116,15 @@ def group_by_key_prefix(prefix, d):
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(
|
||||
partial(string_begins_with, prefix), d
|
||||
)
|
||||
kwargs_without_prefix = dict(
|
||||
map(
|
||||
lambda x: (x[0][len(prefix) :], x[1]),
|
||||
tuple(kwargs_with_prefix.items()),
|
||||
)
|
||||
)
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
@ -139,7 +154,7 @@ class Rezero(nn.Module):
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
@ -151,7 +166,7 @@ class ScaleNorm(nn.Module):
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
@ -173,7 +188,7 @@ class GRUGating(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d')
|
||||
rearrange(residual, 'b n d -> (b n) d'),
|
||||
)
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
@ -181,6 +196,7 @@ class GRUGating(nn.Module):
|
||||
|
||||
# feedforward
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
@ -192,19 +208,18 @@ class GEGLU(nn.Module):
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -214,23 +229,25 @@ class FeedForward(nn.Module):
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.0,
|
||||
on_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
||||
self.scale = dim_head ** -0.5
|
||||
raise NotImplementedError(
|
||||
'Check out entmax activation instead of softmax activation!'
|
||||
)
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
@ -252,7 +269,7 @@ class Attention(nn.Module):
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
# self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
@ -263,20 +280,29 @@ class Attention(nn.Module):
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
self.to_out = (
|
||||
nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())
|
||||
if on_attn
|
||||
else nn.Linear(inner_dim, dim)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None,
|
||||
):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
b, n, _, h, talking_heads, device = (
|
||||
*x.shape,
|
||||
self.heads,
|
||||
self.talking_heads,
|
||||
x.device,
|
||||
)
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
@ -297,23 +323,35 @@ class Attention(nn.Module):
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)
|
||||
)
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
q_mask = default(
|
||||
mask, lambda: torch.ones((b, n), device=device).bool()
|
||||
)
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
||||
k_mask = default(
|
||||
k_mask,
|
||||
lambda: torch.ones((b, k.shape[-2]), device=device).bool(),
|
||||
)
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
||||
mem_k, mem_v = map(
|
||||
lambda t: repeat(t, 'h n d -> b h n d', b=b),
|
||||
(self.mem_k, self.mem_v),
|
||||
)
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
input_mask = F.pad(
|
||||
input_mask, (self.num_mem_kv, 0), value=True
|
||||
)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
@ -324,7 +362,9 @@ class Attention(nn.Module):
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
dots = einsum(
|
||||
'b h i j, h k -> b k i j', dots, self.pre_softmax_proj
|
||||
).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
@ -336,7 +376,9 @@ class Attention(nn.Module):
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(
|
||||
r, 'j -> () () () j'
|
||||
)
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
@ -354,14 +396,16 @@ class Attention(nn.Module):
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
||||
attn = einsum(
|
||||
'b h i j, h k -> b k i j', attn, self.post_softmax_proj
|
||||
).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn
|
||||
post_softmax_attn=post_softmax_attn,
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
@ -369,28 +413,28 @@ class Attention(nn.Module):
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
@ -403,10 +447,14 @@ class AttentionLayers(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.pia_pos_emb = (
|
||||
FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
)
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
assert (
|
||||
rel_pos_num_buckets <= rel_pos_max_distance
|
||||
), 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
@ -438,15 +486,27 @@ class AttentionLayers(nn.Module):
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
depth_cut = (
|
||||
par_depth * 2 // 3
|
||||
) # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
assert (
|
||||
len(default_block) <= par_width
|
||||
), 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (
|
||||
par_width - len(default_block)
|
||||
)
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
assert (
|
||||
sandwich_coef > 0 and sandwich_coef <= depth
|
||||
), 'sandwich coefficient should be less than the depth'
|
||||
layer_types = (
|
||||
('a',) * sandwich_coef
|
||||
+ default_block * (depth - sandwich_coef)
|
||||
+ ('f',) * sandwich_coef
|
||||
)
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
@ -455,7 +515,9 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
layer = Attention(
|
||||
dim, heads=heads, causal=causal, **attn_kwargs
|
||||
)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
@ -472,20 +534,17 @@ class AttentionLayers(nn.Module):
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
norm_fn(),
|
||||
layer,
|
||||
residual_fn
|
||||
]))
|
||||
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False,
|
||||
**kwargs,
|
||||
):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
@ -494,7 +553,9 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
|
||||
zip(self.layer_types, self.layers)
|
||||
):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
@ -507,10 +568,22 @@ class AttentionLayers(nn.Module):
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn, mem=layer_mem)
|
||||
out, inter = block(
|
||||
x,
|
||||
mask=mask,
|
||||
sinusoidal_emb=self.pia_pos_emb,
|
||||
rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn,
|
||||
mem=layer_mem,
|
||||
)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
out, inter = block(
|
||||
x,
|
||||
context=context,
|
||||
mask=mask,
|
||||
context_mask=context_mask,
|
||||
prev_attn=prev_cross_attn,
|
||||
)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
@ -529,8 +602,7 @@ class AttentionLayers(nn.Module):
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
hiddens=hiddens, attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
@ -544,23 +616,24 @@ class Encoder(AttentionLayers):
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.0,
|
||||
emb_dropout=0.0,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
||||
assert isinstance(
|
||||
attn_layers, AttentionLayers
|
||||
), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
@ -570,23 +643,34 @@ class TransformerWrapper(nn.Module):
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.pos_emb = (
|
||||
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
|
||||
if (use_pos_emb and not attn_layers.has_pos_emb)
|
||||
else always(0)
|
||||
)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.project_emb = (
|
||||
nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
)
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
self.to_logits = (
|
||||
nn.Linear(dim, num_tokens)
|
||||
if not tie_embedding
|
||||
else lambda t: t @ self.token_emb.weight.t()
|
||||
)
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
self.memory_tokens = nn.Parameter(
|
||||
torch.randn(num_memory_tokens, dim)
|
||||
)
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
@ -596,18 +680,26 @@ class TransformerWrapper(nn.Module):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
embedding_manager=None,
|
||||
**kwargs,
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
|
||||
embedded_x = self.token_emb(x)
|
||||
|
||||
if embedding_manager:
|
||||
x = embedding_manager(x, embedded_x)
|
||||
else:
|
||||
x = embedded_x
|
||||
|
||||
x = x + self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
@ -620,7 +712,9 @@ class TransformerWrapper(nn.Module):
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x, intermediates = self.attn_layers(
|
||||
x, mask=mask, mems=mems, return_hiddens=True, **kwargs
|
||||
)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
@ -629,13 +723,30 @@ class TransformerWrapper(nn.Module):
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
new_mems = (
|
||||
list(
|
||||
map(
|
||||
lambda pair: torch.cat(pair, dim=-2),
|
||||
zip(mems, hiddens),
|
||||
)
|
||||
)
|
||||
if exists(mems)
|
||||
else hiddens
|
||||
)
|
||||
new_mems = list(
|
||||
map(
|
||||
lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems
|
||||
)
|
||||
)
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
attn_maps = list(
|
||||
map(
|
||||
lambda t: t.post_softmax_attn,
|
||||
intermediates.attn_intermediates,
|
||||
)
|
||||
)
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
|
||||
|
59
ldm/util.py
59
ldm/util.py
@ -20,16 +20,18 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
txt = Image.new('RGB', wh, color='white')
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
font = ImageFont.load_default()
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
lines = '\n'.join(
|
||||
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
||||
)
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
draw.text((0, 0), lines, fill='black', font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
print('Cant encode string for logging. Skipping.')
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@ -71,22 +73,26 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if not 'target' in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
elif config == '__is_unconditional__':
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
raise KeyError('Expected key `target` to instantiate.')
|
||||
return get_obj_from_str(config['target'])(
|
||||
**config.get('params', dict()), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
module, cls = string.rsplit('.', 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
@ -102,31 +108,36 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put("Done")
|
||||
Q.put('Done')
|
||||
|
||||
|
||||
def parallel_data_prefetch(
|
||||
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
|
||||
func: callable,
|
||||
data,
|
||||
n_proc,
|
||||
target_data_type='ndarray',
|
||||
cpu_intensive=True,
|
||||
use_worker_id=False,
|
||||
):
|
||||
# if target_data_type not in ["ndarray", "list"]:
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
if isinstance(data, np.ndarray) and target_data_type == 'list':
|
||||
raise ValueError('list expected but function got ndarray.')
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
print(
|
||||
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
if target_data_type == 'ndarray':
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
@ -136,7 +147,7 @@ def parallel_data_prefetch(
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
if target_data_type == 'ndarray':
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(np.array_split(data, n_proc))
|
||||
@ -150,7 +161,7 @@ def parallel_data_prefetch(
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(
|
||||
[data[i: i + step] for i in range(0, len(data), step)]
|
||||
[data[i : i + step] for i in range(0, len(data), step)]
|
||||
)
|
||||
]
|
||||
processes = []
|
||||
@ -159,7 +170,7 @@ def parallel_data_prefetch(
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
print(f"Start prefetching...")
|
||||
print(f'Start prefetching...')
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
@ -172,13 +183,13 @@ def parallel_data_prefetch(
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == "Done":
|
||||
if res == 'Done':
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e)
|
||||
print('Exception: ', e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
@ -186,7 +197,7 @@ def parallel_data_prefetch(
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
print(f'Prefetching complete. [{time.time() - start} sec.]')
|
||||
|
||||
if target_data_type == 'ndarray':
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
|
@ -6,7 +6,8 @@ https://github.com/CompVis/taming-transformers
|
||||
-- merci
|
||||
"""
|
||||
|
||||
import time
|
||||
import time, math
|
||||
from tqdm.auto import trange, tqdm
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
@ -21,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
|
||||
from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
|
||||
|
||||
def disabled_train(self):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
@ -92,7 +93,6 @@ class DDPM(pl.LightningModule):
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@ -104,7 +104,6 @@ class DDPM(pl.LightningModule):
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
|
||||
class FirstStage(DDPM):
|
||||
@ -403,7 +402,7 @@ class UNet(DDPM):
|
||||
h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
|
||||
bs = cond.shape[0]
|
||||
|
||||
assert bs%2 == 0
|
||||
# assert bs%2 == 0
|
||||
lenhs = len(hs)
|
||||
|
||||
for i in range(step,bs,step):
|
||||
@ -446,15 +445,14 @@ class UNet(DDPM):
|
||||
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
|
||||
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
|
||||
to_torch = lambda x: x.to(self.cdevice)
|
||||
|
||||
self.register_buffer1('betas', to_torch(self.betas))
|
||||
self.register_buffer1('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer1('alphas_cumprod_prev', to_torch(self.alphas_cumprod_prev))
|
||||
self.register_buffer1('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
|
||||
self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod))
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
@ -463,25 +461,21 @@ class UNet(DDPM):
|
||||
self.register_buffer1('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
self.ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer1('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
seed,
|
||||
conditioning=None,
|
||||
conditioning,
|
||||
x0=None,
|
||||
shape = None,
|
||||
seed=1234,
|
||||
callback=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
sampler = "plms",
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
@ -492,41 +486,74 @@ class UNet(DDPM):
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to(self.cdevice)
|
||||
self.model2.to(self.cdevice)
|
||||
|
||||
samples = self.plms_sampling(conditioning, size, seed,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
if x0 is None:
|
||||
batch_size, b1, b2, b3 = shape
|
||||
img_shape = (1, b1, b2, b3)
|
||||
tens = []
|
||||
print("seeds used = ", [seed+s for s in range(batch_size)])
|
||||
for _ in range(batch_size):
|
||||
torch.manual_seed(seed)
|
||||
tens.append(torch.randn(img_shape, device=self.cdevice))
|
||||
seed+=1
|
||||
noise = torch.cat(tens)
|
||||
del tens
|
||||
|
||||
x_latent = noise if x0 is None else x0
|
||||
# sampling
|
||||
|
||||
if sampler == "plms":
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
print(f'Data shape for PLMS sampling is {shape}')
|
||||
samples = self.plms_sampling(conditioning, batch_size, x_latent,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
|
||||
elif sampler == "ddim":
|
||||
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask = mask,init_latent=x_T,use_original_steps=False)
|
||||
|
||||
elif sampler == "euler":
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
elif sampler == "euler_a":
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
elif sampler == "dpm2":
|
||||
samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
elif sampler == "heun":
|
||||
samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
elif sampler == "dpm2_a":
|
||||
samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
|
||||
elif sampler == "lms":
|
||||
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to("cpu")
|
||||
@ -535,36 +562,17 @@ class UNet(DDPM):
|
||||
return samples
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape, seed,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
def plms_sampling(self, cond,b, img,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
|
||||
device = self.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
_, b1, b2, b3 = shape
|
||||
img_shape = (1, b1, b2, b3)
|
||||
tens = []
|
||||
print("seeds used = ", [seed+s for s in range(b)])
|
||||
for _ in range(b):
|
||||
torch.manual_seed(seed)
|
||||
tens.append(torch.randn(img_shape, device=device))
|
||||
seed+=1
|
||||
img = torch.cat(tens)
|
||||
del tens
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
timesteps = self.ddim_timesteps
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
@ -618,10 +626,10 @@ class UNet(DDPM):
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
alphas = self.ddim_alphas
|
||||
alphas_prev = self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
@ -664,17 +672,11 @@ class UNet(DDPM):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None, mask=None):
|
||||
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
|
||||
if noise is None:
|
||||
b0, b1, b2, b3 = x0.shape
|
||||
@ -687,50 +689,53 @@ class UNet(DDPM):
|
||||
seed+=1
|
||||
noise = torch.cat(tens)
|
||||
del tens
|
||||
if mask is not None:
|
||||
noise = noise*mask
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(self.cdevice), t, x0.shape) * noise)
|
||||
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
mask = None,use_original_steps=False):
|
||||
def add_noise(self, x0, t):
|
||||
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to(self.cdevice)
|
||||
self.model2.to(self.cdevice)
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
noise = torch.randn(x0.shape, device=x0.device)
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
# print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape),
|
||||
# extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape))
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
mask = None,init_latent=None,use_original_steps=False):
|
||||
|
||||
timesteps = self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
# x0 = x_latent
|
||||
x0 = init_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
|
||||
# if mask is not None:
|
||||
# x_dec = x0 * mask + (1. - mask) * x_dec
|
||||
if mask is not None:
|
||||
# x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
|
||||
x0_noisy = x0
|
||||
x_dec = x0_noisy* mask + (1. - mask) * x_dec
|
||||
|
||||
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
# if mask is not None:
|
||||
# return x0 * mask + (1. - mask) * x_dec
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to("cpu")
|
||||
self.model2.to("cpu")
|
||||
if mask is not None:
|
||||
return x0 * mask + (1. - mask) * x_dec
|
||||
|
||||
return x_dec
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
@ -743,7 +748,6 @@ class UNet(DDPM):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
# print("xin shape = ", x_in.shape)
|
||||
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
@ -751,10 +755,10 @@ class UNet(DDPM):
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
alphas = self.ddim_alphas
|
||||
alphas_prev = self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
@ -771,4 +775,256 @@ class UNet(DDPM):
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev
|
||||
return x_prev
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = (sigmas[i] * (gamma + 1)).half()
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
|
||||
s_i = sigma_hat * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
s_i = sigmas[i] * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
cvd = CompVisDenoiser(alphas_cumprod=ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = (sigmas[i] * (gamma + 1)).half()
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
|
||||
s_i = sigma_hat * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
s_i = sigmas[i + 1] * s_in
|
||||
x_in = torch.cat([x_2] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
|
||||
s_i = sigma_hat * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
|
||||
s_i = sigma_mid * s_in
|
||||
x_in = torch.cat([x_2] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
s_i = sigmas[i] * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
|
||||
s_i = sigma_mid * s_in
|
||||
x_in = torch.cat([x_2] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
s_i = sigmas[i] * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
13
optimizedSD/diffusers_txt2img.py
Normal file
13
optimizedSD/diffusers_txt2img.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
from diffusers import LDMTextToImagePipeline
|
||||
|
||||
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
|
||||
|
||||
prompt = "19th Century wooden engraving of Elon musk"
|
||||
|
||||
seed = torch.manual_seed(1024)
|
||||
images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"]
|
||||
|
||||
# save images
|
||||
for idx, image in enumerate(images):
|
||||
image.save(f"image-{idx}.png")
|
@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import (
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from .splitAttention import SpatialTransformer
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
362
optimizedSD/optimized_img2img.py
Normal file
362
optimizedSD/optimized_img2img.py
Normal file
@ -0,0 +1,362 @@
|
||||
import argparse, os, re
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import instantiate_from_config
|
||||
from optimUtils import split_weighted_subprompts, logger
|
||||
from transformers import logging
|
||||
import pandas as pd
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
return sd
|
||||
|
||||
|
||||
def load_img(path, h0, w0):
|
||||
|
||||
image = Image.open(path).convert("RGB")
|
||||
w, h = image.size
|
||||
|
||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||
if h0 is not None and w0 is not None:
|
||||
h, w = h0, w0
|
||||
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
print(f"New image size ({w}, {h})")
|
||||
image = image.resize((w, h), resample=Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
config = "optimizedSD/v1-inference.yaml"
|
||||
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
|
||||
)
|
||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples")
|
||||
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action="store_true",
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action="store_true",
|
||||
help="do not save individual samples. For speed measurements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=1,
|
||||
help="sample this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=int,
|
||||
default=None,
|
||||
help="image height, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W",
|
||||
type=int,
|
||||
default=None,
|
||||
help="image width, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="CPU or GPU (cuda/cuda:0/cuda:1/...)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--turbo",
|
||||
action="store_true",
|
||||
help="Reduces inference time on the expense of 1GB VRAM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
help="output image format",
|
||||
choices=["jpg", "png"],
|
||||
default="png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
type=str,
|
||||
help="sampler",
|
||||
choices=["ddim"],
|
||||
default="ddim",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
tic = time.time()
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
if opt.seed == None:
|
||||
opt.seed = randint(0, 1000000)
|
||||
seed_everything(opt.seed)
|
||||
|
||||
# Logging
|
||||
logger(vars(opt), log_csv = "logs/img2img_logs.csv")
|
||||
|
||||
sd = load_model_from_config(f"{ckpt}")
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
if (sp[0]) == "model":
|
||||
if "input_blocks" in sp:
|
||||
li.append(key)
|
||||
elif "middle_block" in sp:
|
||||
li.append(key)
|
||||
elif "time_embed" in sp:
|
||||
li.append(key)
|
||||
else:
|
||||
lo.append(key)
|
||||
for key in li:
|
||||
sd["model1." + key[6:]] = sd.pop(key)
|
||||
for key in lo:
|
||||
sd["model2." + key[6:]] = sd.pop(key)
|
||||
|
||||
config = OmegaConf.load(f"{config}")
|
||||
|
||||
assert os.path.isfile(opt.init_img)
|
||||
init_image = load_img(opt.init_img, opt.H, opt.W).to(opt.device)
|
||||
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
model.cdevice = opt.device
|
||||
model.unet_bs = opt.unet_bs
|
||||
model.turbo = opt.turbo
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
modelCS.cond_stage_model.device = opt.device
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
del sd
|
||||
if opt.device != "cpu" and opt.precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
modelFS.half()
|
||||
init_image = init_image.half()
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
assert opt.prompt is not None
|
||||
prompt = opt.prompt
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = batch_size * list(data)
|
||||
data = list(chunk(sorted(data), batch_size))
|
||||
|
||||
modelFS.to(opt.device)
|
||||
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
|
||||
t_enc = int(opt.strength * opt.ddim_steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
|
||||
if opt.precision == "autocast" and opt.device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
seeds = ""
|
||||
with torch.no_grad():
|
||||
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(opt.device)
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
subprompts, weights = split_weighted_subprompts(prompts[0])
|
||||
if len(subprompts) > 1:
|
||||
c = torch.zeros_like(uc)
|
||||
totalWeight = sum(weights)
|
||||
# normalize each "sub prompt" and add it
|
||||
for i in range(len(subprompts)):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
torch.tensor([t_enc] * batch_size).to(opt.device),
|
||||
opt.seed,
|
||||
opt.ddim_eta,
|
||||
opt.ddim_steps,
|
||||
)
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
t_enc,
|
||||
c,
|
||||
z_enc,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
sampler = opt.sampler
|
||||
)
|
||||
|
||||
modelFS.to(opt.device)
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
|
||||
)
|
||||
seeds += str(opt.seed) + ","
|
||||
opt.seed += 1
|
||||
base_count += 1
|
||||
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
del samples_ddim
|
||||
print("memory_final = ", torch.cuda.memory_allocated(device=opt.device) / 1e6)
|
||||
|
||||
toc = time.time()
|
||||
|
||||
time_taken = (toc - tic) / 60.0
|
||||
|
||||
print(
|
||||
(
|
||||
"Samples finished in {0:.2f} minutes and exported to "
|
||||
+ sample_path
|
||||
+ "\n Seeds used = "
|
||||
+ seeds[:-1]
|
||||
).format(time_taken)
|
||||
)
|
@ -1,7 +1,7 @@
|
||||
import argparse, os, sys, glob, random
|
||||
import argparse, os, re
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from random import randint
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
@ -13,6 +13,10 @@ from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from ldm.util import instantiate_from_config
|
||||
from optimUtils import split_weighted_subprompts, logger
|
||||
from transformers import logging
|
||||
# from samplers import CompVisDenoiser
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
@ -30,33 +34,22 @@ def load_model_from_config(ckpt, verbose=False):
|
||||
|
||||
|
||||
config = "optimizedSD/v1-inference.yaml"
|
||||
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
device = "cuda"
|
||||
DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
"--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
|
||||
)
|
||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples")
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="do not save individual samples. For speed measurements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -68,7 +61,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument(
|
||||
"--fixed_code",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="if enabled, uses the same starting code across samples ",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -125,6 +118,12 @@ parser.add_argument(
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="specify GPU (cuda/cuda:0/cuda:1/...)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
@ -133,165 +132,216 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
default=None,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--small_batch",
|
||||
action='store_true',
|
||||
help="Reduce inference time when generate a smaller batch of images",
|
||||
"--unet_bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"--turbo",
|
||||
action="store_true",
|
||||
help="Reduces inference time on the expense of 1GB VRAM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
help="evaluate at this precision",
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
help="output image format",
|
||||
choices=["jpg", "png"],
|
||||
default="png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
type=str,
|
||||
help="sampler",
|
||||
choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"],
|
||||
default="plms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
help="path to checkpoint of model",
|
||||
default=DEFAULT_CKPT,
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
tic = time.time()
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
sample_path = os.path.join(outpath, "samples", "_".join(opt.prompt.split())[:255])
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
if opt.seed == None:
|
||||
opt.seed = randint(0, 1000000)
|
||||
seed_everything(opt.seed)
|
||||
|
||||
sd = load_model_from_config(f"{ckpt}")
|
||||
li = []
|
||||
lo = []
|
||||
# Logging
|
||||
logger(vars(opt), log_csv = "logs/txt2img_logs.csv")
|
||||
|
||||
sd = load_model_from_config(f"{opt.ckpt}")
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split('.')
|
||||
if(sp[0]) == 'model':
|
||||
if('input_blocks' in sp):
|
||||
sp = key.split(".")
|
||||
if (sp[0]) == "model":
|
||||
if "input_blocks" in sp:
|
||||
li.append(key)
|
||||
elif('middle_block' in sp):
|
||||
elif "middle_block" in sp:
|
||||
li.append(key)
|
||||
elif('time_embed' in sp):
|
||||
elif "time_embed" in sp:
|
||||
li.append(key)
|
||||
else:
|
||||
lo.append(key)
|
||||
for key in li:
|
||||
sd['model1.' + key[6:]] = sd.pop(key)
|
||||
sd["model1." + key[6:]] = sd.pop(key)
|
||||
for key in lo:
|
||||
sd['model2.' + key[6:]] = sd.pop(key)
|
||||
sd["model2." + key[6:]] = sd.pop(key)
|
||||
|
||||
config = OmegaConf.load(f"{config}")
|
||||
config.modelUNet.params.ddim_steps = opt.ddim_steps
|
||||
|
||||
if opt.small_batch:
|
||||
config.modelUNet.params.small_batch = True
|
||||
else:
|
||||
config.modelUNet.params.small_batch = False
|
||||
|
||||
|
||||
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
|
||||
model.unet_bs = opt.unet_bs
|
||||
model.cdevice = opt.device
|
||||
model.turbo = opt.turbo
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
|
||||
modelCS.cond_stage_model.device = opt.device
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
del sd
|
||||
|
||||
if opt.precision == "autocast":
|
||||
if opt.device != "cpu" and opt.precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
|
||||
start_code = None
|
||||
if opt.fixed_code:
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=opt.device)
|
||||
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
assert opt.prompt is not None
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
print(f"Using prompt: {prompt}")
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
text = f.read()
|
||||
print(f"Using prompt: {text.strip()}")
|
||||
data = text.splitlines()
|
||||
data = batch_size * list(data)
|
||||
data = list(chunk(sorted(data), batch_size))
|
||||
|
||||
|
||||
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||
if opt.precision == "autocast" and opt.device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
seeds = ""
|
||||
with torch.no_grad():
|
||||
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(device)
|
||||
|
||||
sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(opt.device)
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelCS.to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
subprompts, weights = split_weighted_subprompts(prompts[0])
|
||||
if len(subprompts) > 1:
|
||||
c = torch.zeros_like(uc)
|
||||
totalWeight = sum(weights)
|
||||
# normalize each "sub prompt" and add it
|
||||
for i in range(len(subprompts)):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
samples_ddim = model.sample(S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=opt.n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
x_T=start_code)
|
||||
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
|
||||
modelFS.to(device)
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
samples_ddim = model.sample(
|
||||
S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
seed=opt.seed,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
x_T=start_code,
|
||||
sampler = opt.sampler,
|
||||
)
|
||||
|
||||
modelFS.to(opt.device)
|
||||
|
||||
print(samples_ddim.shape)
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
|
||||
)
|
||||
seeds += str(opt.seed) + ","
|
||||
opt.seed += 1
|
||||
base_count += 1
|
||||
|
||||
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelFS.to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
# if not opt.skip_grid:
|
||||
# all_samples.append(x_samples_ddim)
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
del samples_ddim
|
||||
print("memory_final = ", torch.cuda.memory_allocated()/1e6)
|
||||
|
||||
# if not skip_grid:
|
||||
# # additionally, save as grid
|
||||
# grid = torch.stack(all_samples, 0)
|
||||
# grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
# grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# # to image
|
||||
# grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
# Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
# grid_count += 1
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
|
||||
toc = time.time()
|
||||
|
||||
time_taken = (toc-tic)/60.0
|
||||
time_taken = (toc - tic) / 60.0
|
||||
|
||||
print(("Your samples are ready in {0:.2f} minutes and waiting for you here \n" + sample_path).format(time_taken))
|
||||
print(
|
||||
(
|
||||
"Samples finished in {0:.2f} minutes and exported to "
|
||||
+ sample_path
|
||||
+ "\n Seeds used = "
|
||||
+ seeds[:-1]
|
||||
).format(time_taken)
|
||||
)
|
||||
|
252
optimizedSD/samplers.py
Normal file
252
optimizedSD/samplers.py
Normal file
@ -0,0 +1,252 @@
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from tqdm.auto import trange, tqdm
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.quantize = quantize
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
dists = torch.abs(sigma - self.sigmas[:, None])
|
||||
if quantize:
|
||||
return torch.argmin(dists, dim=0).view(sigma.shape)
|
||||
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
|
||||
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
|
||||
w = (low - sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
||||
# print(low_idx, high_idx, w )
|
||||
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
|
||||
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, alphas_cumprod, quantize=False, device='cpu'):
|
||||
super().__init__(alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
280
optimizedSD/splitAttention.py
Normal file
280
optimizedSD/splitAttention.py
Normal file
@ -0,0 +1,280 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.att_step = att_step
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
|
||||
limit = k.shape[0]
|
||||
att_step = self.att_step
|
||||
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
|
||||
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range (0, limit, att_step):
|
||||
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i+att_step,:,:] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
@ -29,7 +29,7 @@ streamlit==1.14.0
|
||||
streamlit-on-Hover-tabs==1.0.1
|
||||
streamlit-option-menu==0.3.2
|
||||
streamlit_nested_layout==0.1.1
|
||||
streamlit-server-state==0.14.2
|
||||
streamlit-server-state==0.15.0
|
||||
streamlit-tensorboard==0.0.2
|
||||
streamlit-elements==0.1.* # used for the draggable dashboard and new UI design (WIP)
|
||||
streamlit-ace==0.1.1 # used to replace the text area on the prompt and also for the code editor tool.
|
||||
@ -43,11 +43,12 @@ jsonmerge==1.8.
|
||||
matplotlib==3.6.
|
||||
resize-right==0.0.2
|
||||
torchdiffeq==0.2.3
|
||||
barfi==0.7.0
|
||||
|
||||
# Environment Dependencies for WebUI (flet)
|
||||
|
||||
# txt2vid
|
||||
diffusers==0.6.0
|
||||
diffusers==0.7.2
|
||||
librosa==0.9.2
|
||||
|
||||
# img2img inpainting
|
||||
@ -66,6 +67,7 @@ retry==0.9.2 # used by sd_utils
|
||||
python-slugify==6.1.2 # used by sd_utils
|
||||
piexif==1.1.3 # used by sd_utils
|
||||
pywebview==3.6.3 # used by streamlit_webview.py
|
||||
shutup==0.2.0 # remove all the annoying warnings
|
||||
|
||||
accelerate==0.12.0
|
||||
albumentations==0.4.3
|
||||
|
@ -125,8 +125,13 @@ def layout():
|
||||
st.session_state["defaults"].general.no_half = st.checkbox("No Half", value=st.session_state['defaults'].general.no_half,
|
||||
help="DO NOT switch the model to 16-bit floats. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.use_cudnn = st.checkbox("Use cudnn", value=st.session_state['defaults'].general.use_cudnn,
|
||||
help="Switch the pytorch backend to use cudnn, this should help with fixing Nvidia 16xx cards getting"
|
||||
"a black or green image. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.use_float16 = st.checkbox("Use float16", value=st.session_state['defaults'].general.use_float16,
|
||||
help="Switch the model to 16-bit floats. Default: False")
|
||||
|
||||
|
||||
precision_list = ['full', 'autocast']
|
||||
st.session_state["defaults"].general.precision = st.selectbox("Precision", precision_list, index=precision_list.index(st.session_state['defaults'].general.precision),
|
||||
|
754
scripts/convert_original_stable_diffusion_to_diffusers.py
Normal file
754
scripts/convert_original_stable_diffusion_to_diffusers.py
Normal file
@ -0,0 +1,754 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
#DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
||||
else:
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item.replace("in_layers.0", "norm1")
|
||||
new_item = new_item.replace("in_layers.2", "conv1")
|
||||
|
||||
new_item = new_item.replace("out_layers.0", "norm2")
|
||||
new_item = new_item.replace("out_layers.3", "conv2")
|
||||
|
||||
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
||||
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({"old": old_item, "new": new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||
to them. It splits attention layers, and takes into account additional replacements
|
||||
that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
config = dict(
|
||||
sample_size=unet_params.image_size,
|
||||
in_channels=unet_params.in_channels,
|
||||
out_channels=unet_params.out_channels,
|
||||
down_block_types=tuple(down_block_types),
|
||||
up_block_types=tuple(up_block_types),
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
layers_per_block=unet_params.num_res_blocks,
|
||||
cross_attention_dim=unet_params.context_dim,
|
||||
attention_head_dim=unet_params.num_heads,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_vae_diffusers_config(original_config):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
config = dict(
|
||||
sample_size=vae_params.resolution,
|
||||
in_channels=vae_params.in_channels,
|
||||
out_channels=vae_params.out_ch,
|
||||
down_block_types=tuple(down_block_types),
|
||||
up_block_types=tuple(up_block_types),
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
latent_channels=vae_params.z_channels,
|
||||
layers_per_block=vae_params.num_res_blocks,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def create_diffusers_schedular(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
beta_end=original_config.model.params.linear_end,
|
||||
beta_schedule="scaled_linear",
|
||||
)
|
||||
return schedular
|
||||
|
||||
|
||||
def create_ldm_bert_config(original_config):
|
||||
bert_params = original_config.model.parms.cond_stage_config.params
|
||||
config = LDMBertConfig(
|
||||
d_model=bert_params.n_embed,
|
||||
encoder_layers=bert_params.n_layer,
|
||||
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
unet_key = "model.diffusion_model."
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if extract_ema:
|
||||
print(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
print(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
||||
output_blocks = {
|
||||
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_output_blocks)
|
||||
}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
resnets = [
|
||||
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||
]
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
)
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
resnet_1 = middle_blocks[2]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
||||
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(
|
||||
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||
output_block_list = {}
|
||||
|
||||
for layer in output_block_layers:
|
||||
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
||||
if layer_id in output_block_list:
|
||||
output_block_list[layer_id].append(layer_name)
|
||||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
||||
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnets)
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
|
||||
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
||||
f"output_blocks.{i}.{index}.conv.bias"
|
||||
]
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
attentions = []
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
"old": f"output_blocks.{i}.1",
|
||||
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||
)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
||||
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||
|
||||
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||
|
||||
def _copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
def _copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||
|
||||
# copy attn
|
||||
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||
|
||||
# copy MLP
|
||||
pt_mlp = pt_layer[1][1]
|
||||
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||
|
||||
def _copy_layers(hf_layers, pt_layers):
|
||||
for i, hf_layer in enumerate(hf_layers):
|
||||
if i != 0:
|
||||
i += i
|
||||
pt_layer = pt_layers[i : i + 2]
|
||||
_copy_layer(hf_layer, pt_layer)
|
||||
|
||||
hf_model = LDMBertModel(config).eval()
|
||||
|
||||
# copy embeds
|
||||
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||
|
||||
# copy layer norm
|
||||
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||
|
||||
# copy hidden layers
|
||||
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||
|
||||
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
try:
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
||||
parser.add_argument(
|
||||
"--original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type",
|
||||
default="pndm",
|
||||
type=str,
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract_ema",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
||||
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
||||
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.original_config_file is None:
|
||||
os.system(
|
||||
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
args.original_config_file = "./v1-inference.yaml"
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path)
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
if args.scheduler_type == "pndm":
|
||||
scheduler = PNDMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
elif args.scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
elif args.scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
elif args.scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
||||
)
|
||||
elif args.scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
||||
)
|
||||
elif args.scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config)
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
# Convert the text model.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenCLIPEmbedder":
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
pipe.save_pretrained(args.dump_path)
|
@ -14,7 +14,7 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, server_state, \
|
||||
from sd_utils import st, server_state, no_rerun, \
|
||||
generation_callback, process_images, KDiffusionSampler, \
|
||||
custom_models_available, RealESRGAN_available, GFPGAN_available, \
|
||||
LDSR_available, load_models, hc, seed_to_int, logger, \
|
||||
@ -378,6 +378,10 @@ def layout():
|
||||
placeholder = "A corgi wearing a top hat as an oil painting."
|
||||
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
|
||||
sygil_suggestions.suggestion_area(placeholder)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].admin.global_negative_prompt:
|
||||
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
img2img_generate_col.write("")
|
||||
@ -690,11 +694,12 @@ def layout():
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
with col3_img2img_layout:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
with no_rerun:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
|
||||
if uploaded_images:
|
||||
#image = Image.fromarray(image).convert('RGBA')
|
||||
|
58
scripts/prune-ckpt.py
Normal file
58
scripts/prune-ckpt.py
Normal file
@ -0,0 +1,58 @@
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Pruning')
|
||||
parser.add_argument('--ckpt', type=str, default=None, help='path to model ckpt')
|
||||
args = parser.parse_args()
|
||||
ckpt = args.ckpt
|
||||
|
||||
def prune_it(p, keep_only_ema=False):
|
||||
print(f"prunin' in path: {p}")
|
||||
size_initial = os.path.getsize(p)
|
||||
nsd = dict()
|
||||
sd = torch.load(p, map_location="cpu")
|
||||
print(sd.keys())
|
||||
for k in sd.keys():
|
||||
if k != "optimizer_states":
|
||||
nsd[k] = sd[k]
|
||||
else:
|
||||
print(f"removing optimizer states for path {p}")
|
||||
if "global_step" in sd:
|
||||
print(f"This is global step {sd['global_step']}.")
|
||||
if keep_only_ema:
|
||||
sd = nsd["state_dict"].copy()
|
||||
# infer ema keys
|
||||
ema_keys = {k: "model_ema." + k[6:].replace(".", ".") for k in sd.keys() if k.startswith("model.")}
|
||||
new_sd = dict()
|
||||
|
||||
for k in sd:
|
||||
if k in ema_keys:
|
||||
new_sd[k] = sd[ema_keys[k]].half()
|
||||
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
|
||||
new_sd[k] = sd[k].half()
|
||||
|
||||
assert len(new_sd) == len(sd) - len(ema_keys)
|
||||
nsd["state_dict"] = new_sd
|
||||
else:
|
||||
sd = nsd['state_dict'].copy()
|
||||
new_sd = dict()
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k].half()
|
||||
nsd['state_dict'] = new_sd
|
||||
|
||||
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
|
||||
print(f"saving pruned checkpoint at: {fn}")
|
||||
torch.save(nsd, fn)
|
||||
newsize = os.path.getsize(fn)
|
||||
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
|
||||
f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
|
||||
if keep_only_ema:
|
||||
MSG += " and non-EMA weights"
|
||||
print(MSG)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prune_it(ckpt)
|
@ -107,7 +107,7 @@ def getConceptsFromPath(page, conceptPerPage, searchText=""):
|
||||
|
||||
# Maintain the aspect ratio (max 200x200)
|
||||
resizedImage = originalImage.copy()
|
||||
resizedImage.thumbnail((200, 200), Image.ANTIALIAS)
|
||||
resizedImage.thumbnail((200, 200), Image.Resampling.LANCZOS)
|
||||
|
||||
# concept["images"].append(resizedImage)
|
||||
|
||||
|
@ -22,7 +22,7 @@ from streamlit.runtime.scriptrunner import StopException
|
||||
#from streamlit.runtime.scriptrunner import script_run_context
|
||||
|
||||
#streamlit components section
|
||||
from streamlit_server_state import server_state, server_state_lock
|
||||
from streamlit_server_state import server_state, server_state_lock, no_rerun
|
||||
import hydralit_components as hc
|
||||
from hydralit import HydraHeadApp
|
||||
import streamlit_nested_layout
|
||||
@ -72,6 +72,7 @@ from io import BytesIO
|
||||
from packaging import version
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
import shutup
|
||||
|
||||
#import librosa
|
||||
from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
@ -91,6 +92,15 @@ except ImportError as e:
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# remove all the annoying python warnings.
|
||||
shutup.please()
|
||||
|
||||
# the following lines should help fixing an issue with nvidia 16xx cards.
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state["defaults"].general.use_cudnn:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
@ -261,10 +271,13 @@ def set_page_title(title):
|
||||
|
||||
|
||||
def make_grid(n_items=5, n_cols=5):
|
||||
# Compute number of rows
|
||||
n_rows = 1 + n_items // int(n_cols)
|
||||
|
||||
# Create rows
|
||||
rows = [st.container() for _ in range(n_rows)]
|
||||
|
||||
# Create columns in each row
|
||||
cols_per_row = [r.columns(n_cols) for r in rows]
|
||||
cols = [column for row in cols_per_row for column in row]
|
||||
|
||||
@ -272,29 +285,29 @@ def make_grid(n_items=5, n_cols=5):
|
||||
|
||||
|
||||
def merge(file1, file2, out, weight):
|
||||
alpha = (weight)/100
|
||||
if not(file1.endswith(".ckpt")):
|
||||
file1 += ".ckpt"
|
||||
if not(file2.endswith(".ckpt")):
|
||||
file2 += ".ckpt"
|
||||
if not(out.endswith(".ckpt")):
|
||||
out += ".ckpt"
|
||||
#Load Models
|
||||
model_0 = torch.load(file1)
|
||||
model_1 = torch.load(file2)
|
||||
theta_0 = model_0['state_dict']
|
||||
theta_1 = model_1['state_dict']
|
||||
|
||||
for key in theta_0.keys():
|
||||
if 'model' in key and key in theta_1:
|
||||
theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
|
||||
|
||||
logger.info("RUNNING...\n(STAGE 2)")
|
||||
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
theta_0[key] = theta_1[key]
|
||||
torch.save(model_0, out)
|
||||
try:
|
||||
#Load Models
|
||||
model_0 = torch.load(file1)
|
||||
model_1 = torch.load(file2)
|
||||
theta_0 = model_0['state_dict']
|
||||
theta_1 = model_1['state_dict']
|
||||
alpha = (weight)/100
|
||||
for key in theta_0.keys():
|
||||
if 'model' in key and key in theta_1:
|
||||
theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
|
||||
logger.info("RUNNING...\n(STAGE 2)")
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
theta_0[key] = theta_1[key]
|
||||
torch.save(model_0, out)
|
||||
except:
|
||||
logger.error("Error in merging")
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=3):
|
||||
@ -483,7 +496,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
if "global_step" in pl_sd:
|
||||
logger.info(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
model = instantiate_from_config(config.model, personalization_config='')
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
logger.info("missing keys:")
|
||||
@ -1606,6 +1619,10 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
|
||||
#
|
||||
@retry(tries=5)
|
||||
def generation_callback(img, i=0):
|
||||
|
||||
# try to do garbage collection before decoding the image
|
||||
torch_gc()
|
||||
|
||||
if "update_preview_frequency" not in st.session_state:
|
||||
raise StopException
|
||||
|
||||
@ -2395,7 +2412,7 @@ def process_images(
|
||||
else: # just behave like usual
|
||||
c = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(prompts)
|
||||
|
||||
|
||||
|
||||
shape = [opt_C, height // opt_f, width // opt_f]
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
|
@ -14,7 +14,7 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, MemUsageMonitor, server_state, \
|
||||
from sd_utils import st, MemUsageMonitor, server_state, no_rerun, \
|
||||
get_next_sequence_number, check_prompt_length, torch_gc, \
|
||||
save_sample, generation_callback, process_images, \
|
||||
KDiffusionSampler, \
|
||||
@ -426,6 +426,12 @@ def layout():
|
||||
placeholder = "A corgi wearing a top hat as an oil painting."
|
||||
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
|
||||
sygil_suggestions.suggestion_area(placeholder)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].admin.global_negative_prompt:
|
||||
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
|
||||
|
||||
#print(prompt)
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([2,5,2], gap="large")
|
||||
@ -652,12 +658,13 @@ def layout():
|
||||
if generate_button:
|
||||
|
||||
with col2:
|
||||
if not use_stable_horde:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
with no_rerun:
|
||||
if not use_stable_horde:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
|
||||
#print(st.session_state['use_RealESRGAN'])
|
||||
#print(st.session_state['use_LDSR'])
|
||||
|
@ -21,7 +21,7 @@ https://github.com/nateraw/stable-diffusion-videos
|
||||
repo and the original gist script from
|
||||
https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
||||
"""
|
||||
from sd_utils import st, MemUsageMonitor, server_state, torch_gc, \
|
||||
from sd_utils import st, MemUsageMonitor, server_state, no_rerun, torch_gc, \
|
||||
custom_models_available, RealESRGAN_available, GFPGAN_available, \
|
||||
LDSR_available, hc, seed_to_int, logger, slerp, optimize_update_preview_frequency, \
|
||||
load_learned_embed_in_clip, load_GFPGAN, RealESRGANModel
|
||||
@ -54,7 +54,7 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
||||
#from stable_diffusion_videos import StableDiffusionWalkPipeline
|
||||
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
|
||||
PNDMScheduler
|
||||
PNDMScheduler, DDPMScheduler
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
@ -189,10 +189,10 @@ def make_video_pyav(
|
||||
|
||||
write_video(
|
||||
output_filepath,
|
||||
frames,
|
||||
fps=fps,
|
||||
audio_array=audio_tensor,
|
||||
audio_fps=sr,
|
||||
frames,
|
||||
fps=fps,
|
||||
audio_array=audio_tensor,
|
||||
audio_fps=sr,
|
||||
audio_codec="aac",
|
||||
options={"crf": "10", "pix_fmt": "yuv420p"},
|
||||
)
|
||||
@ -777,22 +777,22 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
prompt_config_path.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
prompts=prompts,
|
||||
seeds=seeds,
|
||||
num_interpolation_steps=num_interpolation_steps,
|
||||
fps=fps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
upsample=upsample,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_filepath=audio_filepath,
|
||||
audio_start_sec=audio_start_sec,
|
||||
),
|
||||
|
||||
indent=2,
|
||||
sort_keys=False,
|
||||
prompts=prompts,
|
||||
seeds=seeds,
|
||||
num_interpolation_steps=num_interpolation_steps,
|
||||
fps=fps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
upsample=upsample,
|
||||
height=height,
|
||||
width=width,
|
||||
audio_filepath=audio_filepath,
|
||||
audio_start_sec=audio_start_sec,
|
||||
),
|
||||
|
||||
indent=2,
|
||||
sort_keys=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -946,6 +946,7 @@ def diffuse(
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
eta,
|
||||
fps=30
|
||||
):
|
||||
|
||||
torch_device = cond_latents.get_device()
|
||||
@ -1055,8 +1056,7 @@ def diffuse(
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
|
||||
#
|
||||
total_frames = (st.session_state.sampling_steps + st.session_state.num_inference_steps) * st.session_state.max_duration_in_seconds
|
||||
total_frames = st.session_state.max_duration_in_seconds * fps
|
||||
total_steps = st.session_state.sampling_steps + st.session_state.num_inference_steps
|
||||
|
||||
if i > st.session_state.sampling_steps:
|
||||
@ -1124,16 +1124,18 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
|
||||
if weights_path == "runwayml/stable-diffusion-v1-5":
|
||||
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
|
||||
else:
|
||||
model_path = weights_path
|
||||
|
||||
if not os.path.exists(model_path + "/model_index.json"):
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
|
||||
)
|
||||
|
||||
@ -1141,11 +1143,11 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
else:
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
use_local_file=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
use_local_file=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
)
|
||||
|
||||
server_state["pipe"].unet.to(torch_device)
|
||||
@ -1195,13 +1197,13 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
st.session_state["progress_bar_text"].error(e)
|
||||
|
||||
#
|
||||
def save_video_to_disk(frames, seeds, sanitized_prompt, fps=6,save_video=True, outdir='outputs'):
|
||||
def save_video_to_disk(frames, seeds, sanitized_prompt, fps=30,save_video=True, outdir='outputs'):
|
||||
if save_video:
|
||||
# write video to memory
|
||||
#output = io.BytesIO()
|
||||
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid"), im, extension=".mp4", fps=30)
|
||||
#try:
|
||||
video_path = os.path.join(os.getcwd(), outdir, "txt2vid",f"{seeds}_{sanitized_prompt}{datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())[:8]}.mp4")
|
||||
video_path = os.path.join(os.getcwd(), outdir, "txt2vid",f"{seeds}_{sanitized_prompt}{datetime.datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())[:8]}.mp4")
|
||||
writer = imageio.get_writer(video_path, fps=fps)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
@ -1357,8 +1359,30 @@ def txt2vid(
|
||||
klms_scheduler = LMSDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
|
||||
#flaxddims_scheduler = FlaxDDIMScheduler(
|
||||
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
#)
|
||||
|
||||
#flaxddpms_scheduler = FlaxDDPMScheduler(
|
||||
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
#)
|
||||
|
||||
#flaxpndms_scheduler = FlaxPNDMScheduler(
|
||||
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
#)
|
||||
|
||||
ddpms_scheduler = DDPMScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
|
||||
klms=klms_scheduler,
|
||||
ddpms=ddpms_scheduler,
|
||||
#flaxddims=flaxddims_scheduler,
|
||||
#flaxddpms=flaxddpms_scheduler,
|
||||
#flaxpndms=flaxpndms_scheduler,
|
||||
)
|
||||
|
||||
with st.session_state["progress_bar_text"].container():
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
@ -1482,9 +1506,9 @@ def txt2vid(
|
||||
#)
|
||||
|
||||
# old code
|
||||
total_frames = (st.session_state.sampling_steps + st.session_state.num_inference_steps) * st.session_state.max_duration_in_seconds
|
||||
total_frames = st.session_state.max_duration_in_seconds * fps
|
||||
|
||||
while second_count < max_duration_in_seconds:
|
||||
while frame_index+1 <= total_frames:
|
||||
st.session_state["frame_duration"] = 0
|
||||
st.session_state["frame_speed"] = 0
|
||||
st.session_state["current_frame"] = frame_index
|
||||
@ -1506,7 +1530,7 @@ def txt2vid(
|
||||
#init = slerp(gpu, float(t), init1, init2)
|
||||
|
||||
with autocast("cuda"):
|
||||
image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
|
||||
image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta, fps=fps)
|
||||
|
||||
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
|
||||
#im = Image.fromarray(image)
|
||||
@ -1560,6 +1584,8 @@ def txt2vid(
|
||||
|
||||
st.session_state["frame_duration"] = duration
|
||||
st.session_state["frame_speed"] = speed
|
||||
if frame_index+1 > total_frames:
|
||||
break
|
||||
|
||||
init1 = init2
|
||||
|
||||
@ -1602,6 +1628,10 @@ def layout():
|
||||
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
|
||||
sygil_suggestions.suggestion_area(placeholder)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].admin.global_negative_prompt:
|
||||
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
generate_col1.write("")
|
||||
generate_col1.write("")
|
||||
@ -1632,6 +1662,9 @@ def layout():
|
||||
|
||||
st.session_state["max_duration_in_seconds"] = st.number_input("Max Duration In Seconds:", value=st.session_state['defaults'].txt2vid.max_duration_in_seconds,
|
||||
help="Specify the max duration in seconds you want your video to be.")
|
||||
|
||||
st.session_state["fps"] = st.number_input("Frames per Second (FPS):", value=st.session_state['defaults'].txt2vid.fps,
|
||||
help="Specify the frame rate of the video.")
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
#st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview,
|
||||
@ -1713,7 +1746,9 @@ def layout():
|
||||
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
scheduler_name_list = ["klms", "ddim"]
|
||||
scheduler_name_list = ["klms", "ddim", "ddpms",
|
||||
#"flaxddims", "flaxddpms", "flaxpndms"
|
||||
]
|
||||
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
|
||||
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
|
||||
|
||||
@ -1874,45 +1909,46 @@ def layout():
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
|
||||
|
||||
if st.session_state["use_GFPGAN"]:
|
||||
if "GFPGAN" in server_state:
|
||||
logger.info("GFPGAN already loaded")
|
||||
with no_rerun:
|
||||
if st.session_state["use_GFPGAN"]:
|
||||
if "GFPGAN" in server_state:
|
||||
logger.info("GFPGAN already loaded")
|
||||
else:
|
||||
with col2:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
# Load GFPGAN
|
||||
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
|
||||
try:
|
||||
load_GFPGAN()
|
||||
logger.info("Loaded GFPGAN")
|
||||
except Exception:
|
||||
import traceback
|
||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
||||
logger.error(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
with col2:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
# Load GFPGAN
|
||||
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
|
||||
try:
|
||||
load_GFPGAN()
|
||||
logger.info("Loaded GFPGAN")
|
||||
except Exception:
|
||||
import traceback
|
||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
||||
logger.error(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if "GFPGAN" in server_state:
|
||||
del server_state["GFPGAN"]
|
||||
if "GFPGAN" in server_state:
|
||||
del server_state["GFPGAN"]
|
||||
|
||||
#try:
|
||||
# run video generation
|
||||
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
|
||||
num_steps=st.session_state.sampling_steps, max_duration_in_seconds=st.session_state.max_duration_in_seconds,
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale, save_video_on_stop=save_video_on_stop,
|
||||
outdir=st.session_state["defaults"].general.outdir,
|
||||
do_loop=st.session_state["do_loop"],
|
||||
use_lerp_for_text=st.session_state["use_lerp_for_text"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
|
||||
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
|
||||
beta_schedule=beta_scheduler_type, starting_image=None)
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale, save_video_on_stop=save_video_on_stop,
|
||||
outdir=st.session_state["defaults"].general.outdir,
|
||||
do_loop=st.session_state["do_loop"],
|
||||
use_lerp_for_text=st.session_state["use_lerp_for_text"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
|
||||
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
|
||||
beta_schedule=beta_scheduler_type, starting_image=None, fps=st.session_state.fps)
|
||||
|
||||
if video and save_video_on_stop:
|
||||
if os.path.exists(video): # temporary solution to bypass exception
|
||||
# show video preview on the UI after we hit the stop button
|
||||
# currently not working as session_state is cleared on StopException
|
||||
preview_video.video(open(video, 'rb').read())
|
||||
preview_video.video(open(video, 'rb').read())
|
||||
|
||||
#message.success('Done!', icon="✅")
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
1239
scripts/webui_flet.py
Normal file
1239
scripts/webui_flet.py
Normal file
File diff suppressed because it is too large
Load Diff
65
scripts/webui_flet_utils.py
Normal file
65
scripts/webui_flet_utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
# webui_utils.py
|
||||
|
||||
# imports
|
||||
import os, yaml
|
||||
from PIL import Image
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
# logging
|
||||
log_file = 'webui_flet.log'
|
||||
|
||||
def log_message(message):
|
||||
with open(log_file,'a+') as log:
|
||||
log.write(message)
|
||||
|
||||
|
||||
# Settings
|
||||
path_to_default_config = 'configs/webui/webui_flet.yaml'
|
||||
path_to_user_config = 'configs/webui/userconfig_flet.yaml'
|
||||
|
||||
def get_default_settings_from_config():
|
||||
with open(path_to_default_config) as f:
|
||||
default_settings = yaml.safe_load(f)
|
||||
return default_settings
|
||||
|
||||
def get_user_settings_from_config():
|
||||
settings = get_default_settings_from_config()
|
||||
if os.path.exists(path_to_user_config):
|
||||
with open(path_to_user_config) as f:
|
||||
user_settings = yaml.safe_load(f)
|
||||
settings.update(user_settings)
|
||||
return settings
|
||||
|
||||
def save_user_settings_to_config(settings):
|
||||
with open(path_to_user_config, 'w+') as f:
|
||||
yaml.dump(settings, f, default_flow_style=False)
|
||||
|
||||
|
||||
# Image handling
|
||||
|
||||
def load_images(images): # just for testing, needs love to function
|
||||
images_loaded = {}
|
||||
images_not_loaded = []
|
||||
for i in images:
|
||||
try:
|
||||
img = Image.open(images[i]['path'])
|
||||
if img:
|
||||
images_loaded.update({images[i].name:img})
|
||||
except:
|
||||
images_not_loaded.append(i)
|
||||
|
||||
return images_loaded, images_not_loaded
|
||||
|
||||
def create_blank_image():
|
||||
img = Image.new('RGBA',(512,512),(0,0,0,0))
|
||||
return img
|
||||
|
||||
|
||||
# Textual Inversion
|
||||
textual_inversion_grid_row_list = [
|
||||
'model', 'medium', 'artist', 'trending', 'movement', 'flavors', 'techniques', 'tags',
|
||||
]
|
||||
|
||||
def run_textual_inversion(args):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user