[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2023-06-23 02:58:20 +00:00
parent fe762e4813
commit a9bc7eae19
234 changed files with 34344 additions and 19018 deletions

View File

@ -1,3 +1,3 @@
outputs/
src/
configs/webui/userconfig_streamlit.yaml
configs/webui/userconfig_streamlit.yaml

2
.gitattributes vendored
View File

@ -1,4 +1,4 @@
* text=auto
*.{cmd,[cC][mM][dD]} text eol=crlf
*.{bat,[bB][aA][tT]} text eol=crlf
*.sh text eol=lf
*.sh text eol=lf

View File

@ -40,7 +40,7 @@ body:
- type: dropdown
id: os
attributes:
label: Where are you running the webui?
label: Where are you running the webui?
multiple: true
options:
- Windows
@ -52,7 +52,7 @@ body:
attributes:
label: Custom settings
description: If you are running the webui with specifi settings, please paste them here for reference (like --nitro)
render: shell
render: shell
- type: textarea
id: logs
attributes:
@ -66,4 +66,4 @@ body:
description: By submitting this issue, you agree to follow our [Code of Conduct](https://docs.github.com/en/site-policy/github-terms/github-community-code-of-conduct)
options:
- label: I agree to follow this project's Code of Conduct
required: true
required: true

View File

@ -13,4 +13,4 @@ Closes: # (issue)
- [ ] I have changed the base branch to `dev`
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] I have made corresponding changes to the documentation

View File

@ -37,4 +37,4 @@ jobs:
# The GH actions bot is used by default if you didn't specify the two fields.
# You can swap them out with your own user credentials.
user_name: github-actions[bot]
user_email: 41898282+github-actions[bot]@users.noreply.github.com
user_email: 41898282+github-actions[bot]@users.noreply.github.com

View File

@ -21,4 +21,4 @@ jobs:
- name: Install dependencies
run: yarn install
- name: Test build website
run: yarn build
run: yarn build

View File

@ -6,7 +6,7 @@
## Installation instructions for:
- **[Windows](https://sygil-dev.github.io/sygil-webui/docs/Installation/windows-installation)**
- **[Windows](https://sygil-dev.github.io/sygil-webui/docs/Installation/windows-installation)**
- **[Linux](https://sygil-dev.github.io/sygil-webui/docs/Installation/linux-installation)**
### Want to ask a question or request a feature?
@ -34,10 +34,10 @@ Check the [Contribution Guide](CONTRIBUTING.md)
* Run additional upscaling models on CPU to save VRAM
* Textual inversion: [Reaserch Paper](https://textual-inversion.github.io/)
* Textual inversion: [Reaserch Paper](https://textual-inversion.github.io/)
* K-Diffusion Samplers: A great collection of samplers to use, including:
- `k_euler`
- `k_lms`
- `k_euler_a`
@ -95,8 +95,8 @@ An easy way to work with Stable Diffusion right from your browser.
To give a token (tag recognized by the AI) a specific or increased weight (emphasis), add `:0.##` to the prompt, where `0.##` is a decimal that will specify the weight of all tokens before the colon.
Ex: `cat:0.30, dog:0.70` or `guy riding a bicycle :0.7, incoming car :0.30`
Negative prompts can be added by using `###` , after which any tokens will be seen as negative.
Ex: `cat playing with string ### yarn` will negate `yarn` from the generated image.
Negative prompts can be added by using `###` , after which any tokens will be seen as negative.
Ex: `cat playing with string ### yarn` will negate `yarn` from the generated image.
Negatives are a very powerful tool to get rid of contextually similar or related topics, but **be careful when adding them since the AI might see connections you can't**, and end up outputting gibberish
@ -131,7 +131,7 @@ Lets you improve faces in pictures using the GFPGAN model. There is a checkbox i
If you want to use GFPGAN to improve generated faces, you need to install it separately.
Download [GFPGANv1.4.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth) and put it
into the `/sygil-webui/models/gfpgan` directory.
into the `/sygil-webui/models/gfpgan` directory.
### RealESRGAN
@ -141,7 +141,7 @@ Lets you double the resolution of generated images. There is a checkbox in every
There is also a separate tab for using RealESRGAN on any picture.
Download [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) and [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth).
Put them into the `sygil-webui/models/realesrgan` directory.
Put them into the `sygil-webui/models/realesrgan` directory.
### LSDR
@ -174,8 +174,8 @@ which is available on [GitHub](https://github.com/CompVis/latent-diffusion). PDF
[Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion
model.
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion).
@ -184,26 +184,26 @@ See [this section](#stable-diffusion-v1) below and the [model card](https://hugg
Stable Diffusion v1 refers to a specific configuration of the model
architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet
and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and
and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and
then finetuned on 512x512 images.
*Note: Stable Diffusion v1 is a general text-to-image diffusion model and therefore mirrors biases and (mis-)conceptions that are present
in its training data.
in its training data.
Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/CompVis/stable-diffusion).
## Comments
- Our code base for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
Thanks for open-sourcing!
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
## BibTeX
```
@misc{rombach2021highresolution,
title={High-Resolution Image Synthesis with Latent Diffusion Models},
title={High-Resolution Image Synthesis with Latent Diffusion Models},
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
year={2021},
eprint={2112.10752},

View File

@ -21,7 +21,7 @@ This model card focuses on the model associated with the Stable Diffusion model,
# Uses
## Direct Use
## Direct Use
The model is intended for research purposes only. Possible research areas and
tasks include
@ -68,11 +68,11 @@ Using the model to generate content that is cruel to individuals is a misuse of
considerations.
### Bias
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
which consists of images that are primarily limited to English descriptions.
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
which consists of images that are primarily limited to English descriptions.
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
@ -84,7 +84,7 @@ The model developers used the following dataset for training the model:
- LAION-2B (en) and subsets thereof (see next section)
**Training Procedure**
Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
- Text prompts are encoded through a ViT-L/14 text-encoder.
@ -108,12 +108,12 @@ filtered to images with an original size `>= 512x512`, estimated aesthetics scor
- **Batch:** 32 x 8 x 2 x 4 = 2048
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
## Evaluation Results
## Evaluation Results
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
steps show the relative improvements of the checkpoints:
![pareto](assets/v1-variants-scores.jpg)
![pareto](assets/v1-variants-scores.jpg)
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
## Environmental Impact
@ -137,4 +137,3 @@ Based on that information, we estimate the following CO2 emissions using the [Ma
}
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*

View File

@ -582,4 +582,4 @@
"outputs": []
}
]
}
}

View File

@ -23,7 +23,7 @@ Hopefully demand will be high, we want to train **hundreds** of new concepts!
# What does `most inventive use` mean?
Whatever you want it to mean! be creative! experiment!
Whatever you want it to mean! be creative! experiment!
There are several categories we will look at:
@ -33,7 +33,7 @@ There are several categories we will look at:
* composition; meaning anything related to how big things are, their position, the angle, etc
* styling;
* styling;
![image](https://user-images.githubusercontent.com/106811348/197045629-029ba6f5-1f79-475c-9ce7-969aaf3d253b.png)
@ -45,7 +45,7 @@ There are several categories we will look at:
## `The Sims(TM): Stable Diffusion edition` ?
For this event the theme is “The Sims: Stable Diffusion edition”.
For this event the theme is “The Sims: Stable Diffusion edition”.
So we have selected a subset of [products from Amazon Berkely Objects dataset](https://github.com/sd-webui/abo).

View File

@ -17,5 +17,5 @@
"type_vocab_size": 2,
"vocab_size": 30522,
"encoder_width": 768,
"add_cross_attention": true
"add_cross_attention": true
}

View File

@ -21,7 +21,7 @@ init_lr: 1e-5
image_size: 384
# generation configs
max_length: 20
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '
@ -30,4 +30,3 @@ prompt: 'a picture of '
weight_decay: 0.05
min_lr: 0
max_epoch: 5

View File

@ -17,5 +17,5 @@
"type_vocab_size": 2,
"vocab_size": 30524,
"encoder_width": 768,
"add_cross_attention": true
"add_cross_attention": true
}

View File

@ -1,4 +1,4 @@
image_root: '/export/share/datasets/vision/NLVR2/'
image_root: '/export/share/datasets/vision/NLVR2/'
ann_root: 'annotation'
# set pretrained as a file path or an url
@ -6,8 +6,8 @@ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/mo
#size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 64
batch_size_train: 16
batch_size_test: 64
vit_grad_ckpt: False
vit_ckpt_layer: 0
max_epoch: 15
@ -18,4 +18,3 @@ image_size: 384
weight_decay: 0.05
init_lr: 3e-5
min_lr: 0

View File

@ -12,4 +12,4 @@ image_size: 384
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '
prompt: 'a picture of '

View File

@ -1,7 +1,7 @@
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
]
laion_path: ''
laion_path: ''
# size of vit model; base or large
vit: 'base'
@ -22,6 +22,3 @@ warmup_lr: 1e-6
lr_decay_rate: 0.9
max_epoch: 20
warmup_steps: 3000

View File

@ -31,4 +31,3 @@ negative_all_rank: True
weight_decay: 0.05
min_lr: 0
max_epoch: 6

View File

@ -31,4 +31,3 @@ negative_all_rank: False
weight_decay: 0.05
min_lr: 0
max_epoch: 6

View File

@ -9,4 +9,4 @@ vit: 'base'
batch_size: 64
k_test: 128
image_size: 384
num_frm_test: 8
num_frm_test: 8

View File

@ -8,8 +8,8 @@ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/mo
# size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 32
batch_size_train: 16
batch_size_test: 32
vit_grad_ckpt: False
vit_ckpt_layer: 0
init_lr: 2e-5
@ -22,4 +22,4 @@ inference: 'rank'
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 10
max_epoch: 10

View File

@ -83,4 +83,4 @@ lightning:
increase_log_steps: False
trainer:
benchmark: True
benchmark: True

View File

@ -95,4 +95,4 @@ lightning:
increase_log_steps: False
trainer:
benchmark: True
benchmark: True

View File

@ -15,7 +15,7 @@ model:
conditioning_key: crossattn
monitor: val/loss
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
@ -37,7 +37,7 @@ model:
use_spatial_transformer: true
transformer_depth: 1
context_dim: 512
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
@ -59,7 +59,7 @@ model:
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.ClassEmbedder
params:

View File

@ -82,4 +82,4 @@ lightning:
increase_log_steps: False
trainer:
benchmark: True
benchmark: True

View File

@ -82,4 +82,4 @@ lightning:
increase_log_steps: False
trainer:
benchmark: True
benchmark: True

View File

@ -88,4 +88,4 @@ lightning:
trainer:
benchmark: True
benchmark: True

View File

@ -65,4 +65,4 @@ model:
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: torch.nn.Identity
target: torch.nn.Identity

View File

@ -70,5 +70,3 @@ model:
params:
freeze: True
layer: "penultimate"

View File

@ -73,4 +73,3 @@ model:
params:
freeze: True
layer: "penultimate"

View File

@ -12,7 +12,7 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# UI defaults configuration file. Is read automatically if located at configs/webui/webui.yaml, or specify path via --defaults.

View File

@ -436,4 +436,4 @@ model_manager:
files:
sygil_diffusion:
file_name: "sygil-diffusion-v0.4.ckpt"
download_link: "https://huggingface.co/Sygil/Sygil-Diffusion/resolve/main/sygil-diffusion-v0.4.ckpt"
download_link: "https://huggingface.co/Sygil/Sygil-Diffusion/resolve/main/sygil-diffusion-v0.4.ckpt"

View File

@ -1,26 +1,46 @@
import os, subprocess
import yaml
print (os.getcwd)
print(os.getcwd)
try:
with open("environment.yaml") as file_handle:
environment_data = yaml.safe_load(file_handle, Loader=yaml.FullLoader)
with open("environment.yaml") as file_handle:
environment_data = yaml.safe_load(file_handle, Loader=yaml.FullLoader)
except FileNotFoundError:
try:
with open(os.path.join("..", "environment.yaml")) as file_handle:
environment_data = yaml.safe_load(file_handle, Loader=yaml.FullLoader)
except:
pass
try:
with open(os.path.join("..", "environment.yaml")) as file_handle:
environment_data = yaml.safe_load(file_handle, Loader=yaml.FullLoader)
except:
pass
try:
for dependency in environment_data["dependencies"]:
package_name, package_version = dependency.split("=")
os.system("pip install {}=={}".format(package_name, package_version))
for dependency in environment_data["dependencies"]:
package_name, package_version = dependency.split("=")
os.system("pip install {}=={}".format(package_name, package_version))
except:
pass
pass
try:
subprocess.run(['python', '-m', 'streamlit', "run" ,os.path.join("..","scripts/webui_streamlit.py"), "--theme.base dark"], stdout=subprocess.DEVNULL)
subprocess.run(
[
"python",
"-m",
"streamlit",
"run",
os.path.join("..", "scripts/webui_streamlit.py"),
"--theme.base dark",
],
stdout=subprocess.DEVNULL,
)
except FileExistsError:
subprocess.run(['python', '-m', 'streamlit', "run" ,"scripts/webui_streamlit.py", "--theme.base dark"], stdout=subprocess.DEVNULL)
subprocess.run(
[
"python",
"-m",
"streamlit",
"run",
"scripts/webui_streamlit.py",
"--theme.base dark",
],
stdout=subprocess.DEVNULL,
)

View File

@ -10580,4 +10580,4 @@ zdzisław beksinski
Ödön Márffy
Þórarinn B Þorláksson
Þórarinn B. Þorláksson
Ștefan Luchian
Ștefan Luchian

View File

@ -102634,4 +102634,4 @@ zzislaw beksinski
🦑 design
🦩🪐🐞👩🏻🦳
🧒 📸 🎨
🪔 🎨;🌞🌄
🪔 🎨;🌞🌄

View File

@ -101,4 +101,4 @@ graffiti art
lineart
pixel art
poster art
vector art
vector art

View File

@ -197,4 +197,4 @@ verdadism
video art
viennese actionism
visual art
vorticism
vorticism

View File

@ -15,4 +15,4 @@ reddit
shutterstock
tumblr
unsplash
zbrush central
zbrush central

View File

@ -157,4 +157,4 @@
/r/ImaginaryWitches
/r/ImaginaryWizards
/r/ImaginaryWorldEaters
/r/ImaginaryWorlds
/r/ImaginaryWorlds

View File

@ -1799,7 +1799,7 @@ vacbed
vaginal-birth
vaginal-sticker
vampire
variant-set
variant-set
vegetablenabe
vel
very-long-hair
@ -1933,4 +1933,4 @@ zenzai-monaka
zijou
zin-crow
zinkurou
zombie
zombie

View File

@ -23,7 +23,7 @@ Food Art
Tattoo
Digital
Pixel
Embroidery
Embroidery
Line
Pointillism
Single Color
@ -60,4 +60,4 @@ Street
Realistic
Photo Realistic
Hyper Realistic
Doodle
Doodle

View File

@ -6,7 +6,7 @@ denoising_strength: 0.55
variation: 3
initial_seed: 1
# put foreground onto background
# put foreground onto background
size: 512, 512
color: 0,0,0
@ -16,7 +16,7 @@ color:0,0,0,0
resize: 300, 300
pos: 256, 350
// select mask by probing some pixels from the image
// select mask by probing some pixels from the image
mask_by_color_at: 15, 15, 15, 256, 85, 465, 100, 480
mask_by_color_threshold:80
mask_by_color_space: HLS

View File

@ -27,7 +27,7 @@ transform3d_max_mask: 255
transform3d_inpaint_radius: 1
transform3d_inpaint_method: 0
## put foreground onto background
## put foreground onto background
size: 512, 512

View File

@ -4,7 +4,7 @@ ddim_steps: 50
denoising_strength: 0.5
initial_seed: 2
# put foreground onto background
# put foreground onto background
size: 512, 512
## create foreground

View File

@ -23,4 +23,4 @@
"8": ["seagreen", "darkseagreen"]
}
}
}
}

View File

@ -36701,4 +36701,4 @@
}
]
}
}
}

File diff suppressed because one or more lines are too long

View File

@ -14,7 +14,7 @@ Home Page: https://github.com/Sygil-Dev/sygil-webui
- Open the `installer` folder and copy the `install.bat` to the root folder next to the `webui.bat`
- Double-click the `install.bat` file and wait for it to handle everything for you.
- Double-click the `install.bat` file and wait for it to handle everything for you.
### Installation on Linux:
@ -26,4 +26,4 @@ Home Page: https://github.com/Sygil-Dev/sygil-webui
- Wait for the installer to handle everything for you.
After installation, you can run the `webui.cmd` file (on Windows) or `webui.sh` file (on Linux/Mac) to start the WebUI.
After installation, you can run the `webui.cmd` file (on Windows) or `webui.sh` file (on Linux/Mac) to start the WebUI.

View File

@ -37,45 +37,45 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
* Open Miniconda3 Prompt from your start menu after it has been installed
* _(Optional)_ Create a new text file in your root directory `/sygil-webui/custom-conda-path.txt` that contains the path to your relevant Miniconda3, for example `C:\Users\<username>\miniconda3` (replace `<username>` with your own username). This is required if you have more than 1 miniconda installation or are using custom installation location.
* _(Optional)_ Create a new text file in your root directory `/sygil-webui/custom-conda-path.txt` that contains the path to your relevant Miniconda3, for example `C:\Users\<username>\miniconda3` (replace `<username>` with your own username). This is required if you have more than 1 miniconda installation or are using custom installation location.
## Cloning the repo
Type `git clone https://github.com/Sygil-Dev/sygil-webui.git` into the prompt.
Type `git clone https://github.com/Sygil-Dev/sygil-webui.git` into the prompt.
This will create the `sygil-webui` directory in your Windows user folder.
This will create the `sygil-webui` directory in your Windows user folder.
![CleanShot 2022-08-31 at 16 31 20@2x](https://user-images.githubusercontent.com/463317/187796462-29e5bafd-bbc1-4a48-adc8-7eccc174cb62.jpg)
---
---
Once a repo has been cloned, updating it is as easy as typing `git pull` inside of Miniconda when in the repos topmost directory downloaded by the clone command. Below you can see I used the `cd` command to navigate into that folder.
![CleanShot 2022-08-31 at 16 36 34@2x](https://user-images.githubusercontent.com/463317/187796970-db94402f-717b-43a8-9c85-270c0cd256c3.jpg)
* Next you are going to want to create a Hugging Face account: [https://huggingface.co/](https://huggingface.co/)
* Next you are going to want to create a Hugging Face account: [https://huggingface.co/](https://huggingface.co/)
* After you have signed up, and are signed in go to this link and click on Authorize: [https://huggingface.co/CompVis/stable-diffusion-v-1-4-original](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)
* After you have signed up, and are signed in go to this link and click on Authorize: [https://huggingface.co/CompVis/stable-diffusion-v-1-4-original](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original)
* After you have authorized your account, go to this link to download the model weights for version 1.4 of the model, future versions will be released in the same way, and updating them will be a similar process :
* After you have authorized your account, go to this link to download the model weights for version 1.4 of the model, future versions will be released in the same way, and updating them will be a similar process :
[https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt)
* Download the model into this directory: `C:\Users\<username>\sygil-webui\models\ldm\stable-diffusion-v1`
* Rename `sd-v1-4.ckpt` to `model.ckpt` once it is inside the stable-diffusion-v1 folder.
* Since we are already in our sygil-webui folder in Miniconda, our next step is to create the environment Stable Diffusion needs to work.
* Since we are already in our sygil-webui folder in Miniconda, our next step is to create the environment Stable Diffusion needs to work.
* _(Optional)_ If you already have an environment set up for an installation of Stable Diffusion named ldm open up the `environment.yaml` file in `\sygil-webui\` change the environment name inside of it from `ldm` to `ldo`
---
---
## First run
* `webui.cmd` at the root folder (`\sygil-webui\`) is your main script that you'll always run. It has the functions to automatically do the followings:
* Create conda env
* Create conda env
* Install and update requirements
* Run the relauncher and webui.py script for gradio UI options
* Run the relauncher and webui.py script for gradio UI options
* Run `webui.cmd` by double clicking the file.
@ -83,7 +83,7 @@ Once a repo has been cloned, updating it is as easy as typing `git pull` inside
![First successful run](https://user-images.githubusercontent.com/3688500/189009827-66c5df32-be44-4851-a265-6791444f537f.JPG)
* You'll receive warning messages on **GFPGAN**, **RealESRGAN** and **LDSR** but these are optionals and will be further explained below.
* You'll receive warning messages on **GFPGAN**, **RealESRGAN** and **LDSR** but these are optionals and will be further explained below.
* In the meantime, you can now go to your web browser and open the link to [http://localhost:7860/](http://localhost:7860/).
@ -91,9 +91,9 @@ Once a repo has been cloned, updating it is as easy as typing `git pull` inside
* You should be able to see progress in your `webui.cmd` window. The [http://localhost:7860/](http://localhost:7860/) will be automatically updated to show the final image once progress reach 100%
* Images created with the web interface will be saved to `\sygil-webui\outputs\` in their respective folders alongside `.yaml` text files with all of the details of your prompts for easy referencing later. Images will also be saved with their seed and numbered so that they can be cross referenced with their `.yaml` files easily.
* Images created with the web interface will be saved to `\sygil-webui\outputs\` in their respective folders alongside `.yaml` text files with all of the details of your prompts for easy referencing later. Images will also be saved with their seed and numbered so that they can be cross referenced with their `.yaml` files easily.
---
---
### Optional additional models
@ -104,12 +104,12 @@ There are three more models that we need to download in order to get the most ou
### GFPGAN
1. If you want to use GFPGAN to improve generated faces, you need to install it separately.
2. Download [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) and [GFPGANv1.4.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth) and put it into the `/sygil-webui/models/gfpgan` directory.
2. Download [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) and [GFPGANv1.4.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth) and put it into the `/sygil-webui/models/gfpgan` directory.
### RealESRGAN
1. Download [RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth) and [RealESRGAN_x4plus_anime_6B.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth).
2. Put them into the `sygil-webui/models/realesrgan` directory.
2. Put them into the `sygil-webui/models/realesrgan` directory.
### LDSR
@ -117,7 +117,7 @@ There are three more models that we need to download in order to get the most ou
2. Git clone [Hafiidz/latent-diffusion](https://github.com/Hafiidz/latent-diffusion) into your `/sygil-webui/src/` folder.
3. Run `/sygil-webui/models/ldsr/download_model.bat` to automatically download and rename the models.
4. Wait until it is done and you can confirm by confirming two new files in `sygil-webui/models/ldsr/`
5. _(Optional)_ If there are no files there, you can manually download **LDSR** [project.yaml](https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1) and [model last.cpkt](https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1).
5. _(Optional)_ If there are no files there, you can manually download **LDSR** [project.yaml](https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1) and [model last.cpkt](https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1).
6. Rename last.ckpt to model.ckpt and place both under `sygil-webui/models/ldsr/`.
7. Refer to [here](https://github.com/Sygil-Dev/sygil-webui/issues/488) for any issue.

View File

@ -46,7 +46,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
**Step 3:** Make the script executable by opening the directory in your Terminal and typing `chmod +x linux-sd.sh`, or whatever you named this file as.
**Step 4:** Run the script with `./linux-sd.sh`, it will begin by cloning the [WebUI Github Repo](https://github.com/Sygil-Dev/sygil-webui) to the directory the script is located in. This folder will be named `sygil-webui`.
**Step 4:** Run the script with `./linux-sd.sh`, it will begin by cloning the [WebUI Github Repo](https://github.com/Sygil-Dev/sygil-webui) to the directory the script is located in. This folder will be named `sygil-webui`.
**Step 5:** The script will pause and ask that you move/copy the downloaded 1.4 AI models to the `sygil-webui` folder. Press Enter once you have done so to continue.
@ -67,8 +67,8 @@ The user will have the ability to set these to yes or no using the menu choices.
**Building the Conda environment may take upwards of 15 minutes, depending on your network connection and system specs. This is normal, just leave it be and let it finish. If you are trying to update and the script hangs at `Installing PIP Dependencies` for more than 10 minutes, you will need to `Ctrl-C` to stop the script, delete your `src` folder, and rerun `linux-sd.sh` again.**
**Step 8:** Once the conda environment has been created and the upscaler models have been downloaded, then the user is presented with a choice to choose between the Streamlit or Gradio versions of the WebUI Interface.
- Streamlit:
**Step 8:** Once the conda environment has been created and the upscaler models have been downloaded, then the user is presented with a choice to choose between the Streamlit or Gradio versions of the WebUI Interface.
- Streamlit:
- Has A More Modern UI
- More Features Planned
- Will Be The Main UI Going Forward

View File

@ -56,9 +56,9 @@ Requirements:
* Host computer is AMD64 architecture (e.g. Intel/AMD x86 64-bit CPUs)
* Host computer operating system (Linux or Windows with WSL2 enabled)
* See [Microsoft WSL2 Installation Guide for Windows 10] (https://learn.microsoft.com/en-us/windows/wsl/) for more information on installing.
* Ubuntu (Default) for WSL2 is recommended for Windows users
* Ubuntu (Default) for WSL2 is recommended for Windows users
* Host computer has Docker, or compatible container runtime
* Docker Compose (v1.29+) or later
* Docker Compose (v1.29+) or later
* See [Install Docker Engine] (https://docs.docker.com/engine/install/#supported-platforms) to learn more about installing Docker on your Linux operating system
* 10+ GB Free Disk Space (used by Docker base image, the Stable Diffusion WebUI Docker image for dependencies, model files/weights)
@ -78,7 +78,7 @@ to issues with AMDs support of GPU passthrough. You also _must_ have ROCm driver
```
docker compose -f docker-compose.yml -f docker-compose.amd.yml ...
```
or, by setting
or, by setting
```
export COMPOSE_FILE=docker-compose.yml:docker-compose.amd.yml
```

View File

@ -98,22 +98,22 @@ It is hierarchical, so each layer can have their own child layers.
In the frontend you can find a brief documentation for the syntax, examples and reference for the various arguments.
Here a summary:
Markdown headings, e.g. '# layer0', define layers.
The content of sections define the arguments for image generation.
Markdown headings, e.g. '# layer0', define layers.
The content of sections define the arguments for image generation.
Arguments are defined by lines of the form 'arg:value' or 'arg=value'.
Layers are hierarchical, i.e. each layer can contain more layers.
Layers are hierarchical, i.e. each layer can contain more layers.
The number of '#' increases in the headings of a child layers.
Child layers are blended together by their image masks, like layers in image editors.
By default alpha composition is used for blending.
By default alpha composition is used for blending.
Other blend modes from [ImageChops](https://pillow.readthedocs.io/en/stable/reference/ImageChops.html) can also be used.
Sections with "prompt" and child layers invoke Image2Image, without child layers they invoke Text2Image.
Sections with "prompt" and child layers invoke Image2Image, without child layers they invoke Text2Image.
The result of blending child layers will be the input for Image2Image.
Without "prompt" they are just images, useful for mask selection, image composition, etc.
Images can be initialized with "color", resized with "resize" and their position specified with "pos".
Rotation and rotation center are "rotation" and "center".
Rotation and rotation center are "rotation" and "center".
Mask can automatically be selected by color, color at pixels of the image, or by estimated depth.
@ -128,15 +128,15 @@ The poses describe the camera position and orientation as x,y,z,rotate_x,rotate_
The camera coordinate system is the pinhole camera as described and pictured in [OpenCV "Camera Calibration and 3D Reconstruction" documentation](https://docs.opencv.org/4.x/d9/d0c/group__calib3d.html).
When the camera pose `transform3d_from_pose` where the input image was taken is not specified, the camera pose `transform3d_to_pose` to which the image is to be transformed is in terms of the input camera coordinate system:
Walking forwards one depth unit in the input image corresponds to a position `0,0,1`.
Walking to the right is something like `1,0,0`.
Walking forwards one depth unit in the input image corresponds to a position `0,0,1`.
Walking to the right is something like `1,0,0`.
Going downwards is then `0,1,0`.
## Gradio Optional Customizations
---
Gradio allows for a number of possible customizations via command line arguments/terminal parameters. If you are running these manually, they would need to be run like this: `python scripts/webui.py --param`. Otherwise, you may add your own parameter customizations to `scripts/relauncher.py`, the program that automatically relaunches the Gradio interface should a crash happen.
Gradio allows for a number of possible customizations via command line arguments/terminal parameters. If you are running these manually, they would need to be run like this: `python scripts/webui.py --param`. Otherwise, you may add your own parameter customizations to `scripts/relauncher.py`, the program that automatically relaunches the Gradio interface should a crash happen.
Inside of `relauncher.py` are a few preset defaults most people would likely access:
@ -171,7 +171,7 @@ additional_arguments = ""
---
This is a list of the full set of optional parameters you can launch the Gradio Interface with.
This is a list of the full set of optional parameters you can launch the Gradio Interface with.
```
usage: webui.py [-h] [--ckpt CKPT] [--cli CLI] [--config CONFIG] [--defaults DEFAULTS] [--esrgan-cpu] [--esrgan-gpu ESRGAN_GPU] [--extra-models-cpu] [--extra-models-gpu] [--gfpgan-cpu] [--gfpgan-dir GFPGAN_DIR] [--gfpgan-gpu GFPGAN_GPU] [--gpu GPU]

View File

@ -61,7 +61,7 @@ To use GoBig, you will need to download the RealESRGAN models as directed above.
LSDR is a 4X upscaler with high VRAM usage that uses a Latent Diffusion model to upscale the image. This will accentuate the details of an image, but won't change the composition. This might introduce sharpening, but it is great for textures or compositions with plenty of details. However, it is slower and will use more VRAM.
If you want to use LSDR to upscale your images, you need to download the models for it separately if you are on Windows or doing so manually on Linux.
Download the LDSR [project.yaml](https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1) and [ model last.cpkt](https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1). Rename `last.ckpt` to `model.ckpt` and place both in the `sygil-webui/models/ldsr` directory after you have setup the conda environment for the first time.
Download the LDSR [project.yaml](https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1) and [ model last.cpkt](https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1). Rename `last.ckpt` to `model.ckpt` and place both in the `sygil-webui/models/ldsr` directory after you have setup the conda environment for the first time.
## GoLatent (Gradio only currently)

View File

@ -6,4 +6,4 @@
![](../images/streamlit/streamlit-concepts.png)
The Concept Library allows for the easy usage of custom textual inversion models. These models may be loaded into `models/custom/sd-concepts-library` and will appear in the Concepts Library in Streamlit. To use one of these custom models in a prompt, either copy it using the button on the model, or type `<model-name>` in the prompt where you wish to use it.
The Concept Library allows for the easy usage of custom textual inversion models. These models may be loaded into `models/custom/sd-concepts-library` and will appear in the Concepts Library in Streamlit. To use one of these custom models in a prompt, either copy it using the button on the model, or type `<model-name>` in the prompt where you wish to use it.

View File

@ -18,7 +18,7 @@ You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
-->
You can use other *versions* of Stable Diffusion, and *fine-tunes* of Stable Diffusion.
You can use other *versions* of Stable Diffusion, and *fine-tunes* of Stable Diffusion.
Any model with the `.ckpt` extension can be placed into the `models/custom` folder and used in the UI. The filename of the model will be used to show the model on the drop-down menu on the UI from which you can select and use your custom model so, make sure it has a good filename so you can recognize it from the drop-down menu.
@ -44,7 +44,7 @@ Any model with the `.ckpt` extension can be placed into the `models/custom` fold
- ### [Trinart v2](https://huggingface.co/naclbit/trinart_stable_diffusion_v2)
-
-
## Unofficial Model List:

View File

@ -27,7 +27,7 @@ const config = {
defaultLocale: 'en',
locales: ['en'],
},
// ...
plugins: [
[
@ -108,7 +108,7 @@ const config = {
/** @type {import('@docusaurus/preset-classic').Options} */
({
docs: {
sidebarCollapsed: false,
sidebarPath: require.resolve('./sidebars.js'),
// Please change this to your repo.
@ -193,4 +193,4 @@ const config = {
}),
};
module.exports = config;
module.exports = config;

View File

@ -13,7 +13,7 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Starts the webserver inside the docker container
#

View File

@ -13,7 +13,7 @@ name: ldm
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
channels:
- pytorch
- defaults

View File

@ -20,4 +20,4 @@ module.exports = {
'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off',
'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off'
}
}
}

View File

@ -1,7 +1,7 @@
/* ----------------------------------------------
* Generated by Animista on 2022-9-3 12:0:51
* Licensed under FreeBSD License.
* See http://animista.net/license for more info.
* See http://animista.net/license for more info.
* w: http://animista.net, t: @cssanimista
* ---------------------------------------------- */
@ -26,7 +26,7 @@
opacity: 1;
}
}
/* CSS HEX */
:root {
@ -130,7 +130,7 @@ background-color:#9d85fbdf!important;
border: none!important;}
/* Background for Gradio stuff along with colors for text */
.dark .gr-box {
background-color:rgba(55, 55, 55, 0.105)!important;
border: solid 0.5px!important;
@ -206,4 +206,3 @@ button, select, textarea {
.dark .gr-check-radio{
background-color: #373737ff!important;
}

View File

@ -18,4 +18,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text { display:none!important; }

View File

@ -146,7 +146,7 @@ div.gallery:hover {
.css-jn99sy {
display: none
}
/* Make the text area widget have a similar height as the text input field */
.st-dy{
height: 54px;
@ -154,14 +154,14 @@ div.gallery:hover {
}
.css-17useex{
gap: 3px;
}
/* Remove some empty spaces to make the UI more compact. */
.css-18e3th9{
padding-left: 10px;
padding-right: 30px;
position: unset !important; /* Fixes the layout/page going up when an expander or another item is expanded and then collapsed */
position: unset !important; /* Fixes the layout/page going up when an expander or another item is expanded and then collapsed */
}
.css-k1vhr4{
padding-top: initial;

View File

@ -12,7 +12,7 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from os import path
import json
@ -28,7 +28,7 @@ def readTextFile(*args):
def css(opt):
styling = readTextFile("css", "styles.css")
if not opt.no_progressbar_hiding:
styling += readTextFile("css", "no_progress_bar.css")
styling += readTextFile("css", "no_progress_bar.css")
return styling

View File

@ -10,4 +10,4 @@
<div id="app"></div>
</body>
</html>
</html>

View File

@ -187,7 +187,7 @@ PERFORMANCE OF THIS SOFTWARE.
***************************************************************************** */var Oa=function(){return Oa=Object.assign||function(t){for(var n,r=1,i=arguments.length;r<i;r++){n=arguments[r];for(var o in n)Object.prototype.hasOwnProperty.call(n,o)&&(t[o]=n[o])}return t},Oa.apply(this,arguments)},sO={thumbnail:!0,animateThumb:!0,currentPagerPosition:"middle",alignThumbnails:"middle",thumbWidth:100,thumbHeight:"80px",thumbMargin:5,appendThumbnailsTo:".lg-components",toggleThumb:!1,enableThumbDrag:!0,enableThumbSwipe:!0,thumbnailSwipeThreshold:10,loadYouTubeThumbnail:!0,youTubeThumbSize:1,thumbnailPluginStrings:{toggleThumbnails:"Toggle thumbnails"}},ys={afterAppendSlide:"lgAfterAppendSlide",init:"lgInit",hasVideo:"lgHasVideo",containerResize:"lgContainerResize",updateSlides:"lgUpdateSlides",afterAppendSubHtml:"lgAfterAppendSubHtml",beforeOpen:"lgBeforeOpen",afterOpen:"lgAfterOpen",slideItemLoad:"lgSlideItemLoad",beforeSlide:"lgBeforeSlide",afterSlide:"lgAfterSlide",posterClick:"lgPosterClick",dragStart:"lgDragStart",dragMove:"lgDragMove",dragEnd:"lgDragEnd",beforeNextSlide:"lgBeforeNextSlide",beforePrevSlide:"lgBeforePrevSlide",beforeClose:"lgBeforeClose",afterClose:"lgAfterClose",rotateLeft:"lgRotateLeft",rotateRight:"lgRotateRight",flipHorizontal:"lgFlipHorizontal",flipVertical:"lgFlipVertical",autoplay:"lgAutoplay",autoplayStart:"lgAutoplayStart",autoplayStop:"lgAutoplayStop"},oO=function(){function e(t,n){return this.thumbOuterWidth=0,this.thumbTotalWidth=0,this.translateX=0,this.thumbClickable=!1,this.core=t,this.$LG=n,this}return e.prototype.init=function(){this.settings=Oa(Oa({},sO),this.core.settings),this.thumbOuterWidth=0,this.thumbTotalWidth=this.core.galleryItems.length*(this.settings.thumbWidth+this.settings.thumbMargin),this.translateX=0,this.setAnimateThumbStyles(),this.core.settings.allowMediaOverlap||(this.settings.toggleThumb=!1),this.settings.thumbnail&&(this.build(),this.settings.animateThumb?(this.settings.enableThumbDrag&&this.enableThumbDrag(),this.settings.enableThumbSwipe&&this.enableThumbSwipe(),this.thumbClickable=!1):this.thumbClickable=!0,this.toggleThumbBar(),this.thumbKeyPress())},e.prototype.build=function(){var t=this;this.setThumbMarkup(),this.manageActiveClassOnSlideChange(),this.$lgThumb.first().on("click.lg touchend.lg",function(n){var r=t.$LG(n.target);!r.hasAttribute("data-lg-item-id")||setTimeout(function(){if(t.thumbClickable&&!t.core.lgBusy){var i=parseInt(r.attr("data-lg-item-id"));t.core.slide(i,!1,!0,!1)}},50)}),this.core.LGel.on(ys.beforeSlide+".thumb",function(n){var r=n.detail.index;t.animateThumb(r)}),this.core.LGel.on(ys.beforeOpen+".thumb",function(){t.thumbOuterWidth=t.core.outer.get().offsetWidth}),this.core.LGel.on(ys.updateSlides+".thumb",function(){t.rebuildThumbnails()}),this.core.LGel.on(ys.containerResize+".thumb",function(){!t.core.lgOpened||setTimeout(function(){t.thumbOuterWidth=t.core.outer.get().offsetWidth,t.animateThumb(t.core.index),t.thumbOuterWidth=t.core.outer.get().offsetWidth},50)})},e.prototype.setThumbMarkup=function(){var t="lg-thumb-outer ";this.settings.alignThumbnails&&(t+="lg-thumb-align-"+this.settings.alignThumbnails);var n='<div class="'+t+`">
<div class="lg-thumb lg-group">
</div>
</div>`;this.core.outer.addClass("lg-has-thumb"),this.settings.appendThumbnailsTo===".lg-components"?this.core.$lgComponents.append(n):this.core.outer.append(n),this.$thumbOuter=this.core.outer.find(".lg-thumb-outer").first(),this.$lgThumb=this.core.outer.find(".lg-thumb").first(),this.settings.animateThumb&&this.core.outer.find(".lg-thumb").css("transition-duration",this.core.settings.speed+"ms").css("width",this.thumbTotalWidth+"px").css("position","relative"),this.setThumbItemHtml(this.core.galleryItems)},e.prototype.enableThumbDrag=function(){var t=this,n={cords:{startX:0,endX:0},isMoved:!1,newTranslateX:0,startTime:new Date,endTime:new Date,touchMoveTime:0},r=!1;this.$thumbOuter.addClass("lg-grab"),this.core.outer.find(".lg-thumb").first().on("mousedown.lg.thumb",function(i){t.thumbTotalWidth>t.thumbOuterWidth&&(i.preventDefault(),n.cords.startX=i.pageX,n.startTime=new Date,t.thumbClickable=!1,r=!0,t.core.outer.get().scrollLeft+=1,t.core.outer.get().scrollLeft-=1,t.$thumbOuter.removeClass("lg-grab").addClass("lg-grabbing"))}),this.$LG(window).on("mousemove.lg.thumb.global"+this.core.lgId,function(i){!t.core.lgOpened||r&&(n.cords.endX=i.pageX,n=t.onThumbTouchMove(n))}),this.$LG(window).on("mouseup.lg.thumb.global"+this.core.lgId,function(){!t.core.lgOpened||(n.isMoved?n=t.onThumbTouchEnd(n):t.thumbClickable=!0,r&&(r=!1,t.$thumbOuter.removeClass("lg-grabbing").addClass("lg-grab")))})},e.prototype.enableThumbSwipe=function(){var t=this,n={cords:{startX:0,endX:0},isMoved:!1,newTranslateX:0,startTime:new Date,endTime:new Date,touchMoveTime:0};this.$lgThumb.on("touchstart.lg",function(r){t.thumbTotalWidth>t.thumbOuterWidth&&(r.preventDefault(),n.cords.startX=r.targetTouches[0].pageX,t.thumbClickable=!1,n.startTime=new Date)}),this.$lgThumb.on("touchmove.lg",function(r){t.thumbTotalWidth>t.thumbOuterWidth&&(r.preventDefault(),n.cords.endX=r.targetTouches[0].pageX,n=t.onThumbTouchMove(n))}),this.$lgThumb.on("touchend.lg",function(){n.isMoved?n=t.onThumbTouchEnd(n):t.thumbClickable=!0})},e.prototype.rebuildThumbnails=function(){var t=this;this.$thumbOuter.addClass("lg-rebuilding-thumbnails"),setTimeout(function(){t.thumbTotalWidth=t.core.galleryItems.length*(t.settings.thumbWidth+t.settings.thumbMargin),t.$lgThumb.css("width",t.thumbTotalWidth+"px"),t.$lgThumb.empty(),t.setThumbItemHtml(t.core.galleryItems),t.animateThumb(t.core.index)},50),setTimeout(function(){t.$thumbOuter.removeClass("lg-rebuilding-thumbnails")},200)},e.prototype.setTranslate=function(t){this.$lgThumb.css("transform","translate3d(-"+t+"px, 0px, 0px)")},e.prototype.getPossibleTransformX=function(t){return t>this.thumbTotalWidth-this.thumbOuterWidth&&(t=this.thumbTotalWidth-this.thumbOuterWidth),t<0&&(t=0),t},e.prototype.animateThumb=function(t){if(this.$lgThumb.css("transition-duration",this.core.settings.speed+"ms"),this.settings.animateThumb){var n=0;switch(this.settings.currentPagerPosition){case"left":n=0;break;case"middle":n=this.thumbOuterWidth/2-this.settings.thumbWidth/2;break;case"right":n=this.thumbOuterWidth-this.settings.thumbWidth}this.translateX=(this.settings.thumbWidth+this.settings.thumbMargin)*t-1-n,this.translateX>this.thumbTotalWidth-this.thumbOuterWidth&&(this.translateX=this.thumbTotalWidth-this.thumbOuterWidth),this.translateX<0&&(this.translateX=0),this.setTranslate(this.translateX)}},e.prototype.onThumbTouchMove=function(t){return t.newTranslateX=this.translateX,t.isMoved=!0,t.touchMoveTime=new Date().valueOf(),t.newTranslateX-=t.cords.endX-t.cords.startX,t.newTranslateX=this.getPossibleTransformX(t.newTranslateX),this.setTranslate(t.newTranslateX),this.$thumbOuter.addClass("lg-dragging"),t},e.prototype.onThumbTouchEnd=function(t){t.isMoved=!1,t.endTime=new Date,this.$thumbOuter.removeClass("lg-dragging");var n=t.endTime.valueOf()-t.startTime.valueOf(),r=t.cords.endX-t.cords.startX,i=Math.abs(r)/n;return i>.15&&t.endTime.valueOf()-t.touchMoveTime<30?(i+=1,i>2&&(i+=1),i=i+i*(Math.abs(r)/this.thumbOuterWidth),this.$lgThumb.css("transition-duration",Math.min(i-1,2)+"settings"),r=r*i,this.translateX=this.getPossibleTransformX(this.translateX-r),this.setTranslate(this.translateX)):this.translateX=t.newTranslateX,Math.abs(t.cords.endX-t.cords.startX)<this.settings.thumbnailSwipeThreshold&&(this.thumbClickable=!0),t},e.prototype.getThumbHtml=function(t,n){var r=this.core.galleryItems[n].__slideVideoInfo||{},i;return r.youtube&&this.settings.loadYouTubeThumbnail?i="//img.youtube.com/vi/"+r.youtube[1]+"/"+this.settings.youTubeThumbSize+".jpg":i=t,'<div data-lg-item-id="'+n+'" class="lg-thumb-item '+(n===this.core.index?" active":"")+`"
</div>`;this.core.outer.addClass("lg-has-thumb"),this.settings.appendThumbnailsTo===".lg-components"?this.core.$lgComponents.append(n):this.core.outer.append(n),this.$thumbOuter=this.core.outer.find(".lg-thumb-outer").first(),this.$lgThumb=this.core.outer.find(".lg-thumb").first(),this.settings.animateThumb&&this.core.outer.find(".lg-thumb").css("transition-duration",this.core.settings.speed+"ms").css("width",this.thumbTotalWidth+"px").css("position","relative"),this.setThumbItemHtml(this.core.galleryItems)},e.prototype.enableThumbDrag=function(){var t=this,n={cords:{startX:0,endX:0},isMoved:!1,newTranslateX:0,startTime:new Date,endTime:new Date,touchMoveTime:0},r=!1;this.$thumbOuter.addClass("lg-grab"),this.core.outer.find(".lg-thumb").first().on("mousedown.lg.thumb",function(i){t.thumbTotalWidth>t.thumbOuterWidth&&(i.preventDefault(),n.cords.startX=i.pageX,n.startTime=new Date,t.thumbClickable=!1,r=!0,t.core.outer.get().scrollLeft+=1,t.core.outer.get().scrollLeft-=1,t.$thumbOuter.removeClass("lg-grab").addClass("lg-grabbing"))}),this.$LG(window).on("mousemove.lg.thumb.global"+this.core.lgId,function(i){!t.core.lgOpened||r&&(n.cords.endX=i.pageX,n=t.onThumbTouchMove(n))}),this.$LG(window).on("mouseup.lg.thumb.global"+this.core.lgId,function(){!t.core.lgOpened||(n.isMoved?n=t.onThumbTouchEnd(n):t.thumbClickable=!0,r&&(r=!1,t.$thumbOuter.removeClass("lg-grabbing").addClass("lg-grab")))})},e.prototype.enableThumbSwipe=function(){var t=this,n={cords:{startX:0,endX:0},isMoved:!1,newTranslateX:0,startTime:new Date,endTime:new Date,touchMoveTime:0};this.$lgThumb.on("touchstart.lg",function(r){t.thumbTotalWidth>t.thumbOuterWidth&&(r.preventDefault(),n.cords.startX=r.targetTouches[0].pageX,t.thumbClickable=!1,n.startTime=new Date)}),this.$lgThumb.on("touchmove.lg",function(r){t.thumbTotalWidth>t.thumbOuterWidth&&(r.preventDefault(),n.cords.endX=r.targetTouches[0].pageX,n=t.onThumbTouchMove(n))}),this.$lgThumb.on("touchend.lg",function(){n.isMoved?n=t.onThumbTouchEnd(n):t.thumbClickable=!0})},e.prototype.rebuildThumbnails=function(){var t=this;this.$thumbOuter.addClass("lg-rebuilding-thumbnails"),setTimeout(function(){t.thumbTotalWidth=t.core.galleryItems.length*(t.settings.thumbWidth+t.settings.thumbMargin),t.$lgThumb.css("width",t.thumbTotalWidth+"px"),t.$lgThumb.empty(),t.setThumbItemHtml(t.core.galleryItems),t.animateThumb(t.core.index)},50),setTimeout(function(){t.$thumbOuter.removeClass("lg-rebuilding-thumbnails")},200)},e.prototype.setTranslate=function(t){this.$lgThumb.css("transform","translate3d(-"+t+"px, 0px, 0px)")},e.prototype.getPossibleTransformX=function(t){return t>this.thumbTotalWidth-this.thumbOuterWidth&&(t=this.thumbTotalWidth-this.thumbOuterWidth),t<0&&(t=0),t},e.prototype.animateThumb=function(t){if(this.$lgThumb.css("transition-duration",this.core.settings.speed+"ms"),this.settings.animateThumb){var n=0;switch(this.settings.currentPagerPosition){case"left":n=0;break;case"middle":n=this.thumbOuterWidth/2-this.settings.thumbWidth/2;break;case"right":n=this.thumbOuterWidth-this.settings.thumbWidth}this.translateX=(this.settings.thumbWidth+this.settings.thumbMargin)*t-1-n,this.translateX>this.thumbTotalWidth-this.thumbOuterWidth&&(this.translateX=this.thumbTotalWidth-this.thumbOuterWidth),this.translateX<0&&(this.translateX=0),this.setTranslate(this.translateX)}},e.prototype.onThumbTouchMove=function(t){return t.newTranslateX=this.translateX,t.isMoved=!0,t.touchMoveTime=new Date().valueOf(),t.newTranslateX-=t.cords.endX-t.cords.startX,t.newTranslateX=this.getPossibleTransformX(t.newTranslateX),this.setTranslate(t.newTranslateX),this.$thumbOuter.addClass("lg-dragging"),t},e.prototype.onThumbTouchEnd=function(t){t.isMoved=!1,t.endTime=new Date,this.$thumbOuter.removeClass("lg-dragging");var n=t.endTime.valueOf()-t.startTime.valueOf(),r=t.cords.endX-t.cords.startX,i=Math.abs(r)/n;return i>.15&&t.endTime.valueOf()-t.touchMoveTime<30?(i+=1,i>2&&(i+=1),i=i+i*(Math.abs(r)/this.thumbOuterWidth),this.$lgThumb.css("transition-duration",Math.min(i-1,2)+"settings"),r=r*i,this.translateX=this.getPossibleTransformX(this.translateX-r),this.setTranslate(this.translateX)):this.translateX=t.newTranslateX,Math.abs(t.cords.endX-t.cords.startX)<this.settings.thumbnailSwipeThreshold&&(this.thumbClickable=!0),t},e.prototype.getThumbHtml=function(t,n){var r=this.core.galleryItems[n].__slideVideoInfo||{},i;return r.youtube&&this.settings.loadYouTubeThumbnail?i="//img.youtube.com/vi/"+r.youtube[1]+"/"+this.settings.youTubeThumbSize+".jpg":i=t,'<div data-lg-item-id="'+n+'" class="lg-thumb-item '+(n===this.core.index?" active":"")+`"
style="width:`+this.settings.thumbWidth+"px; height: "+this.settings.thumbHeight+`;
margin-right: `+this.settings.thumbMargin+`px;">
<img data-lg-item-id="`+n+'" src="'+i+`" />

View File

@ -51,4 +51,4 @@
<glyph unicode="&#xe908;" glyph-name="message-circle" data-tags="message-circle" d="M938.667 448.128v21.205c0 0.725-0.043 1.621-0.085 2.475-5.803 99.755-47.488 190.336-112.725 258.176-68.352 71.125-162.731 117.419-268.843 123.264-0.683 0.043-1.536 0.085-2.347 0.085h-20.864c-59.947 0.683-122.965-13.227-181.931-43.008-52.181-26.496-97.749-63.488-133.931-108.16-56.405-69.717-89.899-158.080-89.941-253.696-0.597-54.4 10.795-111.36 35.157-165.419l-75.605-226.859c-2.816-8.363-3.072-17.835 0-26.965 7.467-22.357 31.616-34.432 53.973-26.965l226.731 75.563c49.493-22.485 105.984-35.243 165.376-35.115 58.539 0.384 115.797 13.141 168.149 36.949 81.579 37.163 151.040 101.248 193.749 186.667 27.477 53.291 43.307 115.84 43.136 181.803zM853.333 447.872c0.128-52.267-12.459-101.333-33.664-142.464-34.176-68.352-88.832-118.827-153.259-148.139-41.387-18.859-86.827-28.971-133.376-29.269-52.096-0.128-101.163 12.459-142.293 33.664-10.624 5.504-22.528 6.059-33.067 2.56l-162.261-54.101 54.101 162.261c3.755 11.221 2.56 22.912-2.389 32.725-23.552 46.677-34.304 96.171-33.792 142.421 0.043 76.331 26.411 145.92 70.955 200.917 28.629 35.371 64.768 64.725 106.24 85.76 46.592 23.552 96.085 34.304 142.336 33.792h19.456c83.712-4.565 158.037-41.003 212.011-97.152 51.285-53.376 84.139-124.416 89.003-202.795z" />
<glyph unicode="&#xe909;" glyph-name="maximize-2" data-tags="maximize-2" d="M793.003 768l-225.835-225.835c-16.683-16.683-16.683-43.691 0-60.331s43.691-16.683 60.331 0l225.835 225.835v-153.003c0-23.552 19.115-42.667 42.667-42.667s42.667 19.115 42.667 42.667v256c0 5.803-1.152 11.307-3.243 16.341s-5.163 9.728-9.216 13.781c-0.043 0.043-0.043 0.043-0.085 0.085-3.925 3.925-8.619 7.083-13.781 9.216-5.035 2.091-10.539 3.243-16.341 3.243h-256c-23.552 0-42.667-19.115-42.667-42.667s19.115-42.667 42.667-42.667zM230.997 85.334l225.835 225.835c16.683 16.683 16.683 43.691 0 60.331s-43.691 16.683-60.331 0l-225.835-225.835v153.003c0 23.552-19.115 42.667-42.667 42.667s-42.667-19.115-42.667-42.667v-256c0-23.552 19.115-42.667 42.667-42.667h256c23.552 0 42.667 19.115 42.667 42.667s-19.115 42.667-42.667 42.667z" />
<glyph unicode="&#xe90a;" glyph-name="minimize-2" data-tags="minimize-2" d="M700.331 554.667l225.835 225.835c16.683 16.683 16.683 43.691 0 60.331s-43.691 16.683-60.331 0l-225.835-225.835v153.003c0 23.552-19.115 42.667-42.667 42.667s-42.667-19.115-42.667-42.667v-256c0-5.803 1.152-11.307 3.243-16.341s5.163-9.728 9.216-13.781c0.043-0.043 0.043-0.043 0.085-0.085 3.925-3.925 8.619-7.083 13.781-9.216 5.035-2.091 10.539-3.243 16.341-3.243h256c23.552 0 42.667 19.115 42.667 42.667s-19.115 42.667-42.667 42.667zM158.165 12.502l225.835 225.835v-153.003c0-23.552 19.115-42.667 42.667-42.667s42.667 19.115 42.667 42.667v256c0 5.803-1.152 11.307-3.243 16.341s-5.163 9.728-9.216 13.781c-0.043 0.043-0.043 0.043-0.085 0.085-4.096 4.053-8.789 7.125-13.781 9.216-5.035 2.091-10.539 3.243-16.341 3.243h-256c-23.552 0-42.667-19.115-42.667-42.667s19.115-42.667 42.667-42.667h153.003l-225.835-225.835c-16.683-16.683-16.683-43.691 0-60.331s43.691-16.683 60.331 0z" />
</font></defs></svg>
</font></defs></svg>

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -10,4 +10,4 @@
<div id="app"></div>
</body>
</html>
</html>

File diff suppressed because it is too large Load Diff

View File

@ -12,8 +12,8 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
''' Class to store image generation parameters to be stored as metadata in the image'''
# along with this program. If not, see <http://www.gnu.org/licenses/>.
""" Class to store image generation parameters to be stored as metadata in the image"""
from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import Dict, Optional
@ -21,6 +21,7 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
import copy
@dataclass
class ImageMetadata:
prompt: str = None
@ -40,11 +41,15 @@ class ImageMetadata:
return info
def as_dict(self) -> Dict[str, str]:
return {f"SD:{key}": str(value) for key, value in asdict(self).items() if value is not None}
return {
f"SD:{key}": str(value)
for key, value in asdict(self).items()
if value is not None
}
@classmethod
def set_on_image(cls, image: Image, metadata: ImageMetadata) -> None:
''' Sets metadata on image, in both text form and as an ImageMetadata object '''
"""Sets metadata on image, in both text form and as an ImageMetadata object"""
if metadata:
image.info = metadata.as_dict()
else:
@ -53,8 +58,8 @@ class ImageMetadata:
@classmethod
def get_from_image(cls, image: Image) -> Optional[ImageMetadata]:
''' Gets metadata from an image, first looking for an ImageMetadata,
then if not found tries to construct one from the info '''
"""Gets metadata from an image, first looking for an ImageMetadata,
then if not found tries to construct one from the info"""
metadata = image.info.get("ImageMetadata", None)
if not metadata:
found_metadata = False

View File

@ -8,4 +8,4 @@
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>
</html>

View File

@ -12,11 +12,11 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
# along with this program. If not, see <http://www.gnu.org/licenses/>.
""" Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations """
from __future__ import annotations
import gradio as gr
from gradio.components import Component, Gallery, Slider
from gradio.components import Component, Gallery
from threading import Event, Timer
from typing import Callable, List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, field
@ -82,11 +82,9 @@ def triggerChangeEvent():
@dataclass
class JobManagerUi:
def wrap_func(
self,
func: Callable,
inputs: List[Component],
outputs: List[Component]) -> Tuple[Callable, List[Component], List[Component]]:
''' Takes a gradio event listener function and its input/outputs and returns wrapped replacements which will
self, func: Callable, inputs: List[Component], outputs: List[Component]
) -> Tuple[Callable, List[Component], List[Component]]:
"""Takes a gradio event listener function and its input/outputs and returns wrapped replacements which will
be managed by JobManager
Parameters:
func (Callable) the original event listener to be wrapped.
@ -101,10 +99,9 @@ class JobManagerUi:
Returns:
Tuple(newFunc (Callable), newInputs (List[Component]), newOutputs (List[Component]), which should be used as
replacements for the passed in function, inputs and outputs
'''
"""
return self._job_manager._wrap_func(
func=func, inputs=inputs, outputs=outputs,
job_ui=self
func=func, inputs=inputs, outputs=outputs, job_ui=self
)
_refresh_btn: gr.Button
@ -123,7 +120,9 @@ class JobManagerUi:
class JobManager:
JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running
JOB_MAX_START_TIME = (
5.0 # How long can a job be stuck 'starting' before assuming it isn't running
)
def __init__(self, max_jobs: int):
self._max_jobs: int = max_jobs
@ -133,66 +132,106 @@ class JobManager:
self._session_key: gr.JSON = None
def draw_gradio_ui(self) -> JobManagerUi:
''' draws the job manager ui in gradio
Returns:
ui (JobManagerUi): object which can connect functions to the ui
'''
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
"""draws the job manager ui in gradio
Returns:
ui (JobManagerUi): object which can connect functions to the ui
"""
assert (
gr.context.Context.block is not None
), "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
with gr.Tabs():
with gr.TabItem("Job Controls"):
with gr.Row():
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
stop_btn = gr.Button(
"Stop All Batches", elem_id="stop", variant="secondary"
)
refresh_btn = gr.Button(
"Refresh Finished Batches",
elem_id="refresh",
variant="secondary",
)
status_text = gr.Textbox(
placeholder="Job Status", interactive=False, show_label=False
)
with gr.Row():
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
active_image_stop_btn = gr.Button(
"Skip Active Batch", variant="secondary"
)
active_image_refresh_btn = gr.Button(
"View Batch Progress", variant="secondary"
)
active_image = gr.Image(
type="pil",
interactive=False,
visible=False,
elem_id="active_iteration_image",
)
with gr.TabItem("Batch Progress Settings"):
with gr.Row():
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
record_steps_checkbox = gr.Checkbox(
value=False, label="Enable Batch Progress Grid"
)
record_steps_interval_slider = gr.Slider(
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
with gr.Row() as record_steps_box:
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
value=3,
label="Record Interval (steps)",
minimum=1,
maximum=25,
step=1,
)
with gr.Row():
steps_to_gallery_checkbox = gr.Checkbox(
value=False, label="Save Progress Grid to Gallery"
)
steps_to_file_checkbox = gr.Checkbox(
value=False, label="Save Progress Grid to File"
)
with gr.TabItem("Maintenance"):
with gr.Row():
gr.Markdown(
"Stop all concurrent sessions, or free memory associated with jobs which were finished after the browser was closed")
"Stop all concurrent sessions, or free memory associated with jobs which were finished after the browser was closed"
)
with gr.Row():
stop_all_sessions_btn = gr.Button(
"Stop All Sessions", elem_id="stop_all", variant="secondary"
)
free_done_sessions_btn = gr.Button(
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
"Clear Finished Jobs",
elem_id="clear_finished",
variant="secondary",
)
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
_active_image_refresh_btn=active_image_refresh_btn,
_rec_steps_checkbox=record_steps_checkbox,
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
return JobManagerUi(
_refresh_btn=refresh_btn,
_stop_btn=stop_btn,
_status_text=status_text,
_stop_all_session_btn=stop_all_sessions_btn,
_free_done_sessions_btn=free_done_sessions_btn,
_active_image=active_image,
_active_image_stop_btn=active_image_stop_btn,
_active_image_refresh_btn=active_image_refresh_btn,
_rec_steps_checkbox=record_steps_checkbox,
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
_rec_steps_intrvl_sldr=record_steps_interval_slider,
_job_manager=self,
)
def clear_all_finished_jobs(self):
''' Removes all currently finished jobs, across all sessions.
Useful to free memory if a job is started and the browser is closed
before it finishes '''
"""Removes all currently finished jobs, across all sessions.
Useful to free memory if a job is started and the browser is closed
before it finishes"""
for session in self._sessions.values():
session.finished_jobs.clear()
def stop_all_jobs(self):
''' Stops all active jobs, across all sessions'''
"""Stops all active jobs, across all sessions"""
for session in self._sessions.values():
for job in session.jobs.values():
job.should_stop.set()
job.stop_cur_iter.set()
def _get_job_token(self, block: bool = False) -> Optional[int]:
''' Attempts to acquire a job token, optionally blocking until available '''
"""Attempts to acquire a job token, optionally blocking until available"""
token = None
while token is None:
try:
@ -212,27 +251,31 @@ class JobManager:
return token
def _release_job_token(self, token: int) -> None:
''' Returns a job token to allow another job to start '''
"""Returns a job token to allow another job to start"""
self._avail_job_tokens.append(token)
self._run_queued_jobs()
def _refresh_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Updates information from the active job '''
"""Updates information from the active job"""
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
return [triggerChangeEvent(), job_info.job_status]
def _stop_wrapped_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Marks that the job should be stopped'''
def _stop_wrapped_func(
self, func_key: FuncKey, session_key: str
) -> List[Component]:
"""Marks that the job should be stopped"""
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return f"Session {session_key} was not running function {func_key}"
job_info.should_stop.set()
return "Stopping after current batch finishes"
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Updates information from the active iteration '''
def _refresh_cur_iter_func(
self, func_key: FuncKey, session_key: str
) -> List[Component]:
"""Updates information from the active iteration"""
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
@ -240,19 +283,26 @@ class JobManager:
job_info.refresh_active_image_requested.set()
if job_info.refresh_active_image_done.wait(timeout=20.0):
job_info.refresh_active_image_done.clear()
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
return [
gr.Image.update(value=job_info.active_image, visible=True),
f"Sample iteration {job_info.active_iteration_cnt}",
]
return [gr.Image.update(visible=False), "Timed out getting image"]
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Marks that the active iteration should be stopped'''
def _stop_cur_iter_func(
self, func_key: FuncKey, session_key: str
) -> List[Component]:
"""Marks that the active iteration should be stopped"""
session_info, job_info = self._get_call_info(func_key, session_key)
if job_info is None:
return [None, f"Session {session_key} was not running function {func_key}"]
job_info.stop_cur_iter.set()
return [gr.Image.update(visible=False), "Stopping current iteration"]
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
''' Helper to get the SessionInfo and JobInfo. '''
def _get_call_info(
self, func_key: FuncKey, session_key: str
) -> Tuple[SessionInfo, JobInfo]:
"""Helper to get the SessionInfo and JobInfo."""
session_info = self._sessions.get(session_key, None)
if not session_info:
print(f"Couldn't find session {session_key} for call to {func_key}")
@ -268,7 +318,7 @@ class JobManager:
return session_info, job_info
def _run_queued_jobs(self) -> None:
''' Runs queued jobs for any available slots '''
"""Runs queued jobs for any available slots"""
if self._avail_job_tokens:
try:
# Notify next queued job it may begin
@ -282,10 +332,18 @@ class JobManager:
pass # No queued jobs
def _pre_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
session_key: str) -> List[Component]:
''' Called when a job is about to start '''
self,
func_key: FuncKey,
output_dummy_obj: Component,
refresh_btn: gr.Button,
stop_btn: gr.Button,
status_text: gr.Textbox,
active_image: gr.Image,
active_refresh_btn: gr.Button,
active_stop_btn: gr.Button,
session_key: str,
) -> List[Component]:
"""Called when a job is about to start"""
session_info, job_info = self._get_call_info(func_key, session_key)
# If we didn't already get a token then queue up for one
@ -293,16 +351,23 @@ class JobManager:
job_info.job_token = self._get_job_token(block=True)
# Buttons don't seem to update unless value is set on them as well...
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
}
return {
output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
status_text: gr.Textbox.update(
value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"
),
active_refresh_btn: gr.Button.update(
variant="primary", value=active_refresh_btn.value
),
active_stop_btn: gr.Button.update(
variant="primary", value=active_stop_btn.value
),
}
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Runs the real function with job management. '''
"""Runs the real function with job management."""
session_info, job_info = self._get_call_info(func_key, session_key)
if session_info is None or job_info is None:
return []
@ -310,7 +375,9 @@ class JobManager:
job_info.started = True
try:
if job_info.should_stop.is_set():
raise Exception(f"Job {job_info} requested a stop before execution began")
raise Exception(
f"Job {job_info} requested a stop before execution began"
)
outputs = job_info.func(*job_info.inputs, job_info=job_info)
except Exception as e:
job_info.job_status = f"Error: {e}"
@ -334,35 +401,56 @@ class JobManager:
return tuple(filtered_output)
def _post_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
session_key: str) -> List[Component]:
''' Called when a job completes '''
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has finished!"),
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
active_image: gr.Image.update(visible=False)
}
self,
func_key: FuncKey,
output_dummy_obj: Component,
refresh_btn: gr.Button,
stop_btn: gr.Button,
status_text: gr.Textbox,
active_image: gr.Image,
active_refresh_btn: gr.Button,
active_stop_btn: gr.Button,
session_key: str,
) -> List[Component]:
"""Called when a job completes"""
return {
output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
status_text: gr.Textbox.update(value="Generation has finished!"),
active_refresh_btn: gr.Button.update(
variant="secondary", value=active_refresh_btn.value
),
active_stop_btn: gr.Button.update(
variant="secondary", value=active_stop_btn.value
),
active_image: gr.Image.update(visible=False),
}
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
''' Updates the gallery with results from the given job.
Frees the images after return if the job is finished.
Triggered by changing the update_gallery_obj dummy object '''
def _update_gallery_event(
self, func_key: FuncKey, session_key: str
) -> List[Component]:
"""Updates the gallery with results from the given job.
Frees the images after return if the job is finished.
Triggered by changing the update_gallery_obj dummy object"""
session_info, job_info = self._get_call_info(func_key, session_key)
if session_info is None or job_info is None:
return []
return job_info.images
def _wrap_func(self, func: Callable, inputs: List[Component],
outputs: List[Component],
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
''' handles JobManageUI's wrap_func'''
def _wrap_func(
self,
func: Callable,
inputs: List[Component],
outputs: List[Component],
job_ui: JobManagerUi,
) -> Tuple[Callable, List[Component]]:
"""handles JobManageUI's wrap_func"""
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
assert (
gr.context.Context.block is not None
), "wrap_func must be called within a 'gr.Blocks' 'with' context"
# Create a unique key for this job
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
@ -370,8 +458,11 @@ class JobManager:
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
if self._session_key is None:
# When this gradio object is received as an event handler input it will resolve to a unique per-session id
self._session_key = gr.JSON(value=lambda: uuid.uuid4().hex, visible=False,
elem_id="JobManagerDummyObject_sessionKey")
self._session_key = gr.JSON(
value=lambda: uuid.uuid4().hex,
visible=False,
elem_id="JobManagerDummyObject_sessionKey",
)
# Pull the gallery out of the original outputs and assign it to the gallery update dummy object
gallery_comp = None
@ -389,25 +480,25 @@ class JobManager:
partial(self._update_gallery_event, func_key),
[self._session_key],
[gallery_comp],
queue=False
queue=False,
)
if job_ui._refresh_btn:
job_ui._refresh_btn.variant = 'secondary'
job_ui._refresh_btn.variant = "secondary"
job_ui._refresh_btn.click(
partial(self._refresh_func, func_key),
[self._session_key],
[update_gallery_obj, job_ui._status_text],
queue=False
queue=False,
)
if job_ui._stop_btn:
job_ui._stop_btn.variant = 'secondary'
job_ui._stop_btn.variant = "secondary"
job_ui._stop_btn.click(
partial(self._stop_wrapped_func, func_key),
[self._session_key],
[job_ui._status_text],
queue=False
queue=False,
)
if job_ui._active_image and job_ui._active_image_refresh_btn:
@ -415,7 +506,7 @@ class JobManager:
partial(self._refresh_cur_iter_func, func_key),
[self._session_key],
[job_ui._active_image, job_ui._status_text],
queue=False
queue=False,
)
if job_ui._active_image_stop_btn:
@ -423,19 +514,15 @@ class JobManager:
partial(self._stop_cur_iter_func, func_key),
[self._session_key],
[job_ui._active_image, job_ui._status_text],
queue=False
queue=False,
)
if job_ui._stop_all_session_btn:
job_ui._stop_all_session_btn.click(
self.stop_all_jobs, [], [],
queue=False
)
job_ui._stop_all_session_btn.click(self.stop_all_jobs, [], [], queue=False)
if job_ui._free_done_sessions_btn:
job_ui._free_done_sessions_btn.click(
self.clear_all_finished_jobs, [], [],
queue=False
self.clear_all_finished_jobs, [], [], queue=False
)
# (ab)use gr.JSON to forward events.
@ -452,18 +539,26 @@ class JobManager:
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
# the Component as a key... so group together the UI components that the event listeners are going to update
# to make it easy to append to function calls and outputs
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
job_ui_params = [
job_ui._refresh_btn,
job_ui._stop_btn,
job_ui._status_text,
job_ui._active_image,
job_ui._active_image_refresh_btn,
job_ui._active_image_stop_btn,
]
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
# to be able to update the UI before and after, as well as run the actual call
post_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_postCall")
post_call_dummyobj = gr.JSON(
visible=False, elem_id="JobManagerDummyObject_postCall"
)
post_call_dummyobj.change(
partial(self._post_call_func, func_key, update_gallery_obj, *job_ui_params),
[self._session_key],
[update_gallery_obj] + job_ui_outputs,
queue=False
queue=False,
)
call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_runCall")
@ -471,20 +566,27 @@ class JobManager:
partial(self._call_func, func_key),
[self._session_key],
outputs + [post_call_dummyobj],
queue=False
queue=False,
)
pre_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_preCall")
pre_call_dummyobj = gr.JSON(
visible=False, elem_id="JobManagerDummyObject_preCall"
)
pre_call_dummyobj.change(
partial(self._pre_call_func, func_key, call_dummyobj, *job_ui_params),
[self._session_key],
[call_dummyobj] + job_ui_outputs,
queue=False
queue=False,
)
# Add any components that we want the runtime values for
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
added_inputs = [
self._session_key,
job_ui._rec_steps_checkbox,
job_ui._save_rec_steps_to_gallery_chkbx,
job_ui._save_rec_steps_to_file_chkbx,
job_ui._rec_steps_intrvl_sldr,
]
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
def wrapped_func(*wrapped_inputs):
@ -505,12 +607,19 @@ class JobManager:
if func_key in session_info.jobs:
job_info = session_info.jobs[func_key]
# If the job seems stuck in 'starting' then go ahead and toss it
if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME:
if (
not job_info.started
and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME
):
job_info.should_stop.set()
job_info.stop_cur_iter.set()
session_info.jobs.pop(func_key)
return {job_ui._status_text: "Canceled possibly hung job. Try again"}
return {job_ui._status_text: "This session is already running that function!"}
return {
job_ui._status_text: "Canceled possibly hung job. Try again"
}
return {
job_ui._status_text: "This session is already running that function!"
}
# Is this a new run of a previously finished job? Clear old info
if func_key in session_info.finished_jobs:
@ -518,9 +627,17 @@ class JobManager:
job_token = self._get_job_token(block=False)
job = JobInfo(
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time())
inputs=job_inputs,
func=func,
removed_output_idxs=removed_idxs,
session_key=session_key,
job_token=job_token,
rec_steps_enabled=record_steps_enabled,
rec_steps_intrvl=rec_steps_interval,
rec_steps_to_gallery=save_rec_steps_grid,
rec_steps_to_file=save_rec_steps_file,
timestamp=time.time(),
)
session_info.jobs[func_key] = job
ret = {pre_call_dummyobj: triggerChangeEvent()}
@ -528,4 +645,8 @@ class JobManager:
ret[job_ui._status_text] = "Job is queued"
return ret
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
return (
wrapped_func,
inputs + added_inputs,
[pre_call_dummyobj, job_ui._status_text],
)

View File

@ -240,4 +240,4 @@ svg.no-preview-icon {
border-color: var(--primary-color);
color: var(--primary-color);
} */
</style>
</style>

View File

@ -49,4 +49,3 @@ body, html {
</style>

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#fff"><path d="M 16 2 C 14.742188 2 13.847656 2.890625 13.40625 4 L 5 4 L 5 29 L 27 29 L 27 4 L 18.59375 4 C 18.152344 2.890625 17.257813 2 16 2 Z M 16 4 C 16.554688 4 17 4.445313 17 5 L 17 6 L 20 6 L 20 8 L 12 8 L 12 6 L 15 6 L 15 5 C 15 4.445313 15.445313 4 16 4 Z M 7 6 L 10 6 L 10 10 L 22 10 L 22 6 L 25 6 L 25 27 L 7 27 Z M 21.28125 13.28125 L 15 19.5625 L 11.71875 16.28125 L 10.28125 17.71875 L 14.28125 21.71875 L 15 22.40625 L 15.71875 21.71875 L 22.71875 14.71875 Z"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#fff"><path d="M 16 2 C 14.742188 2 13.847656 2.890625 13.40625 4 L 5 4 L 5 29 L 27 29 L 27 4 L 18.59375 4 C 18.152344 2.890625 17.257813 2 16 2 Z M 16 4 C 16.554688 4 17 4.445313 17 5 L 17 6 L 20 6 L 20 8 L 12 8 L 12 6 L 15 6 L 15 5 C 15 4.445313 15.445313 4 16 4 Z M 7 6 L 10 6 L 10 10 L 22 10 L 22 6 L 25 6 L 25 27 L 7 27 Z M 21.28125 13.28125 L 15 19.5625 L 11.71875 16.28125 L 10.28125 17.71875 L 14.28125 21.71875 L 15 22.40625 L 15.71875 21.71875 L 22.71875 14.71875 Z"/></svg>

Before

Width:  |  Height:  |  Size: 550 B

After

Width:  |  Height:  |  Size: 551 B

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#fff"><path d="M 15 3 C 13.742188 3 12.847656 3.890625 12.40625 5 L 5 5 L 5 28 L 13 28 L 13 30 L 27 30 L 27 14 L 25 14 L 25 5 L 17.59375 5 C 17.152344 3.890625 16.257813 3 15 3 Z M 15 5 C 15.554688 5 16 5.445313 16 6 L 16 7 L 19 7 L 19 9 L 11 9 L 11 7 L 14 7 L 14 6 C 14 5.445313 14.445313 5 15 5 Z M 7 7 L 9 7 L 9 11 L 21 11 L 21 7 L 23 7 L 23 14 L 13 14 L 13 26 L 7 26 Z M 15 16 L 25 16 L 25 28 L 15 28 Z"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#fff"><path d="M 15 3 C 13.742188 3 12.847656 3.890625 12.40625 5 L 5 5 L 5 28 L 13 28 L 13 30 L 27 30 L 27 14 L 25 14 L 25 5 L 17.59375 5 C 17.152344 3.890625 16.257813 3 15 3 Z M 15 5 C 15.554688 5 16 5.445313 16 6 L 16 7 L 19 7 L 19 9 L 11 9 L 11 7 L 14 7 L 14 6 C 14 5.445313 14.445313 5 15 5 Z M 7 7 L 9 7 L 9 11 L 21 11 L 21 7 L 23 7 L 23 14 L 13 14 L 13 26 L 7 26 Z M 15 16 L 25 16 L 25 28 L 15 28 Z"/></svg>

Before

Width:  |  Height:  |  Size: 481 B

After

Width:  |  Height:  |  Size: 482 B

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#3F6078"><path d="M 30.335938 12.546875 L 20.164063 11.472656 L 16 2.132813 L 11.835938 11.472656 L 1.664063 12.546875 L 9.261719 19.394531 L 7.140625 29.398438 L 16 24.289063 L 24.859375 29.398438 L 22.738281 19.394531 Z"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#3F6078"><path d="M 30.335938 12.546875 L 20.164063 11.472656 L 16 2.132813 L 11.835938 11.472656 L 1.664063 12.546875 L 9.261719 19.394531 L 7.140625 29.398438 L 16 24.289063 L 24.859375 29.398438 L 22.738281 19.394531 Z"/></svg>

Before

Width:  |  Height:  |  Size: 296 B

After

Width:  |  Height:  |  Size: 297 B

View File

@ -1 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#3F6078"><path d="M 30.335938 12.546875 L 20.164063 11.472656 L 16 2.132813 L 11.835938 11.472656 L 1.664063 12.546875 L 9.261719 19.394531 L 7.140625 29.398438 L 16 24.289063 L 24.859375 29.398438 L 22.738281 19.394531 Z"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" fill="#3F6078"><path d="M 30.335938 12.546875 L 20.164063 11.472656 L 16 2.132813 L 11.835938 11.472656 L 1.664063 12.546875 L 9.261719 19.394531 L 7.140625 29.398438 L 16 24.289063 L 24.859375 29.398438 L 22.738281 19.394531 Z"/></svg>

Before

Width:  |  Height:  |  Size: 296 B

After

Width:  |  Height:  |  Size: 297 B

View File

@ -5,7 +5,7 @@
<h1 class="err__title">Component Error</h1>
<div class="err__msg">Message: {{ componentError }}</div>
</div>
<!--
<!--
Else render the component slot and pass Streamlit event data in `args` props to it.
Don't render until we've gotten our first RENDER_EVENT from Streamlit.
All components get disabled while the app is being re-run, and become re-enabled when the re-run has finished.

View File

@ -12,122 +12,170 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import re
import gradio as gr
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
from PIL import Image
from io import BytesIO
import base64
import re
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
def change_image_editor_mode(
choice, cropped_image, masked_image, resize_mode, width, height
):
if choice == "Mask":
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
update_image_result = update_image_mask(
cropped_image, resize_mode, width, height
)
return [
gr.update(visible=False),
update_image_result,
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
]
update_image_result = update_image_mask(
masked_image["image"] if masked_image is not None else None,
resize_mode,
width,
height,
)
return [
update_image_result,
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
]
update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height)
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
resized_cropped_image = (
resize_image(resize_mode, cropped_image, width, height)
if cropped_image
else None
)
return gr.update(value=resized_cropped_image, visible=True)
def toggle_options_gfpgan(selection):
if 0 in selection:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def toggle_options_upscalers(selection):
if 1 in selection:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def toggle_options_realesrgan(selection):
if selection == 0 or selection == 1 or selection == 3:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def toggle_options_gobig(selection):
if selection == 1:
#print(selection)
# print(selection)
return gr.update(visible=True)
if selection == 3:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def toggle_options_ldsr(selection):
if selection == 2 or selection == 3:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def increment_down(value):
return value - 1
def increment_up(value):
return value + 1
def copy_img_to_lab(img):
try:
image_data = re.sub('^data:image/.+;base64,', '', img)
image_data = re.sub("^data:image/.+;base64,", "", img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
tab_update = gr.update(selected='imgproc_tab')
img_update = gr.update(value=processed_image)
return processed_image, tab_update,
tab_update = gr.update(selected="imgproc_tab")
gr.update(value=processed_image)
return (
processed_image,
tab_update,
)
except IndexError:
return [None, None]
def copy_img_params_to_lab(params):
try:
prompt = params[0][0].replace('\n', ' ').replace('\r', '')
prompt = params[0][0].replace("\n", " ").replace("\r", "")
seed = int(params[1][1])
steps = int(params[7][1])
cfg_scale = float(params[9][1])
sampler = params[11][1]
return prompt,seed,steps,cfg_scale,sampler
return prompt, seed, steps, cfg_scale, sampler
except IndexError:
return [None, None]
def copy_img_to_input(img):
try:
image_data = re.sub('^data:image/.+;base64,', '', img)
image_data = re.sub("^data:image/.+;base64,", "", img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
tab_update = gr.update(selected='img2img_tab')
img_update = gr.update(value=processed_image)
return processed_image, processed_image , tab_update
tab_update = gr.update(selected="img2img_tab")
gr.update(value=processed_image)
return processed_image, processed_image, tab_update
except IndexError:
return [None, None]
def copy_img_to_edit(img):
try:
image_data = re.sub('^data:image/.+;base64,', '', img)
image_data = re.sub("^data:image/.+;base64,", "", img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
tab_update = gr.update(selected='img2img_tab')
img_update = gr.update(value=processed_image)
mode_update = gr.update(value='Crop')
tab_update = gr.update(selected="img2img_tab")
gr.update(value=processed_image)
mode_update = gr.update(value="Crop")
return processed_image, tab_update, mode_update
except IndexError:
return [None, None]
def copy_img_to_mask(img):
try:
image_data = re.sub('^data:image/.+;base64,', '', img)
image_data = re.sub("^data:image/.+;base64,", "", img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
tab_update = gr.update(selected='img2img_tab')
img_update = gr.update(value=processed_image)
mode_update = gr.update(value='Mask')
tab_update = gr.update(selected="img2img_tab")
gr.update(value=processed_image)
mode_update = gr.update(value="Mask")
return processed_image, tab_update, mode_update
except IndexError:
return [None, None]
def copy_img_to_upscale_esrgan(img):
tabs_update = gr.update(selected='realesrgan_tab')
image_data = re.sub('^data:image/.+;base64,', '', img)
tabs_update = gr.update(selected="realesrgan_tab")
image_data = re.sub("^data:image/.+;base64,", "", img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
return processed_image, tabs_update
@ -147,8 +195,11 @@ help_text = """
If anything breaks, try switching modes again, switch tabs, clear the image, or reload.
"""
def resize_image(resize_mode, im, width, height):
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
LANCZOS = (
Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS
)
if resize_mode == 0:
res = im.resize((width, height), resample=LANCZOS)
elif resize_mode == 1:
@ -174,28 +225,45 @@ def resize_image(resize_mode, im, width, height):
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
res.paste(
resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)
)
res.paste(
resized.resize(
(width, fill_height), box=(0, resized.height, width, resized.height)
),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
res.paste(
resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)
)
res.paste(
resized.resize(
(fill_width, height), box=(resized.width, 0, resized.width, height)
),
box=(fill_width + src_w, 0),
)
return res
def update_dimensions_info(width, height):
pixel_count_formated = "{:,.0f}".format(width * height)
return f"Aspect ratio: {round(width / height, 5)}\nTotal pixel count: {pixel_count_formated}"
def get_png_nfo( image: Image ):
def get_png_nfo(image: Image):
info_text = ""
visible = bool(image and any(image.info))
if visible:
for key,value in image.info.items():
for key, value in image.info.items():
info_text += f"{key}: {value}\n"
info_text = info_text.rstrip('\n')
info_text = info_text.rstrip("\n")
return gr.Textbox.update(value=info_text, visible=visible)
def load_settings(*values):
new_settings, key_names, checkboxgroup_info = values[-3:]
values = list(values[:-3])
@ -205,7 +273,9 @@ def load_settings(*values):
if os.path.exists(new_settings):
with open(new_settings, "r", encoding="utf8") as f:
new_settings = yaml.safe_load(f)
elif new_settings.startswith("file://") and os.path.exists(new_settings[7:]):
elif new_settings.startswith("file://") and os.path.exists(
new_settings[7:]
):
with open(new_settings[7:], "r", encoding="utf8") as f:
new_settings = yaml.safe_load(f)
else:
@ -216,7 +286,10 @@ def load_settings(*values):
new_settings = new_settings["txt2img"]
target = new_settings.pop("target", "txt2img")
if target != "txt2img":
print(f"Warning: applying settings to txt2img even though {target} is specified as target.", file=sys.stderr)
print(
f"Warning: applying settings to txt2img even though {target} is specified as target.",
file=sys.stderr,
)
skipped_settings = {}
for key in new_settings.keys():
@ -228,7 +301,7 @@ def load_settings(*values):
print(f"Settings could not be applied: {skipped_settings}", file=sys.stderr)
# Convert lists of checkbox indices to lists of checkbox labels:
for (cbg_index, cbg_choices) in checkboxgroup_info:
for cbg_index, cbg_choices in checkboxgroup_info:
values[cbg_index] = [cbg_choices[i] for i in values[cbg_index]]
return values

View File

@ -13,7 +13,7 @@
:: GNU Affero General Public License for more details.
:: You should have received a copy of the GNU Affero General Public License
:: along with this program. If not, see <http://www.gnu.org/licenses/>.
:: along with this program. If not, see <http://www.gnu.org/licenses/>.
:: Run all commands using this script's directory as the working directory
cd %~dp0

View File

@ -13,7 +13,7 @@
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Start the Stable Diffusion WebUI for Linux Users
DIRECTORY="."
@ -33,9 +33,9 @@ REALESRGAN_ANIME_MODEL="https://github.com/xinntao/Real-ESRGAN/releases/download
SD_CONCEPT_REPO="https://github.com/Sygil-Dev/sd-concepts-library/archive/refs/heads/main.zip"
if [[ -f $ENV_MODIFED_FILE ]]; then
if [[ -f $ENV_MODIFED_FILE ]]; then
ENV_MODIFIED_CACHED=$(<${ENV_MODIFED_FILE})
else
else
ENV_MODIFIED_CACHED=0
fi
@ -93,7 +93,7 @@ sd_model_loading () {
printf "\n\n########## MOVE MODEL FILE ##########\n\n"
printf "Please download the 1.4 AI Model from Huggingface (or another source) and place it inside of the sygil-webui folder\n\n"
read -p "Once you have sd-v1-4.ckpt in the project root, Press Enter...\n\n"
# Check to make sure checksum of models is the original one from HuggingFace and not a fake model set
printf "fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556 sd-v1-4.ckpt" | sha256sum --check || exit 1
mv sd-v1-4.ckpt $DIRECTORY/models/ldm/stable-diffusion-v1/model.ckpt
@ -166,4 +166,4 @@ start_initialization () {
}
start_initialization "$@"
start_initialization "$@"

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 6.3 KiB

After

Width:  |  Height:  |  Size: 6.3 KiB

View File

@ -3,7 +3,11 @@ from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
from data.coco_karpathy_dataset import (
coco_karpathy_train,
coco_karpathy_caption_eval,
coco_karpathy_retrieval_eval,
)
from data.nocaps_dataset import nocaps_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
from data.vqa_dataset import vqa_dataset
@ -11,77 +15,154 @@ from data.nlvr_dataset import nlvr_dataset
from data.pretrain_dataset import pretrain_dataset
from transform.randaugment import RandomAugment
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
transform_train = transforms.Compose([
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
)
transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(
config["image_size"],
scale=(min_scale, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
"AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
"ShearX",
"ShearY",
"TranslateX",
"TranslateY",
"Rotate",
],
),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
normalize,
])
if dataset=='pretrain':
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
return dataset
elif dataset=='caption_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
]
)
transform_test = transforms.Compose(
[
transforms.Resize(
(config["image_size"], config["image_size"]),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
normalize,
]
)
if dataset == "pretrain":
dataset = pretrain_dataset(
config["train_file"], config["laion_path"], transform_train
)
return dataset
elif dataset == "caption_coco":
train_dataset = coco_karpathy_train(
transform_train,
config["image_root"],
config["ann_root"],
prompt=config["prompt"],
)
val_dataset = coco_karpathy_caption_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = coco_karpathy_caption_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
elif dataset=='nocaps':
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return val_dataset, test_dataset
elif dataset=='retrieval_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='retrieval_flickr':
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='vqa':
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
train_files = config['train_files'], split='train')
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
elif dataset == "nocaps":
val_dataset = nocaps_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = nocaps_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return val_dataset, test_dataset
elif dataset == "retrieval_coco":
train_dataset = coco_karpathy_train(
transform_train, config["image_root"], config["ann_root"]
)
val_dataset = coco_karpathy_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = coco_karpathy_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
elif dataset == "retrieval_flickr":
train_dataset = flickr30k_train(
transform_train, config["image_root"], config["ann_root"]
)
val_dataset = flickr30k_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = flickr30k_retrieval_eval(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
elif dataset == "vqa":
train_dataset = vqa_dataset(
transform_train,
config["ann_root"],
config["vqa_root"],
config["vg_root"],
train_files=config["train_files"],
split="train",
)
test_dataset = vqa_dataset(
transform_test,
config["ann_root"],
config["vqa_root"],
config["vg_root"],
split="test",
)
return train_dataset, test_dataset
elif dataset=='nlvr':
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
return train_dataset, val_dataset, test_dataset
elif dataset == "nlvr":
train_dataset = nlvr_dataset(
transform_train, config["image_root"], config["ann_root"], "train"
)
val_dataset = nlvr_dataset(
transform_test, config["image_root"], config["ann_root"], "val"
)
test_dataset = nlvr_dataset(
transform_test, config["image_root"], config["ann_root"], "test"
)
return train_dataset, val_dataset, test_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
for dataset, shuffle in zip(datasets, shuffles):
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
)
samplers.append(sampler)
return samplers
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = (sampler is None)
shuffle = sampler is None
drop_last = True
else:
shuffle = False
@ -95,7 +176,6 @@ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collat
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
)
loaders.append(loader)
return loaders
return loaders

View File

@ -1,11 +1,12 @@
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from torch.utils.data import IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
"""
Define an interface to make the IterableDatasets for text2img data chainable
'''
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
@ -13,11 +14,11 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
pass

View File

@ -8,119 +8,121 @@ from PIL import Image
from data.utils import pre_caption
class coco_karpathy_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=""):
"""
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
filename = 'coco_karpathy_train.json'
"""
url = "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json"
filename = "coco_karpathy_train.json"
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
download_url(url, ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filename), "r"))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
img_id = ann["image_id"]
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
n += 1
return image, caption, self.img_ids[ann['image_id']]
class coco_karpathy_caption_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
return image, int(img_id)
class coco_karpathy_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_path = os.path.join(self.image_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
caption = self.prompt + pre_caption(ann["caption"], self.max_words)
return image, caption, self.img_ids[ann["image_id"]]
class coco_karpathy_caption_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
"""
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
"""
urls = {
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json",
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json",
}
filenames = {"val": "coco_karpathy_val.json", "test": "coco_karpathy_test.json"}
download_url(urls[split], ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r"))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
return image, int(img_id)
class coco_karpathy_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
"""
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
"""
urls = {
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json",
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json",
}
filenames = {"val": "coco_karpathy_val.json", "test": "coco_karpathy_test.json"}
download_url(urls[split], ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r"))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.image.append(ann["image"])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
for i, caption in enumerate(ann["caption"]):
self.text.append(pre_caption(caption, max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
return image, index

View File

@ -8,86 +8,87 @@ from PIL import Image
from data.utils import pre_caption
class flickr30k_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=""):
"""
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
filename = 'flickr30k_train.json'
"""
url = "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json"
filename = "flickr30k_train.json"
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
download_url(url, ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filename), "r"))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
img_id = ann["image_id"]
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
caption = self.prompt + pre_caption(ann["caption"], self.max_words)
return image, caption, self.img_ids[ann["image_id"]]
class flickr30k_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
def __init__(self, transform, image_root, ann_root, split, max_words=30):
"""
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
"""
urls = {
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json",
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json",
}
filenames = {"val": "flickr30k_val.json", "test": "flickr30k_test.json"}
download_url(urls[split], ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r"))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.image.append(ann["image"])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
for i, caption in enumerate(ann["caption"]):
self.text.append(pre_caption(caption, max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
return image, index

View File

@ -11,7 +11,12 @@ from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import (
str_to_indices,
give_synsets_from_indices,
download,
retrieve,
)
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
@ -20,13 +25,13 @@ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr
def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
@ -46,13 +51,17 @@ class ImageNetBase(Dataset):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
ignore = set(
[
"n06596364_9591.JPEG",
]
)
relpaths = [rpath for rpath in relpaths if rpath.split("/")[-1] not in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
synsets = give_synsets_from_indices(
indices, path_to_yaml=self.idx2syn
) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
@ -67,20 +76,24 @@ class ImageNetBase(Dataset):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
if (
not os.path.exists(self.human_dict)
or not os.path.getsize(self.human_dict) == SIZE
):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
if not os.path.exists(self.idx2syn):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
self.human2integer = os.path.join(
self.root, "imagenet1000_clsidx_to_labels.txt"
)
if not os.path.exists(self.human2integer):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
lines = f.read().splitlines()
@ -95,7 +108,11 @@ class ImageNetBase(Dataset):
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
print(
"Removed {} files from filelist during filtering.".format(
l1 - len(self.relpaths)
)
)
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
@ -122,11 +139,12 @@ class ImageNetBase(Dataset):
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
self.data = ImagePaths(
self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
@ -157,8 +175,9 @@ class ImageNetTrain(ImageNetBase):
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
self.random_crop = retrieve(
self.config, "ImageNetTrain/random_crop", default=True
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
@ -166,8 +185,12 @@ class ImageNetTrain(ImageNetBase):
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
@ -179,7 +202,7 @@ class ImageNetTrain(ImageNetBase):
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
subdir = subpath[: -len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
@ -187,7 +210,7 @@ class ImageNetTrain(ImageNetBase):
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
@ -222,8 +245,9 @@ class ImageNetValidation(ImageNetBase):
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
self.random_crop = retrieve(
self.config, "ImageNetValidation/random_crop", default=False
)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
@ -231,8 +255,12 @@ class ImageNetValidation(ImageNetBase):
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
if (
not os.path.exists(path)
or not os.path.getsize(path) == self.SIZES[0]
):
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
@ -242,7 +270,10 @@ class ImageNetValidation(ImageNetBase):
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
if (
not os.path.exists(vspath)
or not os.path.getsize(vspath) == self.SIZES[1]
):
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
@ -261,18 +292,23 @@ class ImageNetValidation(ImageNetBase):
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
def __init__(
self,
size=None,
degradation=None,
downscale_f=4,
min_crop_f=0.5,
max_crop_f=1.0,
random_crop=True,
):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
@ -296,12 +332,16 @@ class ImageNetSR(Dataset):
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
assert max_crop_f <= 1.0
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.image_rescaler = albumentations.SmallestMaxSize(
max_size=size, interpolation=cv2.INTER_AREA
)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
self.pil_interpolation = (
False # gets reset later if incase interp_op is from pillow
)
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
@ -311,27 +351,30 @@ class ImageNetSR(Dataset):
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
self.degradation_process = partial(
TF.resize, size=self.LR_size, interpolation=interpolation_fn
)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
self.degradation_process = albumentations.SmallestMaxSize(
max_size=self.LR_size, interpolation=interpolation_fn
)
def __len__(self):
return len(self.base)
@ -346,14 +389,20 @@ class ImageNetSR(Dataset):
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = min_side_len * np.random.uniform(
self.min_crop_f, self.max_crop_f, size=None
)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.CenterCrop(
height=crop_side_len, width=crop_side_len
)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
self.cropper = albumentations.RandomCrop(
height=crop_side_len, width=crop_side_len
)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
@ -366,8 +415,8 @@ class ImageNetSR(Dataset):
else:
LR_image = self.degradation_process(image=image)["image"]
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example
@ -379,7 +428,9 @@ class ImageNetSRTrain(ImageNetSR):
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
dset = ImageNetTrain(
process_images=False,
)
return Subset(dset, indices)
@ -390,5 +441,7 @@ class ImageNetSRValidation(ImageNetSR):
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
dset = ImageNetValidation(
process_images=False,
)
return Subset(dset, indices)

View File

@ -7,13 +7,9 @@ from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
def __init__(
self, txt_file, data_root, size=None, interpolation="bicubic", flip_p=0.5
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
@ -21,16 +17,16 @@ class LSUNBase(Dataset):
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l) for l in self.image_paths],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
@ -45,9 +41,14 @@ class LSUNBase(Dataset):
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
@ -61,32 +62,54 @@ class LSUNBase(Dataset):
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
super().__init__(
txt_file="data/lsun/church_outdoor_train.txt",
data_root="data/lsun/churches",
**kwargs
)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file="data/lsun/church_outdoor_val.txt",
data_root="data/lsun/churches",
flip_p=flip_p,
**kwargs
)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
super().__init__(
txt_file="data/lsun/bedrooms_train.txt",
data_root="data/lsun/bedrooms",
**kwargs
)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
super().__init__(
txt_file="data/lsun/bedrooms_val.txt",
data_root="data/lsun/bedrooms",
flip_p=flip_p,
**kwargs
)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
super().__init__(
txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs
)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(
txt_file="data/lsun/cat_val.txt",
data_root="data/lsun/cats",
flip_p=flip_p,
**kwargs
)

View File

@ -9,70 +9,71 @@ from PIL import Image
from data.utils import pre_caption
class nlvr_dataset(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images
def __init__(self, transform, image_root, ann_root, split):
"""
image_root (string): Root directory of images
ann_root (string): directory to store the annotation file
split (string): train, val or test
'''
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
"""
urls = {
"train": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json",
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json",
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json",
}
filenames = {
"train": "nlvr_train.json",
"val": "nlvr_dev.json",
"test": "nlvr_test.json",
}
download_url(urls[split], ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r"))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
def __getitem__(self, index):
ann = self.annotation[index]
image0_path = os.path.join(self.image_root,ann['images'][0])
image0 = Image.open(image0_path).convert('RGB')
image0 = self.transform(image0)
image1_path = os.path.join(self.image_root,ann['images'][1])
image1 = Image.open(image1_path).convert('RGB')
image1 = self.transform(image1)
sentence = pre_caption(ann['sentence'], 40)
if ann['label']=='True':
image0_path = os.path.join(self.image_root, ann["images"][0])
image0 = Image.open(image0_path).convert("RGB")
image0 = self.transform(image0)
image1_path = os.path.join(self.image_root, ann["images"][1])
image1 = Image.open(image1_path).convert("RGB")
image1 = self.transform(image1)
sentence = pre_caption(ann["sentence"], 40)
if ann["label"] == "True":
label = 1
else:
label = 0
words = sentence.split(' ')
if 'left' not in words and 'right' not in words:
if random.random()<0.5:
words = sentence.split(" ")
if "left" not in words and "right" not in words:
if random.random() < 0.5:
return image0, image1, sentence, label
else:
return image1, image0, sentence, label
else:
if random.random()<0.5:
if random.random() < 0.5:
return image0, image1, sentence, label
else:
new_words = []
for word in words:
if word=='left':
new_words.append('right')
elif word=='right':
new_words.append('left')
if word == "left":
new_words.append("right")
elif word == "right":
new_words.append("left")
else:
new_words.append(word)
sentence = ' '.join(new_words)
new_words.append(word)
sentence = " ".join(new_words)
return image1, image0, sentence, label

View File

@ -6,27 +6,29 @@ from torchvision.datasets.utils import download_url
from PIL import Image
class nocaps_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
def __init__(self, transform, image_root, ann_root, split):
urls = {
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json",
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json",
}
filenames = {"val": "nocaps_val.json", "test": "nocaps_test.json"}
download_url(urls[split], ann_root)
self.annotation = json.load(open(os.path.join(ann_root, filenames[split]), "r"))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, int(ann['img_id'])
image_path = os.path.join(self.image_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
return image, int(ann["img_id"])

View File

@ -8,92 +8,92 @@ from torchvision import transforms
import random
imagenet_templates_smallest = [
'a photo of a {}',
"a photo of a {}",
]
imagenet_templates_small = [
'a photo of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a photo of a clean {}',
'a photo of a dirty {}',
'a dark photo of the {}',
'a photo of my {}',
'a photo of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a photo of the {}',
'a good photo of the {}',
'a photo of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a photo of the clean {}',
'a rendition of a {}',
'a photo of a nice {}',
'a good photo of a {}',
'a photo of the nice {}',
'a photo of the small {}',
'a photo of the weird {}',
'a photo of the large {}',
'a photo of a cool {}',
'a photo of a small {}',
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_dual_templates_small = [
'a photo of a {} with {}',
'a rendering of a {} with {}',
'a cropped photo of the {} with {}',
'the photo of a {} with {}',
'a photo of a clean {} with {}',
'a photo of a dirty {} with {}',
'a dark photo of the {} with {}',
'a photo of my {} with {}',
'a photo of the cool {} with {}',
'a close-up photo of a {} with {}',
'a bright photo of the {} with {}',
'a cropped photo of a {} with {}',
'a photo of the {} with {}',
'a good photo of the {} with {}',
'a photo of one {} with {}',
'a close-up photo of the {} with {}',
'a rendition of the {} with {}',
'a photo of the clean {} with {}',
'a rendition of a {} with {}',
'a photo of a nice {} with {}',
'a good photo of a {} with {}',
'a photo of the nice {} with {}',
'a photo of the small {} with {}',
'a photo of the weird {} with {}',
'a photo of the large {} with {}',
'a photo of a cool {} with {}',
'a photo of a small {} with {}',
"a photo of a {} with {}",
"a rendering of a {} with {}",
"a cropped photo of the {} with {}",
"the photo of a {} with {}",
"a photo of a clean {} with {}",
"a photo of a dirty {} with {}",
"a dark photo of the {} with {}",
"a photo of my {} with {}",
"a photo of the cool {} with {}",
"a close-up photo of a {} with {}",
"a bright photo of the {} with {}",
"a cropped photo of a {} with {}",
"a photo of the {} with {}",
"a good photo of the {} with {}",
"a photo of one {} with {}",
"a close-up photo of the {} with {}",
"a rendition of the {} with {}",
"a photo of the clean {} with {}",
"a rendition of a {} with {}",
"a photo of a nice {} with {}",
"a good photo of a {} with {}",
"a photo of the nice {} with {}",
"a photo of the small {} with {}",
"a photo of the weird {} with {}",
"a photo of the large {} with {}",
"a photo of a cool {} with {}",
"a photo of a small {} with {}",
]
per_img_token_list = [
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
"א",
"ב",
"ג",
"ד",
"ה",
"ו",
"ז",
"ח",
"ט",
"י",
"כ",
"ל",
"מ",
"נ",
"ס",
"ע",
"פ",
"צ",
"ק",
"ר",
"ש",
"ת",
]
@ -103,16 +103,15 @@ class PersonalizedBase(Dataset):
data_root,
size=None,
repeats=100,
interpolation='bicubic',
interpolation="bicubic",
flip_p=0.5,
set='train',
placeholder_token='*',
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
self.data_root = data_root
self.image_paths = [
@ -137,15 +136,15 @@ class PersonalizedBase(Dataset):
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == 'train':
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -156,32 +155,31 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == 'RGB':
image = image.convert('RGB')
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = (
f'{self.coarse_class_text} {placeholder_string}'
)
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(
placeholder_string, per_img_token_list[i % self.num_images]
)
else:
text = random.choice(imagenet_templates_small).format(
placeholder_string
)
text = random.choice(imagenet_templates_small).format(placeholder_string)
example['caption'] = text
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
@ -192,11 +190,9 @@ class PersonalizedBase(Dataset):
image = Image.fromarray(img)
if self.size is not None:
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -8,70 +8,70 @@ from torchvision import transforms
import random
imagenet_templates_small = [
'a painting in the style of {}',
'a rendering in the style of {}',
'a cropped painting in the style of {}',
'the painting in the style of {}',
'a clean painting in the style of {}',
'a dirty painting in the style of {}',
'a dark painting in the style of {}',
'a picture in the style of {}',
'a cool painting in the style of {}',
'a close-up painting in the style of {}',
'a bright painting in the style of {}',
'a cropped painting in the style of {}',
'a good painting in the style of {}',
'a close-up painting in the style of {}',
'a rendition in the style of {}',
'a nice painting in the style of {}',
'a small painting in the style of {}',
'a weird painting in the style of {}',
'a large painting in the style of {}',
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
imagenet_dual_templates_small = [
'a painting in the style of {} with {}',
'a rendering in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'the painting in the style of {} with {}',
'a clean painting in the style of {} with {}',
'a dirty painting in the style of {} with {}',
'a dark painting in the style of {} with {}',
'a cool painting in the style of {} with {}',
'a close-up painting in the style of {} with {}',
'a bright painting in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'a good painting in the style of {} with {}',
'a painting of one {} in the style of {}',
'a nice painting in the style of {} with {}',
'a small painting in the style of {} with {}',
'a weird painting in the style of {} with {}',
'a large painting in the style of {} with {}',
"a painting in the style of {} with {}",
"a rendering in the style of {} with {}",
"a cropped painting in the style of {} with {}",
"the painting in the style of {} with {}",
"a clean painting in the style of {} with {}",
"a dirty painting in the style of {} with {}",
"a dark painting in the style of {} with {}",
"a cool painting in the style of {} with {}",
"a close-up painting in the style of {} with {}",
"a bright painting in the style of {} with {}",
"a cropped painting in the style of {} with {}",
"a good painting in the style of {} with {}",
"a painting of one {} in the style of {}",
"a nice painting in the style of {} with {}",
"a small painting in the style of {} with {}",
"a weird painting in the style of {} with {}",
"a large painting in the style of {} with {}",
]
per_img_token_list = [
'א',
'ב',
'ג',
'ד',
'ה',
'ו',
'ז',
'ח',
'ט',
'י',
'כ',
'ל',
'מ',
'נ',
'ס',
'ע',
'פ',
'צ',
'ק',
'ר',
'ש',
'ת',
"א",
"ב",
"ג",
"ד",
"ה",
"ו",
"ז",
"ח",
"ט",
"י",
"כ",
"ל",
"מ",
"נ",
"ס",
"ע",
"פ",
"צ",
"ק",
"ר",
"ש",
"ת",
]
@ -81,14 +81,13 @@ class PersonalizedBase(Dataset):
data_root,
size=None,
repeats=100,
interpolation='bicubic',
interpolation="bicubic",
flip_p=0.5,
set='train',
placeholder_token='*',
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [
@ -110,15 +109,15 @@ class PersonalizedBase(Dataset):
per_img_token_list
), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == 'train':
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {
'linear': PIL.Image.LINEAR,
'bilinear': PIL.Image.BILINEAR,
'bicubic': PIL.Image.BICUBIC,
'lanczos': PIL.Image.LANCZOS,
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -129,8 +128,8 @@ class PersonalizedBase(Dataset):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == 'RGB':
image = image.convert('RGB')
if not image.mode == "RGB":
image = image.convert("RGB")
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(
@ -141,14 +140,17 @@ class PersonalizedBase(Dataset):
self.placeholder_token
)
example['caption'] = text
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)
@ -159,11 +161,9 @@ class PersonalizedBase(Dataset):
image = Image.fromarray(img)
if self.size is not None:
image = image.resize(
(self.size, self.size), resample=self.interpolation
)
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example['image'] = (image / 127.5 - 1.0).astype(np.float32)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -1,59 +1,56 @@
import json
import os
import random
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from data.utils import pre_caption
import os,glob
import os, glob
class pretrain_dataset(Dataset):
def __init__(self, ann_file, laion_path, transform):
def __init__(self, ann_file, laion_path, transform):
self.ann_pretrain = []
for f in ann_file:
print('loading '+f)
ann = json.load(open(f,'r'))
print("loading " + f)
ann = json.load(open(f, "r"))
self.ann_pretrain += ann
self.laion_path = laion_path
if self.laion_path:
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
self.laion_files = glob.glob(os.path.join(laion_path, "*.json"))
print('loading '+self.laion_files[0])
with open(self.laion_files[0],'r') as f:
self.ann_laion = json.load(f)
print("loading " + self.laion_files[0])
with open(self.laion_files[0], "r") as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
else:
self.annotation = self.ann_pretrain
self.transform = transform
def reload_laion(self, epoch):
n = epoch%len(self.laion_files)
print('loading '+self.laion_files[n])
with open(self.laion_files[n],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
n = epoch % len(self.laion_files)
print("loading " + self.laion_files[n])
with open(self.laion_files[n], "r") as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image = Image.open(ann['image']).convert('RGB')
def __getitem__(self, index):
ann = self.annotation[index]
image = Image.open(ann["image"]).convert("RGB")
image = self.transform(image)
caption = pre_caption(ann['caption'],30)
return image, caption
caption = pre_caption(ann["caption"], 30)
return image, caption

View File

@ -9,16 +9,16 @@ class AddMiDaS(object):
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
x = ((x + 1.0) * 0.5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
x = torch.from_numpy(x) * 2 - 1.0
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.pt2np(sample["jpg"])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
sample["midas_in"] = x
return sample

View File

@ -1,7 +1,6 @@
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
import torch
import numpy as np
import random
@ -13,67 +12,84 @@ from data.utils import pre_caption
decord.bridge.set_bridge("torch")
class ImageNorm(object):
"""Apply Normalization to Image Pixels on GPU
"""
"""Apply Normalization to Image Pixels on GPU"""
def __init__(self, mean, std):
self.mean = torch.tensor(mean).view(1, 3, 1, 1)
self.std = torch.tensor(std).view(1, 3, 1, 1)
def __call__(self, img):
def __call__(self, img):
if torch.max(img) > 1 and self.mean.max() <= 1:
img.div_(255.)
img.div_(255.0)
return img.sub_(self.mean).div_(self.std)
def load_jsonl(filename):
with open(filename, "r") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
class VideoDataset(Dataset):
def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
'''
class VideoDataset(Dataset):
def __init__(
self,
video_root,
ann_root,
num_frm=4,
frm_sampling_strategy="rand",
max_img_size=384,
video_fmt=".mp4",
):
"""
image_root (string): Root directory of video
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
filename = 'msrvtt_test.jsonl'
"""
url = "https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl"
filename = "msrvtt_test.jsonl"
download_url(url, ann_root)
self.annotation = load_jsonl(os.path.join(ann_root, filename))
download_url(url,ann_root)
self.annotation = load_jsonl(os.path.join(ann_root,filename))
self.num_frm = num_frm
self.frm_sampling_strategy = frm_sampling_strategy
self.max_img_size = max_img_size
self.video_root = video_root
self.video_fmt = video_fmt
self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
self.img_norm = ImageNorm(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
self.text = [pre_caption(ann["caption"], 40) for ann in self.annotation]
self.txt2video = [i for i in range(len(self.annotation))]
self.video2txt = self.txt2video
self.video2txt = self.txt2video
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
ann = self.annotation[index]
video_path = os.path.join(self.video_root, ann["clip_name"] + self.video_fmt)
video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
vid_frm_array = self._load_video_from_path_decord(
video_path, height=self.max_img_size, width=self.max_img_size
)
video = self.img_norm(vid_frm_array.float())
return video, ann['clip_name']
return video, ann["clip_name"]
def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
def _load_video_from_path_decord(
self,
video_path,
height=None,
width=None,
start_time=None,
end_time=None,
fps=-1,
):
try:
if not height or not width:
vr = VideoReader(video_path)
@ -83,26 +99,36 @@ class VideoDataset(Dataset):
vlen = len(vr)
if start_time or end_time:
assert fps > 0, 'must provide video fps if specifying start and end time.'
assert (
fps > 0
), "must provide video fps if specifying start and end time."
start_idx = min(int(start_time * fps), vlen)
end_idx = min(int(end_time * fps), vlen)
else:
start_idx, end_idx = 0, vlen
if self.frm_sampling_strategy == 'uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
elif self.frm_sampling_strategy == 'rand':
if self.frm_sampling_strategy == "uniform":
frame_indices = np.arange(
start_idx, end_idx, vlen / self.num_frm, dtype=int
)
elif self.frm_sampling_strategy == "rand":
frame_indices = sorted(random.sample(range(vlen), self.num_frm))
elif self.frm_sampling_strategy == 'headtail':
frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
elif self.frm_sampling_strategy == "headtail":
frame_indices_head = sorted(
random.sample(range(vlen // 2), self.num_frm // 2)
)
frame_indices_tail = sorted(
random.sample(range(vlen // 2, vlen), self.num_frm // 2)
)
frame_indices = frame_indices_head + frame_indices_tail
else:
raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
raise NotImplementedError(
"Invalid sampling strategy {} ".format(self.frm_sampling_strategy)
)
raw_sample_frms = vr.get_batch(frame_indices)
except Exception as e:
except Exception:
return None
raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)

View File

@ -1,6 +1,5 @@
import os
import json
import random
from PIL import Image
import torch
@ -9,80 +8,99 @@ from data.utils import pre_question
from torchvision.datasets.utils import download_url
class vqa_dataset(Dataset):
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
self.split = split
def __init__(
self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"
):
self.split = split
self.transform = transform
self.vqa_root = vqa_root
self.vg_root = vg_root
if split=='train':
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
if split == "train":
urls = {
"vqa_train": "https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json",
"vqa_val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json",
"vg_qa": "https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json",
}
self.annotation = []
for f in train_files:
download_url(urls[f],ann_root)
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
download_url(urls[f], ann_root)
self.annotation += json.load(
open(os.path.join(ann_root, "%s.json" % f), "r")
)
else:
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
download_url(
"https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json",
ann_root,
)
self.annotation = json.load(
open(os.path.join(ann_root, "vqa_test.json"), "r")
)
download_url(
"https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json",
ann_root,
)
self.answer_list = json.load(
open(os.path.join(ann_root, "answer_list.json"), "r")
)
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
def __getitem__(self, index):
ann = self.annotation[index]
if ann['dataset']=='vqa':
image_path = os.path.join(self.vqa_root,ann['image'])
elif ann['dataset']=='vg':
image_path = os.path.join(self.vg_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
if self.split == 'test':
question = pre_question(ann['question'])
question_id = ann['question_id']
if ann["dataset"] == "vqa":
image_path = os.path.join(self.vqa_root, ann["image"])
elif ann["dataset"] == "vg":
image_path = os.path.join(self.vg_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
if self.split == "test":
question = pre_question(ann["question"])
question_id = ann["question_id"]
return image, question, question_id
elif self.split == "train":
question = pre_question(ann["question"])
elif self.split=='train':
question = pre_question(ann['question'])
if ann['dataset']=='vqa':
if ann["dataset"] == "vqa":
answer_weight = {}
for answer in ann['answer']:
for answer in ann["answer"]:
if answer in answer_weight.keys():
answer_weight[answer] += 1/len(ann['answer'])
answer_weight[answer] += 1 / len(ann["answer"])
else:
answer_weight[answer] = 1/len(ann['answer'])
answer_weight[answer] = 1 / len(ann["answer"])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
elif ann['dataset']=='vg':
answers = [ann['answer']]
weights = [0.2]
elif ann["dataset"] == "vg":
answers = [ann["answer"]]
weights = [0.2]
return image, question, answers, weights
def vqa_collate_fn(batch):
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
for image, question, answer, weights in batch:
image_list.append(image)
question_list.append(question)
weight_list += weights
weight_list += weights
answer_list += answer
n.append(len(answer))
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
return (
torch.stack(image_list, dim=0),
question_list,
answer_list,
torch.Tensor(weight_list),
n,
)

View File

@ -1 +0,0 @@
from ldm.devices.devices import choose_autocast_device, choose_torch_device

View File

@ -1,24 +1,26 @@
import torch
from torch import autocast
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
"""Convenience routine for guessing which GPU device to run model on"""
if torch.cuda.is_available():
return 'cuda'
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return 'mps'
return 'cpu'
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def choose_autocast_device(device):
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
"""Returns an autocast compatible device from a torch device"""
device_type = device.type # this returns 'mps' on M1
# autocast only for cuda, but GTX 16xx have issues with it
if device_type == 'cuda':
if device_type == "cuda":
device_name = torch.cuda.get_device_name()
if 'GeForce GTX 1660' in device_name or 'GeForce GTX 1650' in device_name:
return device_type,nullcontext
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
return device_type, nullcontext
else:
return device_type,autocast
return device_type, autocast
else:
return 'cpu',nullcontext
return "cpu", nullcontext

View File

@ -5,32 +5,47 @@ class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
@ -38,15 +53,24 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
def __init__(
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
@ -60,17 +84,25 @@ class LambdaWarmUpCosineScheduler2:
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
@ -79,20 +111,25 @@ class LambdaWarmUpCosineScheduler2:
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

View File

@ -12,23 +12,24 @@ from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
@ -36,19 +37,25 @@ class VQModel(pl.LightningModule):
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
print(
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
self.use_ema = use_ema
if self.use_ema:
@ -84,7 +91,9 @@ class VQModel(pl.LightningModule):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
@ -115,7 +124,7 @@ class VQModel(pl.LightningModule):
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
@ -133,7 +142,9 @@ class VQModel(pl.LightningModule):
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
@ -147,48 +158,88 @@ class VQModel(pl.LightningModule):
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
predicted_indices=ind,
)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + suffix,
predicted_indices=ind,
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
self.log(
f"val{suffix}/rec_loss",
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f"val{suffix}/aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
if version.parse(pl.__version__) >= version.parse("1.4.0"):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
@ -196,17 +247,21 @@ class VQModel(pl.LightningModule):
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
@ -214,14 +269,14 @@ class VQModel(pl.LightningModule):
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
},
]
return [opt_ae, opt_disc], scheduler
@ -248,7 +303,8 @@ class VQModel(pl.LightningModule):
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
@ -257,7 +313,7 @@ class VQModel(pl.LightningModule):
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
@ -283,27 +339,28 @@ class VQModelInterface(VQModel):
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
@ -354,29 +411,75 @@ class AutoencoderKL(pl.LightningModule):
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
@ -385,13 +488,17 @@ class AutoencoderKL(pl.LightningModule):
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
@ -419,7 +526,7 @@ class AutoencoderKL(pl.LightningModule):
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x

View File

@ -1,11 +1,12 @@
'''
"""
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
"""
import warnings
warnings.filterwarnings("ignore")
from .vit import VisionTransformer, interpolate_pos_embed
@ -14,225 +15,291 @@ from transformers import BertTokenizer
import torch
from torch import nn
#import torch.nn.functional as F
# import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
class BLIP_Base(nn.Module):
def __init__(self,
med_config = 'configs/blip/med_config.json',
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
def __init__(
self,
med_config="configs/blip/med_config.json",
image_size=224,
vit="base",
vit_grad_ckpt=False,
vit_ckpt_layer=0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
self.visual_encoder, vision_width = create_vit(
vit, image_size, vit_grad_ckpt, vit_ckpt_layer
)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
def forward(self, image, caption, mode):
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
if mode=='image':
assert mode in [
"image",
"text",
"multimodal",
], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
if mode == "image":
# return image features
image_embeds = self.visual_encoder(image)
image_embeds = self.visual_encoder(image)
return image_embeds
elif mode=='text':
elif mode == "text":
# return text features
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
text_output = self.text_encoder(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True,
mode="text",
)
return text_output.last_hidden_state
elif mode=='multimodal':
elif mode == "multimodal":
# return multimodel features
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
text.input_ids[:, 0] = self.tokenizer.enc_token_id
output = self.text_encoder(
text.input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
return output.last_hidden_state
class BLIP_Decoder(nn.Module):
def __init__(self,
med_config = 'configs/blip/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
prompt = 'a picture of ',
):
def __init__(
self,
med_config="configs/blip/med_config.json",
image_size=384,
vit="base",
vit_grad_ckpt=False,
vit_ckpt_layer=0,
prompt="a picture of ",
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
self.visual_encoder, vision_width = create_vit(
vit, image_size, vit_grad_ckpt, vit_ckpt_layer
)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config)
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
self.text_decoder = BertLMHeadModel(config=med_config)
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
def forward(self, image, caption):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
text.input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:,:self.prompt_length] = -100
decoder_output = self.text_decoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
text = self.tokenizer(
caption,
padding="longest",
truncation=True,
max_length=40,
return_tensors="pt",
).to(image.device)
text.input_ids[:, 0] = self.tokenizer.bos_token_id
decoder_targets = text.input_ids.masked_fill(
text.input_ids == self.tokenizer.pad_token_id, -100
)
decoder_targets[:, : self.prompt_length] = -100
decoder_output = self.text_decoder(
text.input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
labels=decoder_targets,
return_dict=True,
)
loss_lm = decoder_output.loss
return loss_lm
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
def generate(
self,
image,
sample=False,
num_beams=3,
max_length=30,
min_length=10,
top_p=0.9,
repetition_penalty=1.0,
):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
model_kwargs = {
"encoder_hidden_states": image_embeds,
"encoder_attention_mask": image_atts,
}
prompt = [self.prompt] * image.size(0)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
input_ids[:,0] = self.tokenizer.bos_token_id
input_ids = input_ids[:, :-1]
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
image.device
)
input_ids[:, 0] = self.tokenizer.bos_token_id
input_ids = input_ids[:, :-1]
if sample:
#nucleus sampling
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
# nucleus sampling
outputs = self.text_decoder.generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs
)
else:
#beam search
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)
captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt):])
return captions
# beam search
outputs = self.text_decoder.generate(
input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs
)
def blip_decoder(pretrained='',**kwargs):
captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt) :])
return captions
def blip_decoder(pretrained="", **kwargs):
model = BLIP_Decoder(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def blip_feature_extractor(pretrained='',**kwargs):
model, msg = load_checkpoint(model, pretrained)
assert len(msg.missing_keys) == 0
return model
def blip_feature_extractor(pretrained="", **kwargs):
model = BLIP_Base(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
model, msg = load_checkpoint(model, pretrained)
assert len(msg.missing_keys) == 0
return model
def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
def create_vit(
vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0
):
assert vit in ["base", "large"], "vit parameter must be base or large"
if vit == "base":
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
visual_encoder = VisionTransformer(
img_size=image_size,
patch_size=16,
embed_dim=vision_width,
depth=12,
num_heads=12,
use_grad_checkpointing=use_grad_checkpointing,
ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate,
)
elif vit == "large":
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
visual_encoder = VisionTransformer(
img_size=image_size,
patch_size=16,
embed_dim=vision_width,
depth=24,
num_heads=16,
use_grad_checkpointing=use_grad_checkpointing,
ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate,
)
return visual_encoder, vision_width
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(model,url_or_filename):
def load_checkpoint(model, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder.pos_embed"], model.visual_encoder
)
if "visual_encoder_m.pos_embed" in model.state_dict().keys():
state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m
)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
if state_dict[key].shape != model.state_dict()[key].shape:
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg
msg = model.load_state_dict(state_dict, strict=False)
print("load checkpoint from %s" % url_or_filename)
return model, msg

Some files were not shown because too many files have changed in this diff Show More