Updated XLNet for FP16 compatibility

This commit is contained in:
Guillaume Becquin 2021-10-06 17:52:25 +02:00
parent 889f509e6c
commit de89e2d165
3 changed files with 30 additions and 13 deletions

View File

@ -20,7 +20,6 @@ use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
fn main() -> anyhow::Result<()> {
// Set-up model
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource,
merges_resource,
max_length: 32,
do_sample: true,
do_sample: false,
num_beams: 3,
temperature: 1.0,
num_return_sequences: 1,

View File

@ -186,15 +186,17 @@ impl XLNetRelativeAttention {
);
Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef])
}
None => Tensor::zeros(&[1], (Kind::Float, ac.device())),
None => Tensor::zeros(&[1], (ac.kind(), ac.device())),
};
let mut attention_score = (ac + bd + ef) * self.scale;
if let Some(value) = attention_mask {
attention_score = attention_score - value.permute(&[2, 3, 0, 1]) * 1e30;
let target_kind = attention_score.kind();
attention_score =
(attention_score - value.permute(&[2, 3, 0, 1]) * 1e30).to_kind(target_kind);
};
let attention_probas = attention_score
.softmax(3, Kind::Float)
.softmax(3, attention_score.kind())
.apply_t(&self.dropout, train);
let attention_vector = Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h]);

View File

@ -285,6 +285,7 @@ impl XLNetModel {
q_len: i64,
k_len: i64,
batch_size: Option<i64>,
kind: Kind,
device: Device,
) -> Tensor {
let frequency_sequence =
@ -303,7 +304,7 @@ impl XLNetModel {
}
_ => {}
}
if self.bi_data {
let position_embeddings = if self.bi_data {
let mut backward_positions_sequence =
Tensor::arange_start(-begin, -end, (Kind::Float, device));
match self.clamp_len {
@ -324,7 +325,8 @@ impl XLNetModel {
)
} else {
self.positional_embedding(&forward_positions_sequence, &inverse_frequency, batch_size)
}
};
position_embeddings.to_kind(kind)
}
/// Forward pass through the model
@ -424,9 +426,18 @@ impl XLNetModel {
token_type_ids.map(|token_type_ids| token_type_ids.transpose(0, 1).contiguous());
let attention_mask =
attention_mask.map(|attention_mask| attention_mask.transpose(0, 1).contiguous());
let perm_mask = perm_mask.map(|perm_mask| perm_mask.permute(&[1, 2, 0]).contiguous());
let target_mapping =
target_mapping.map(|target_mapping| target_mapping.permute(&[1, 2, 0]).contiguous());
let perm_mask = perm_mask.map(|perm_mask| {
perm_mask
.to_kind(word_emb_k.kind())
.permute(&[1, 2, 0])
.contiguous()
});
let target_mapping = target_mapping.map(|target_mapping| {
target_mapping
.to_kind(word_emb_k.kind())
.permute(&[1, 2, 0])
.contiguous()
});
let m_len = if let Some(mems) = &old_layer_states {
if let Some(mem_0) = &mems[0] {
@ -513,13 +524,19 @@ impl XLNetModel {
.unsqueeze(-1)
.ne_tensor(&cat_ids.unsqueeze(0))
.to_kind(Kind::Int64);
Some(seg_mat.one_hot(2).to_kind(Kind::Float))
Some(seg_mat.one_hot(2).to_kind(output_h.kind()))
} else {
None
};
let pos_emb = self
.relative_positional_encoding(q_len, k_len, Some(batch_size), output_h.device())
.relative_positional_encoding(
q_len,
k_len,
Some(batch_size),
output_h.kind(),
output_h.device(),
)
.apply_t(&self.dropout, train);
let mut all_hidden_states: Option<Vec<(Tensor, Option<Tensor>)>> =
@ -1695,7 +1712,6 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
match past {
Cache::XLNetCache(past) => {
if let Some(past) = past {
// let new_past = Vec::with_capacity(past.len());
let past = if let Some(first_past) = &past[0] {
let past_len = first_past.prev_content.size()[0];
past.iter()