mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Updated T5 for FP16 compatibility
This commit is contained in:
parent
038ac90757
commit
889f509e6c
@ -18,8 +18,6 @@ use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// let summarization_model = SummarizationModel::new(Default::default())?;
|
||||
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
|
||||
let vocab_resource =
|
||||
|
@ -224,7 +224,7 @@ impl T5Attention {
|
||||
scores += position_bias;
|
||||
|
||||
let attention_weights = scores
|
||||
.softmax(-1, Kind::Float)
|
||||
.softmax(-1, scores.kind())
|
||||
.apply_t(&self.dropout, train);
|
||||
let context = self
|
||||
.unshape(attention_weights.matmul(&v), bs)
|
||||
|
@ -18,7 +18,7 @@ use crate::t5::T5Config;
|
||||
use crate::RustBertError;
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::nn::LinearConfig;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
use tch::{nn, Kind, Scalar, Tensor};
|
||||
|
||||
pub struct T5DenseReluDense {
|
||||
wi: nn::Linear,
|
||||
@ -140,6 +140,21 @@ impl T5Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn clamp_hidden_states(hidden_states: Tensor) -> Tensor {
|
||||
if (hidden_states.kind() != Kind::Float) & bool::from(hidden_states.isinf().any()) {
|
||||
let clamp_value = match hidden_states.kind() {
|
||||
Kind::Half => half::f16::MAX.to_f64() - 1000.,
|
||||
Kind::BFloat16 => half::bf16::MAX.to_f64() - 1000.,
|
||||
_ => {
|
||||
panic!("Type not supported: supported types are Float (single precision), Half and BFloat16 (half precision)");
|
||||
}
|
||||
};
|
||||
hidden_states.clamp(Scalar::from(-clamp_value), Scalar::from(clamp_value))
|
||||
} else {
|
||||
hidden_states
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
@ -152,7 +167,7 @@ impl T5Block {
|
||||
train: bool,
|
||||
) -> T5BlockOutput {
|
||||
let (
|
||||
hidden_states,
|
||||
mut hidden_states,
|
||||
self_attention_weights,
|
||||
self_attention_position_bias,
|
||||
self_attention_layer_past,
|
||||
@ -164,8 +179,10 @@ impl T5Block {
|
||||
train,
|
||||
);
|
||||
|
||||
hidden_states = T5Block::clamp_hidden_states(hidden_states);
|
||||
|
||||
let (
|
||||
hidden_states,
|
||||
mut hidden_states,
|
||||
cross_attention_weights,
|
||||
cross_attention_position_bias,
|
||||
cross_attention_layer_past,
|
||||
@ -186,8 +203,12 @@ impl T5Block {
|
||||
(hidden_states, None, None, None)
|
||||
};
|
||||
|
||||
hidden_states = T5Block::clamp_hidden_states(hidden_states);
|
||||
|
||||
layer_states = (self_attention_layer_past, cross_attention_layer_past);
|
||||
let hidden_states = self.ff_layer.forward_t(&hidden_states, train);
|
||||
let mut hidden_states = self.ff_layer.forward_t(&hidden_states, train);
|
||||
|
||||
hidden_states = T5Block::clamp_hidden_states(hidden_states);
|
||||
|
||||
T5BlockOutput {
|
||||
hidden_states,
|
||||
@ -305,8 +326,10 @@ impl T5Stack {
|
||||
3 => attention_mask.unsqueeze(1),
|
||||
2 => {
|
||||
if self.is_decoder {
|
||||
let seq_ids =
|
||||
Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device()));
|
||||
let seq_ids = Tensor::arange(
|
||||
input_shape[1],
|
||||
(input_embeddings.kind(), input_embeddings.device()),
|
||||
);
|
||||
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
@ -325,8 +348,10 @@ impl T5Stack {
|
||||
}
|
||||
};
|
||||
|
||||
let extended_attention_mask: Option<Tensor> =
|
||||
Some((extended_attention_mask.ones_like() - extended_attention_mask) * -1e9);
|
||||
let extended_attention_mask: Option<Tensor> = Some(
|
||||
((extended_attention_mask.ones_like() - extended_attention_mask) * -1e4)
|
||||
.to_kind(input_embeddings.kind()),
|
||||
);
|
||||
|
||||
let extended_encoder_attention_mask = if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
|
||||
@ -350,7 +375,9 @@ impl T5Stack {
|
||||
));
|
||||
}
|
||||
};
|
||||
Some((encoder_mask.ones_like() - encoder_mask) * -1e9)
|
||||
Some(
|
||||
((encoder_mask.ones_like() - encoder_mask) * -1e4).to_kind(input_embeddings.kind()),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
@ -32,8 +32,16 @@ impl T5LayerNorm {
|
||||
|
||||
impl Module for T5LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Tensor {
|
||||
let variance = x.pow(2f64).mean_dim(&[-1], true, Kind::Float);
|
||||
let input_type = x.kind();
|
||||
let variance = x
|
||||
.to_kind(Kind::Float)
|
||||
.pow(2.0_f64)
|
||||
.mean_dim(&[-1], true, Kind::Float);
|
||||
let x = x * (variance + self.epsilon).rsqrt();
|
||||
&self.weight * x
|
||||
if input_type != Kind::Float {
|
||||
(&self.weight * x).to_kind(input_type)
|
||||
} else {
|
||||
&self.weight * x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user