mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Updated XLNet for FP16 compatibility
This commit is contained in:
parent
889f509e6c
commit
de89e2d165
@ -20,7 +20,6 @@ use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetConfigResources::XLNET_BASE_CASED,
|
||||
@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> {
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
max_length: 32,
|
||||
do_sample: true,
|
||||
do_sample: false,
|
||||
num_beams: 3,
|
||||
temperature: 1.0,
|
||||
num_return_sequences: 1,
|
||||
|
@ -186,15 +186,17 @@ impl XLNetRelativeAttention {
|
||||
);
|
||||
Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef])
|
||||
}
|
||||
None => Tensor::zeros(&[1], (Kind::Float, ac.device())),
|
||||
None => Tensor::zeros(&[1], (ac.kind(), ac.device())),
|
||||
};
|
||||
let mut attention_score = (ac + bd + ef) * self.scale;
|
||||
if let Some(value) = attention_mask {
|
||||
attention_score = attention_score - value.permute(&[2, 3, 0, 1]) * 1e30;
|
||||
let target_kind = attention_score.kind();
|
||||
attention_score =
|
||||
(attention_score - value.permute(&[2, 3, 0, 1]) * 1e30).to_kind(target_kind);
|
||||
};
|
||||
|
||||
let attention_probas = attention_score
|
||||
.softmax(3, Kind::Float)
|
||||
.softmax(3, attention_score.kind())
|
||||
.apply_t(&self.dropout, train);
|
||||
|
||||
let attention_vector = Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h]);
|
||||
|
@ -285,6 +285,7 @@ impl XLNetModel {
|
||||
q_len: i64,
|
||||
k_len: i64,
|
||||
batch_size: Option<i64>,
|
||||
kind: Kind,
|
||||
device: Device,
|
||||
) -> Tensor {
|
||||
let frequency_sequence =
|
||||
@ -303,7 +304,7 @@ impl XLNetModel {
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
if self.bi_data {
|
||||
let position_embeddings = if self.bi_data {
|
||||
let mut backward_positions_sequence =
|
||||
Tensor::arange_start(-begin, -end, (Kind::Float, device));
|
||||
match self.clamp_len {
|
||||
@ -324,7 +325,8 @@ impl XLNetModel {
|
||||
)
|
||||
} else {
|
||||
self.positional_embedding(&forward_positions_sequence, &inverse_frequency, batch_size)
|
||||
}
|
||||
};
|
||||
position_embeddings.to_kind(kind)
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -424,9 +426,18 @@ impl XLNetModel {
|
||||
token_type_ids.map(|token_type_ids| token_type_ids.transpose(0, 1).contiguous());
|
||||
let attention_mask =
|
||||
attention_mask.map(|attention_mask| attention_mask.transpose(0, 1).contiguous());
|
||||
let perm_mask = perm_mask.map(|perm_mask| perm_mask.permute(&[1, 2, 0]).contiguous());
|
||||
let target_mapping =
|
||||
target_mapping.map(|target_mapping| target_mapping.permute(&[1, 2, 0]).contiguous());
|
||||
let perm_mask = perm_mask.map(|perm_mask| {
|
||||
perm_mask
|
||||
.to_kind(word_emb_k.kind())
|
||||
.permute(&[1, 2, 0])
|
||||
.contiguous()
|
||||
});
|
||||
let target_mapping = target_mapping.map(|target_mapping| {
|
||||
target_mapping
|
||||
.to_kind(word_emb_k.kind())
|
||||
.permute(&[1, 2, 0])
|
||||
.contiguous()
|
||||
});
|
||||
|
||||
let m_len = if let Some(mems) = &old_layer_states {
|
||||
if let Some(mem_0) = &mems[0] {
|
||||
@ -513,13 +524,19 @@ impl XLNetModel {
|
||||
.unsqueeze(-1)
|
||||
.ne_tensor(&cat_ids.unsqueeze(0))
|
||||
.to_kind(Kind::Int64);
|
||||
Some(seg_mat.one_hot(2).to_kind(Kind::Float))
|
||||
Some(seg_mat.one_hot(2).to_kind(output_h.kind()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let pos_emb = self
|
||||
.relative_positional_encoding(q_len, k_len, Some(batch_size), output_h.device())
|
||||
.relative_positional_encoding(
|
||||
q_len,
|
||||
k_len,
|
||||
Some(batch_size),
|
||||
output_h.kind(),
|
||||
output_h.device(),
|
||||
)
|
||||
.apply_t(&self.dropout, train);
|
||||
|
||||
let mut all_hidden_states: Option<Vec<(Tensor, Option<Tensor>)>> =
|
||||
@ -1695,7 +1712,6 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
|
||||
match past {
|
||||
Cache::XLNetCache(past) => {
|
||||
if let Some(past) = past {
|
||||
// let new_past = Vec::with_capacity(past.len());
|
||||
let past = if let Some(first_past) = &past[0] {
|
||||
let past_len = first_past.prev_content.size()[0];
|
||||
past.iter()
|
||||
|
Loading…
Reference in New Issue
Block a user