Updated T5 for FP16 compatibility

This commit is contained in:
Guillaume Becquin 2021-10-05 18:23:34 +02:00
parent 038ac90757
commit 889f509e6c
4 changed files with 47 additions and 14 deletions

View File

@ -18,8 +18,6 @@ use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
fn main() -> anyhow::Result<()> {
// let summarization_model = SummarizationModel::new(Default::default())?;
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =

View File

@ -224,7 +224,7 @@ impl T5Attention {
scores += position_bias;
let attention_weights = scores
.softmax(-1, Kind::Float)
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = self
.unshape(attention_weights.matmul(&v), bs)

View File

@ -18,7 +18,7 @@ use crate::t5::T5Config;
use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut};
use tch::nn::LinearConfig;
use tch::{nn, Kind, Tensor};
use tch::{nn, Kind, Scalar, Tensor};
pub struct T5DenseReluDense {
wi: nn::Linear,
@ -140,6 +140,21 @@ impl T5Block {
}
}
fn clamp_hidden_states(hidden_states: Tensor) -> Tensor {
if (hidden_states.kind() != Kind::Float) & bool::from(hidden_states.isinf().any()) {
let clamp_value = match hidden_states.kind() {
Kind::Half => half::f16::MAX.to_f64() - 1000.,
Kind::BFloat16 => half::bf16::MAX.to_f64() - 1000.,
_ => {
panic!("Type not supported: supported types are Float (single precision), Half and BFloat16 (half precision)");
}
};
hidden_states.clamp(Scalar::from(-clamp_value), Scalar::from(clamp_value))
} else {
hidden_states
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
@ -152,7 +167,7 @@ impl T5Block {
train: bool,
) -> T5BlockOutput {
let (
hidden_states,
mut hidden_states,
self_attention_weights,
self_attention_position_bias,
self_attention_layer_past,
@ -164,8 +179,10 @@ impl T5Block {
train,
);
hidden_states = T5Block::clamp_hidden_states(hidden_states);
let (
hidden_states,
mut hidden_states,
cross_attention_weights,
cross_attention_position_bias,
cross_attention_layer_past,
@ -186,8 +203,12 @@ impl T5Block {
(hidden_states, None, None, None)
};
hidden_states = T5Block::clamp_hidden_states(hidden_states);
layer_states = (self_attention_layer_past, cross_attention_layer_past);
let hidden_states = self.ff_layer.forward_t(&hidden_states, train);
let mut hidden_states = self.ff_layer.forward_t(&hidden_states, train);
hidden_states = T5Block::clamp_hidden_states(hidden_states);
T5BlockOutput {
hidden_states,
@ -305,8 +326,10 @@ impl T5Stack {
3 => attention_mask.unsqueeze(1),
2 => {
if self.is_decoder {
let seq_ids =
Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device()));
let seq_ids = Tensor::arange(
input_shape[1],
(input_embeddings.kind(), input_embeddings.device()),
);
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
@ -325,8 +348,10 @@ impl T5Stack {
}
};
let extended_attention_mask: Option<Tensor> =
Some((extended_attention_mask.ones_like() - extended_attention_mask) * -1e9);
let extended_attention_mask: Option<Tensor> = Some(
((extended_attention_mask.ones_like() - extended_attention_mask) * -1e4)
.to_kind(input_embeddings.kind()),
);
let extended_encoder_attention_mask = if self.is_decoder & encoder_hidden_states.is_some() {
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
@ -350,7 +375,9 @@ impl T5Stack {
));
}
};
Some((encoder_mask.ones_like() - encoder_mask) * -1e9)
Some(
((encoder_mask.ones_like() - encoder_mask) * -1e4).to_kind(input_embeddings.kind()),
)
} else {
None
};

View File

@ -32,8 +32,16 @@ impl T5LayerNorm {
impl Module for T5LayerNorm {
fn forward(&self, x: &Tensor) -> Tensor {
let variance = x.pow(2f64).mean_dim(&[-1], true, Kind::Float);
let input_type = x.kind();
let variance = x
.to_kind(Kind::Float)
.pow(2.0_f64)
.mean_dim(&[-1], true, Kind::Float);
let x = x * (variance + self.epsilon).rsqrt();
&self.weight * x
if input_type != Kind::Float {
(&self.weight * x).to_kind(input_type)
} else {
&self.weight * x
}
}
}