diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py index eb35628..e2ca3cc 100644 --- a/scripts/sd_utils.py +++ b/scripts/sd_utils.py @@ -809,7 +809,11 @@ def generation_callback(img, i=0): x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) + if x_samples_ddim.ndimension() == 4: + pil_images = [transforms.ToPILImage()(x.squeeze_(0)) for x in x_samples_ddim] + pil_image = image_grid(pil_images, 1) + else: + pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) # update image on the UI so we can see the progress st.session_state["preview_image"].image(pil_image)