diff --git a/examples/summarization_t5.rs b/examples/summarization_t5.rs index 897814c..2fb08c2 100644 --- a/examples/summarization_t5.rs +++ b/examples/summarization_t5.rs @@ -18,8 +18,6 @@ use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; fn main() -> anyhow::Result<()> { - // let summarization_model = SummarizationModel::new(Default::default())?; - let config_resource = Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)); let vocab_resource = diff --git a/src/t5/attention.rs b/src/t5/attention.rs index 01b4486..9d46497 100644 --- a/src/t5/attention.rs +++ b/src/t5/attention.rs @@ -224,7 +224,7 @@ impl T5Attention { scores += position_bias; let attention_weights = scores - .softmax(-1, Kind::Float) + .softmax(-1, scores.kind()) .apply_t(&self.dropout, train); let context = self .unshape(attention_weights.matmul(&v), bs) diff --git a/src/t5/encoder.rs b/src/t5/encoder.rs index e372344..60a893d 100644 --- a/src/t5/encoder.rs +++ b/src/t5/encoder.rs @@ -18,7 +18,7 @@ use crate::t5::T5Config; use crate::RustBertError; use std::borrow::{Borrow, BorrowMut}; use tch::nn::LinearConfig; -use tch::{nn, Kind, Tensor}; +use tch::{nn, Kind, Scalar, Tensor}; pub struct T5DenseReluDense { wi: nn::Linear, @@ -140,6 +140,21 @@ impl T5Block { } } + fn clamp_hidden_states(hidden_states: Tensor) -> Tensor { + if (hidden_states.kind() != Kind::Float) & bool::from(hidden_states.isinf().any()) { + let clamp_value = match hidden_states.kind() { + Kind::Half => half::f16::MAX.to_f64() - 1000., + Kind::BFloat16 => half::bf16::MAX.to_f64() - 1000., + _ => { + panic!("Type not supported: supported types are Float (single precision), Half and BFloat16 (half precision)"); + } + }; + hidden_states.clamp(Scalar::from(-clamp_value), Scalar::from(clamp_value)) + } else { + hidden_states + } + } + pub fn forward_t( &self, hidden_states: &Tensor, @@ -152,7 +167,7 @@ impl T5Block { train: bool, ) -> T5BlockOutput { let ( - hidden_states, + mut hidden_states, self_attention_weights, self_attention_position_bias, self_attention_layer_past, @@ -164,8 +179,10 @@ impl T5Block { train, ); + hidden_states = T5Block::clamp_hidden_states(hidden_states); + let ( - hidden_states, + mut hidden_states, cross_attention_weights, cross_attention_position_bias, cross_attention_layer_past, @@ -186,8 +203,12 @@ impl T5Block { (hidden_states, None, None, None) }; + hidden_states = T5Block::clamp_hidden_states(hidden_states); + layer_states = (self_attention_layer_past, cross_attention_layer_past); - let hidden_states = self.ff_layer.forward_t(&hidden_states, train); + let mut hidden_states = self.ff_layer.forward_t(&hidden_states, train); + + hidden_states = T5Block::clamp_hidden_states(hidden_states); T5BlockOutput { hidden_states, @@ -305,8 +326,10 @@ impl T5Stack { 3 => attention_mask.unsqueeze(1), 2 => { if self.is_decoder { - let seq_ids = - Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device())); + let seq_ids = Tensor::arange( + input_shape[1], + (input_embeddings.kind(), input_embeddings.device()), + ); let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[ input_shape[0], input_shape[1], @@ -325,8 +348,10 @@ impl T5Stack { } }; - let extended_attention_mask: Option = - Some((extended_attention_mask.ones_like() - extended_attention_mask) * -1e9); + let extended_attention_mask: Option = Some( + ((extended_attention_mask.ones_like() - extended_attention_mask) * -1e4) + .to_kind(input_embeddings.kind()), + ); let extended_encoder_attention_mask = if self.is_decoder & encoder_hidden_states.is_some() { let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap(); @@ -350,7 +375,9 @@ impl T5Stack { )); } }; - Some((encoder_mask.ones_like() - encoder_mask) * -1e9) + Some( + ((encoder_mask.ones_like() - encoder_mask) * -1e4).to_kind(input_embeddings.kind()), + ) } else { None }; diff --git a/src/t5/layer_norm.rs b/src/t5/layer_norm.rs index c9dc475..9c4c902 100644 --- a/src/t5/layer_norm.rs +++ b/src/t5/layer_norm.rs @@ -32,8 +32,16 @@ impl T5LayerNorm { impl Module for T5LayerNorm { fn forward(&self, x: &Tensor) -> Tensor { - let variance = x.pow(2f64).mean_dim(&[-1], true, Kind::Float); + let input_type = x.kind(); + let variance = x + .to_kind(Kind::Float) + .pow(2.0_f64) + .mean_dim(&[-1], true, Kind::Float); let x = x * (variance + self.epsilon).rsqrt(); - &self.weight * x + if input_type != Kind::Float { + (&self.weight * x).to_kind(input_type) + } else { + &self.weight * x + } } }