Addition of Electra generator and discriminator heads

This commit is contained in:
Guillaume B 2020-04-29 18:59:37 +02:00
parent 45eeb7ae5b
commit 5bec2548c1
2 changed files with 50 additions and 0 deletions

View File

@ -69,5 +69,7 @@ fn main() -> failure::Fallible<()> {
.unwrap()
});
output.print();
Ok(())
}

View File

@ -19,6 +19,7 @@ use crate::Config;
use crate::electra::embeddings::ElectraEmbeddings;
use tch::{nn, Tensor, Kind};
use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
#[derive(Debug, Serialize, Deserialize)]
/// # Electra model configuration
@ -132,4 +133,51 @@ impl ElectraModel {
Ok((hidden_state, all_hidden_states, all_attentions))
}
}
pub struct ElectraDiscriminatorHead {
dense: nn::Linear,
dense_prediction: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
}
impl ElectraDiscriminatorHead {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default());
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
});
ElectraDiscriminatorHead { dense, dense_prediction, activation }
}
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
output.apply(&self.dense_prediction).squeeze()
}
}
pub struct ElectraGeneratorHead {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
}
impl ElectraGeneratorHead {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default());
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
let activation = Box::new(_gelu);
ElectraGeneratorHead { layer_norm, dense, activation }
}
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
output.apply(&self.layer_norm)
}
}