Working forward pass (short output, no cache)

This commit is contained in:
Guillaume B 2020-11-13 16:58:03 +01:00
parent aea41062ef
commit 3995a3ee92
3 changed files with 5 additions and 5 deletions

View File

@ -66,9 +66,9 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
let _model_output =
reformer_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)?;
model_output.logits.print();
_model_output.logits.print();
Ok(())
}

View File

@ -370,7 +370,7 @@ impl LSHSelfAttention {
self.num_chunks_after,
);
(query_bucket_idx, key_value_bucket_idx)
} else if do_standard_self_attention & (query_key_dots.dim() > 4) {
} else if do_cached_attention & (query_key_dots.dim() > 4) {
let mut query_shape = sorted_bucket_indices_per_hash.size();
query_shape[sorted_bucket_indices_per_hash.dim() - 1] = 1;
let query_bucket_idx = sorted_bucket_indices_per_hash.new_full(
@ -379,7 +379,7 @@ impl LSHSelfAttention {
(Kind::Int64, sorted_bucket_indices_per_hash.device()),
);
(query_bucket_idx, sorted_bucket_indices_per_hash)
} else if do_standard_self_attention & (query_key_dots.dim() <= 4) {
} else if do_cached_attention & (query_key_dots.dim() <= 4) {
let query_bucket_idx = query_key_dots.select(3, -1).ones_like()
* (query_key_dots.size().last().unwrap() - 1);
let mut query_shape = query_bucket_idx.size();

View File

@ -329,7 +329,7 @@ impl ReformerEncoder {
next_cache[layer_idx] = temp.new_layer_state;
}
hidden_state = Tensor::cat(&[hidden_state, attention_output], -1)
hidden_state = Tensor::cat(&[attention_output, hidden_state], -1)
.apply(&self.layer_norm)
.apply_t(&self.dropout, train);