Updated Albert for Half precision support

This commit is contained in:
Guillaume Becquin 2021-09-30 16:04:42 +02:00
parent 9d921b67b6
commit fc2b2972f9
7 changed files with 51 additions and 34 deletions

View File

@ -53,7 +53,6 @@ fn main() -> anyhow::Result<()> {
};
let mut model = TextGenerationModel::new(generate_config)?;
// model.half();
model.set_device(Device::cuda_if_available());
let input_context_1 = "It was a very nice and sunny";

View File

@ -235,7 +235,7 @@ pub(crate) fn _make_causal_mask(
let mut mask = Tensor::full(
&[target_length, target_length],
f64::NEG_INFINITY,
get_negative_infinity(dtype).unwrap(),
(dtype, device),
);
let mask_cond = Tensor::arange(target_length, (dtype, device));

View File

@ -320,8 +320,17 @@ impl<T: BertEmbedding> BertModel<T> {
}
};
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {
@ -350,14 +359,6 @@ impl<T: BertEmbedding> BertModel<T> {
None
};
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let encoder_output = self.encoder.forward_t(
&embedding_output,
Some(&extended_attention_mask),

View File

@ -13,7 +13,6 @@
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert_model::DistilBertConfig;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
@ -91,7 +90,9 @@ impl MultiHeadSelfAttention {
q.matmul(&k.transpose(2, 3))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let weights = scores
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = self
.flatten(weights.matmul(&v), bs, self.dim_per_head)
.apply(&self.out_lin);

View File

@ -19,7 +19,14 @@ use tch::kind::Kind::Float;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Device, Kind, Tensor};
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
fn create_sinusoidal_embeddings<'p, P>(
config: &DistilBertConfig,
p: P,
device: Device,
) -> nn::Embedding
where
P: Borrow<nn::Path<'p>>,
{
let mut sinusoidal_embedding: Vec<Tensor> =
Vec::with_capacity(config.max_position_embeddings as usize);
for pos in 0..config.max_position_embeddings {
@ -27,11 +34,11 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
for j in 0..config.dim {
if j % 2 == 0 {
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
(pos as f64 / 10000_f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
);
} else {
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
(pos as f64 / 10000_f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
);
}
}
@ -47,7 +54,7 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
..Default::default()
};
let mut embeddings = embedding(
&nn::VarStore::new(device).root(),
p.borrow(),
config.max_position_embeddings,
config.dim,
embedding_config,
@ -90,8 +97,7 @@ impl DistilBertEmbedding {
config.dim,
embedding_config,
),
true => create_sinusoidal_embeddings(config, p.device()),
true => create_sinusoidal_embeddings(config, p / "position_embeddings", p.device()),
};
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,

View File

@ -415,8 +415,8 @@ impl Gpt2Model {
.unsqueeze(2)
.to_kind(input_embeddings.kind());
let attention_mask = (1.0 - attention_mask) * (-10000.0);
attention_mask
let attention_mask: Tensor = (1.0 - attention_mask) * (-10000.0);
attention_mask.to_kind(input_embeddings.kind())
});
let position_embeds = position_ids.apply(&self.wpe);

View File

@ -11,6 +11,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity;
use crate::longformer::LongformerConfig;
use std::borrow::Borrow;
use tch::{nn, Kind, Tensor};
@ -183,7 +184,10 @@ impl LongformerSelfAttention {
let _ = input_tensor
.slice(1, 0, affected_sequence_length, 1)
.slice(3, 0, affected_sequence_length + 1, 1)
.masked_fill_(&beginning_mask, f64::NEG_INFINITY);
.masked_fill_(
&beginning_mask,
get_negative_infinity(input_tensor.kind()).unwrap(),
);
let _ = input_tensor
.narrow(1, -affected_sequence_length, affected_sequence_length)
@ -192,7 +196,10 @@ impl LongformerSelfAttention {
-(affected_sequence_length + 1),
affected_sequence_length + 1,
)
.masked_fill_(&ending_mask, f64::NEG_INFINITY);
.masked_fill_(
&ending_mask,
get_negative_infinity(input_tensor.kind()).unwrap(),
);
}
fn sliding_chunks_query_key_matmul(
@ -227,7 +234,10 @@ impl LongformerSelfAttention {
window_overlap,
window_overlap * 2 + 1,
],
(Kind::Float, diagonal_chunked_attention_scores.device()),
(
diagonal_chunked_attention_scores.kind(),
diagonal_chunked_attention_scores.device(),
),
);
let diagonal_attention_scores_size = diagonal_attention_scores.size();
@ -406,7 +416,7 @@ impl LongformerSelfAttention {
self.num_heads,
self.head_dim,
],
(Kind::Float, key_vectors.device()),
(key_vectors.kind(), key_vectors.device()),
);
let _ = key_vectors_only_global.index_put_(
@ -457,7 +467,7 @@ impl LongformerSelfAttention {
self.num_heads,
self.head_dim,
],
(Kind::Float, value_vectors.device()),
(value_vectors.kind(), value_vectors.device()),
);
let _ = value_vectors_only_global.index_put_(
@ -502,7 +512,7 @@ impl LongformerSelfAttention {
let mut global_attention_hidden_states = Tensor::zeros(
&[max_num_global_attention_indices, batch_size, self.embed_dim],
(Kind::Float, hidden_states.device()),
(hidden_states.kind(), hidden_states.device()),
);
let _ = global_attention_hidden_states.index_put_(
@ -566,10 +576,10 @@ impl LongformerSelfAttention {
.as_ref()
.unwrap(),
)
.fill_(-10000f64);
.fill_(-10000_f64);
let global_attention_scores = global_attention_scores
.masked_fill(&is_index_masked.unsqueeze(1).unsqueeze(1), -10000f64)
.masked_fill(&is_index_masked.unsqueeze(1).unsqueeze(1), -10000_f64)
.view([
batch_size * self.num_heads,
max_num_global_attention_indices,
@ -577,7 +587,7 @@ impl LongformerSelfAttention {
]);
let global_attention_probas = global_attention_scores
.softmax(-1, Kind::Float)
.softmax(-1, global_attention_scores.kind())
.apply_t(&self.dropout, train);
let global_attention_output = global_attention_probas.bmm(&global_value_vectors);
@ -629,13 +639,13 @@ impl LongformerSelfAttention {
let remove_from_windowed_attention_mask = attention_mask.ne(0).unsqueeze(-1).unsqueeze(-1);
let float_mask = remove_from_windowed_attention_mask
.totype(Kind::Float)
.totype(attention_scores.kind())
.masked_fill(&remove_from_windowed_attention_mask, -10000.0);
let diagonal_mask = self.sliding_chunks_query_key_matmul(
&Tensor::ones(
float_mask.size().as_slice(),
(Kind::Float, float_mask.device()),
(float_mask.kind(), float_mask.device()),
),
&float_mask,
self.one_sided_attention_window_size,
@ -679,7 +689,7 @@ impl LongformerSelfAttention {
};
let mut attention_probas = attention_scores
.softmax(-1, Kind::Float)
.softmax(-1, attention_scores.kind())
.masked_fill(&is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
.apply_t(&self.dropout, train);
@ -758,7 +768,7 @@ impl LongformerSelfAttention {
.index(is_index_global_attn_nonzero.as_ref().unwrap())
.size()
.as_slice(),
(Kind::Float, attention_output.device()),
(attention_output.kind(), attention_output.device()),
),
false,
);