Add bfloat16 tensor loading support (#396)

This commit is contained in:
guillaume-be 2023-06-25 09:21:52 +01:00 committed by GitHub
parent 7b1ab24371
commit a74d023583
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,16 @@
# Copyright 2019-2023 Guillaume Becquin
# Copyright 2023 https://github.com/starkat99/half-rs
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import subprocess
@ -7,6 +20,23 @@ import torch
from pathlib import Path
from torch import Tensor
def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray:
"""Convert a bfloat16 tensor to an equivalent byte representation in Numpy.
This is a vectorized implementation inspired from https://github.com/starkat99/half-rs/blob/main/src/bfloat/convert.rs
(shared under Apache 2.0 license at https://github.com/starkat99/half-rs/blob/main/LICENSES/Apache-2.0.txt)
"""
v_fp32 = input_tensor.cpu().float().numpy()
byte_array = np.frombuffer(v_fp32.tobytes(), dtype=np.uint32)
nan_value = np.logical_or(np.right_shift(byte_array, 16), 0x0040)
nan_mask = np.logical_and(byte_array, 0x7FFF_FFFF) > 0x7F80_0000
round_bit = 0x0000_8000
output_val = np.right_shift(byte_array, 16)
threshold_mask = (np.logical_and(byte_array, round_bit) != 0) & (np.logical_and(byte_array, (3*round_bit-1)) != 0)
output = np.where(nan_mask, nan_value, np.where(threshold_mask, output_val+1, output_val)).astype(np.uint16)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@ -64,7 +94,10 @@ if __name__ == "__main__":
if args.suffix:
k = k.split(".")[-1]
if isinstance(v, Tensor):
tensor = v.cpu().numpy()
if v.dtype == torch.bfloat16:
tensor = get_bf16_repr(v)
else:
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(tensor.astype(np.dtype(args.dtype)))
else: