Update deprecated torch.qr in glow.py example (#4685)

torch.qr is deprecated for a long time and is being removed by https://github.com/pytorch/pytorch/pull/70989.

This PR makes the example compatible with new and old PyTorch versions.
This commit is contained in:
Sergii Dymchenko 2022-12-09 16:03:01 -08:00 committed by GitHub
parent f131336fc3
commit 1f53315071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -71,7 +71,8 @@ class Invertible1x1Conv(torch.nn.Module):
bias=False)
# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
_qr = torch.linalg.qr if torch.__version__ >= "1.8" else torch.qr
W = _qr(torch.FloatTensor(c, c).normal_())[0]
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0: