Addition of AlbertLayerGroup

This commit is contained in:
Guillaume B 2020-06-18 20:30:35 +02:00
parent 7fa19a9284
commit 9c4fdd3179

View File

@ -16,6 +16,7 @@ use tch::{nn, Tensor};
use crate::albert::AlbertConfig;
use crate::albert::albert::Activation;
use crate::common::activations::{_gelu_new, _gelu, _relu, _mish};
use std::borrow::BorrowMut;
pub struct AlbertLayer {
attention: AlbertSelfAttention,
@ -60,4 +61,68 @@ impl AlbertLayer {
(ffn_output, attention_weights)
}
}
}
pub struct AlbertLayerGroup {
output_hidden_states: bool,
output_attentions: bool,
layers: Vec<AlbertLayer>,
}
impl AlbertLayerGroup {
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertLayerGroup {
let p = &(p / "albert_layers");
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
};
let mut layers: Vec<AlbertLayer> = vec!();
for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayer::new(&(p / layer_index), config));
};
AlbertLayerGroup { output_hidden_states, output_attentions, layers }
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
};
};
(hidden_state, all_hidden_states, all_attentions)
}
}