mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Updated Albert for Half precision support
This commit is contained in:
parent
9d921b67b6
commit
fc2b2972f9
@ -53,7 +53,6 @@ fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
|
||||
let mut model = TextGenerationModel::new(generate_config)?;
|
||||
// model.half();
|
||||
model.set_device(Device::cuda_if_available());
|
||||
|
||||
let input_context_1 = "It was a very nice and sunny";
|
||||
|
@ -235,7 +235,7 @@ pub(crate) fn _make_causal_mask(
|
||||
|
||||
let mut mask = Tensor::full(
|
||||
&[target_length, target_length],
|
||||
f64::NEG_INFINITY,
|
||||
get_negative_infinity(dtype).unwrap(),
|
||||
(dtype, device),
|
||||
);
|
||||
let mask_cond = Tensor::arange(target_length, (dtype, device));
|
||||
|
@ -320,8 +320,17 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
}
|
||||
};
|
||||
|
||||
let embedding_output = self.embeddings.forward_t(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)?;
|
||||
|
||||
let extended_attention_mask: Tensor =
|
||||
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
||||
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
|
||||
.to_kind(embedding_output.kind());
|
||||
|
||||
let encoder_extended_attention_mask: Option<Tensor> =
|
||||
if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
@ -350,14 +359,6 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
None
|
||||
};
|
||||
|
||||
let embedding_output = self.embeddings.forward_t(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)?;
|
||||
|
||||
let encoder_output = self.encoder.forward_t(
|
||||
&embedding_output,
|
||||
Some(&extended_attention_mask),
|
||||
|
@ -13,7 +13,6 @@
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::distilbert_model::DistilBertConfig;
|
||||
use std::borrow::Borrow;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -91,7 +90,9 @@ impl MultiHeadSelfAttention {
|
||||
q.matmul(&k.transpose(2, 3))
|
||||
};
|
||||
|
||||
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
||||
let weights = scores
|
||||
.softmax(-1, scores.kind())
|
||||
.apply_t(&self.dropout, train);
|
||||
let context = self
|
||||
.flatten(weights.matmul(&v), bs, self.dim_per_head)
|
||||
.apply(&self.out_lin);
|
||||
|
@ -19,7 +19,14 @@ use tch::kind::Kind::Float;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
|
||||
fn create_sinusoidal_embeddings<'p, P>(
|
||||
config: &DistilBertConfig,
|
||||
p: P,
|
||||
device: Device,
|
||||
) -> nn::Embedding
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
let mut sinusoidal_embedding: Vec<Tensor> =
|
||||
Vec::with_capacity(config.max_position_embeddings as usize);
|
||||
for pos in 0..config.max_position_embeddings {
|
||||
@ -27,11 +34,11 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
|
||||
for j in 0..config.dim {
|
||||
if j % 2 == 0 {
|
||||
temp_vec.push(
|
||||
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
|
||||
(pos as f64 / 10000_f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
|
||||
);
|
||||
} else {
|
||||
temp_vec.push(
|
||||
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
|
||||
(pos as f64 / 10000_f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -47,7 +54,7 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
|
||||
..Default::default()
|
||||
};
|
||||
let mut embeddings = embedding(
|
||||
&nn::VarStore::new(device).root(),
|
||||
p.borrow(),
|
||||
config.max_position_embeddings,
|
||||
config.dim,
|
||||
embedding_config,
|
||||
@ -90,8 +97,7 @@ impl DistilBertEmbedding {
|
||||
config.dim,
|
||||
embedding_config,
|
||||
),
|
||||
|
||||
true => create_sinusoidal_embeddings(config, p.device()),
|
||||
true => create_sinusoidal_embeddings(config, p / "position_embeddings", p.device()),
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
|
@ -415,8 +415,8 @@ impl Gpt2Model {
|
||||
.unsqueeze(2)
|
||||
.to_kind(input_embeddings.kind());
|
||||
|
||||
let attention_mask = (1.0 - attention_mask) * (-10000.0);
|
||||
attention_mask
|
||||
let attention_mask: Tensor = (1.0 - attention_mask) * (-10000.0);
|
||||
attention_mask.to_kind(input_embeddings.kind())
|
||||
});
|
||||
|
||||
let position_embeds = position_ids.apply(&self.wpe);
|
||||
|
@ -11,6 +11,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::common::kind::get_negative_infinity;
|
||||
use crate::longformer::LongformerConfig;
|
||||
use std::borrow::Borrow;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
@ -183,7 +184,10 @@ impl LongformerSelfAttention {
|
||||
let _ = input_tensor
|
||||
.slice(1, 0, affected_sequence_length, 1)
|
||||
.slice(3, 0, affected_sequence_length + 1, 1)
|
||||
.masked_fill_(&beginning_mask, f64::NEG_INFINITY);
|
||||
.masked_fill_(
|
||||
&beginning_mask,
|
||||
get_negative_infinity(input_tensor.kind()).unwrap(),
|
||||
);
|
||||
|
||||
let _ = input_tensor
|
||||
.narrow(1, -affected_sequence_length, affected_sequence_length)
|
||||
@ -192,7 +196,10 @@ impl LongformerSelfAttention {
|
||||
-(affected_sequence_length + 1),
|
||||
affected_sequence_length + 1,
|
||||
)
|
||||
.masked_fill_(&ending_mask, f64::NEG_INFINITY);
|
||||
.masked_fill_(
|
||||
&ending_mask,
|
||||
get_negative_infinity(input_tensor.kind()).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
fn sliding_chunks_query_key_matmul(
|
||||
@ -227,7 +234,10 @@ impl LongformerSelfAttention {
|
||||
window_overlap,
|
||||
window_overlap * 2 + 1,
|
||||
],
|
||||
(Kind::Float, diagonal_chunked_attention_scores.device()),
|
||||
(
|
||||
diagonal_chunked_attention_scores.kind(),
|
||||
diagonal_chunked_attention_scores.device(),
|
||||
),
|
||||
);
|
||||
|
||||
let diagonal_attention_scores_size = diagonal_attention_scores.size();
|
||||
@ -406,7 +416,7 @@ impl LongformerSelfAttention {
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
],
|
||||
(Kind::Float, key_vectors.device()),
|
||||
(key_vectors.kind(), key_vectors.device()),
|
||||
);
|
||||
|
||||
let _ = key_vectors_only_global.index_put_(
|
||||
@ -457,7 +467,7 @@ impl LongformerSelfAttention {
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
],
|
||||
(Kind::Float, value_vectors.device()),
|
||||
(value_vectors.kind(), value_vectors.device()),
|
||||
);
|
||||
|
||||
let _ = value_vectors_only_global.index_put_(
|
||||
@ -502,7 +512,7 @@ impl LongformerSelfAttention {
|
||||
|
||||
let mut global_attention_hidden_states = Tensor::zeros(
|
||||
&[max_num_global_attention_indices, batch_size, self.embed_dim],
|
||||
(Kind::Float, hidden_states.device()),
|
||||
(hidden_states.kind(), hidden_states.device()),
|
||||
);
|
||||
|
||||
let _ = global_attention_hidden_states.index_put_(
|
||||
@ -566,10 +576,10 @@ impl LongformerSelfAttention {
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
)
|
||||
.fill_(-10000f64);
|
||||
.fill_(-10000_f64);
|
||||
|
||||
let global_attention_scores = global_attention_scores
|
||||
.masked_fill(&is_index_masked.unsqueeze(1).unsqueeze(1), -10000f64)
|
||||
.masked_fill(&is_index_masked.unsqueeze(1).unsqueeze(1), -10000_f64)
|
||||
.view([
|
||||
batch_size * self.num_heads,
|
||||
max_num_global_attention_indices,
|
||||
@ -577,7 +587,7 @@ impl LongformerSelfAttention {
|
||||
]);
|
||||
|
||||
let global_attention_probas = global_attention_scores
|
||||
.softmax(-1, Kind::Float)
|
||||
.softmax(-1, global_attention_scores.kind())
|
||||
.apply_t(&self.dropout, train);
|
||||
|
||||
let global_attention_output = global_attention_probas.bmm(&global_value_vectors);
|
||||
@ -629,13 +639,13 @@ impl LongformerSelfAttention {
|
||||
|
||||
let remove_from_windowed_attention_mask = attention_mask.ne(0).unsqueeze(-1).unsqueeze(-1);
|
||||
let float_mask = remove_from_windowed_attention_mask
|
||||
.totype(Kind::Float)
|
||||
.totype(attention_scores.kind())
|
||||
.masked_fill(&remove_from_windowed_attention_mask, -10000.0);
|
||||
|
||||
let diagonal_mask = self.sliding_chunks_query_key_matmul(
|
||||
&Tensor::ones(
|
||||
float_mask.size().as_slice(),
|
||||
(Kind::Float, float_mask.device()),
|
||||
(float_mask.kind(), float_mask.device()),
|
||||
),
|
||||
&float_mask,
|
||||
self.one_sided_attention_window_size,
|
||||
@ -679,7 +689,7 @@ impl LongformerSelfAttention {
|
||||
};
|
||||
|
||||
let mut attention_probas = attention_scores
|
||||
.softmax(-1, Kind::Float)
|
||||
.softmax(-1, attention_scores.kind())
|
||||
.masked_fill(&is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
|
||||
.apply_t(&self.dropout, train);
|
||||
|
||||
@ -758,7 +768,7 @@ impl LongformerSelfAttention {
|
||||
.index(is_index_global_attn_nonzero.as_ref().unwrap())
|
||||
.size()
|
||||
.as_slice(),
|
||||
(Kind::Float, attention_output.device()),
|
||||
(attention_output.kind(), attention_output.device()),
|
||||
),
|
||||
false,
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user