diff --git a/examples/generation_xlnet.rs b/examples/generation_xlnet.rs index 7cf98a3..3c73993 100644 --- a/examples/generation_xlnet.rs +++ b/examples/generation_xlnet.rs @@ -20,7 +20,6 @@ use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources}; fn main() -> anyhow::Result<()> { - // Set-up model // Resources paths let config_resource = Resource::Remote(RemoteResource::from_pretrained( XLNetConfigResources::XLNET_BASE_CASED, @@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> { vocab_resource, merges_resource, max_length: 32, - do_sample: true, + do_sample: false, num_beams: 3, temperature: 1.0, num_return_sequences: 1, diff --git a/src/xlnet/attention.rs b/src/xlnet/attention.rs index ec6f86e..438713a 100644 --- a/src/xlnet/attention.rs +++ b/src/xlnet/attention.rs @@ -186,15 +186,17 @@ impl XLNetRelativeAttention { ); Tensor::einsum("ijbs,ibns->bnij", &[seg_mat, &ef]) } - None => Tensor::zeros(&[1], (Kind::Float, ac.device())), + None => Tensor::zeros(&[1], (ac.kind(), ac.device())), }; let mut attention_score = (ac + bd + ef) * self.scale; if let Some(value) = attention_mask { - attention_score = attention_score - value.permute(&[2, 3, 0, 1]) * 1e30; + let target_kind = attention_score.kind(); + attention_score = + (attention_score - value.permute(&[2, 3, 0, 1]) * 1e30).to_kind(target_kind); }; let attention_probas = attention_score - .softmax(3, Kind::Float) + .softmax(3, attention_score.kind()) .apply_t(&self.dropout, train); let attention_vector = Tensor::einsum("bnij,jbnd->ibnd", &[&attention_probas, v_head_h]); diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 6cf411c..bbb1538 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -285,6 +285,7 @@ impl XLNetModel { q_len: i64, k_len: i64, batch_size: Option, + kind: Kind, device: Device, ) -> Tensor { let frequency_sequence = @@ -303,7 +304,7 @@ impl XLNetModel { } _ => {} } - if self.bi_data { + let position_embeddings = if self.bi_data { let mut backward_positions_sequence = Tensor::arange_start(-begin, -end, (Kind::Float, device)); match self.clamp_len { @@ -324,7 +325,8 @@ impl XLNetModel { ) } else { self.positional_embedding(&forward_positions_sequence, &inverse_frequency, batch_size) - } + }; + position_embeddings.to_kind(kind) } /// Forward pass through the model @@ -424,9 +426,18 @@ impl XLNetModel { token_type_ids.map(|token_type_ids| token_type_ids.transpose(0, 1).contiguous()); let attention_mask = attention_mask.map(|attention_mask| attention_mask.transpose(0, 1).contiguous()); - let perm_mask = perm_mask.map(|perm_mask| perm_mask.permute(&[1, 2, 0]).contiguous()); - let target_mapping = - target_mapping.map(|target_mapping| target_mapping.permute(&[1, 2, 0]).contiguous()); + let perm_mask = perm_mask.map(|perm_mask| { + perm_mask + .to_kind(word_emb_k.kind()) + .permute(&[1, 2, 0]) + .contiguous() + }); + let target_mapping = target_mapping.map(|target_mapping| { + target_mapping + .to_kind(word_emb_k.kind()) + .permute(&[1, 2, 0]) + .contiguous() + }); let m_len = if let Some(mems) = &old_layer_states { if let Some(mem_0) = &mems[0] { @@ -513,13 +524,19 @@ impl XLNetModel { .unsqueeze(-1) .ne_tensor(&cat_ids.unsqueeze(0)) .to_kind(Kind::Int64); - Some(seg_mat.one_hot(2).to_kind(Kind::Float)) + Some(seg_mat.one_hot(2).to_kind(output_h.kind())) } else { None }; let pos_emb = self - .relative_positional_encoding(q_len, k_len, Some(batch_size), output_h.device()) + .relative_positional_encoding( + q_len, + k_len, + Some(batch_size), + output_h.kind(), + output_h.device(), + ) .apply_t(&self.dropout, train); let mut all_hidden_states: Option)>> = @@ -1695,7 +1712,6 @@ impl PrivateLanguageGenerator for match past { Cache::XLNetCache(past) => { if let Some(past) = past { - // let new_past = Vec::with_capacity(past.len()); let past = if let Some(first_past) = &past[0] { let past_len = first_past.prev_content.size()[0]; past.iter()