Merge pull request #20 from guillaume-be/bart_optimizations

Bart optimizations
This commit is contained in:
guillaume-be 2020-04-07 10:43:28 +02:00 committed by GitHub
commit 1451a64b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 20 additions and 19 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.6.1"
version = "0.6.2"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"

View File

@ -5,7 +5,7 @@
[![Documentation](https://docs.rs/rust-bert/badge.svg)](https://docs.rs/rust-bert)
![License](https://img.shields.io/crates/l/rust_bert.svg)
Rust native BERT implementation. Port of Huggingface's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://https://github.com/guillaume-be/rust-tokenizers). Supports multithreaded tokenization and GPU inference.
Rust native Transformer-based models implementation. Port of Huggingface's [Transformers library](https://github.com/huggingface/transformers), using the [tch-rs](https://github.com/LaurentMazare/tch-rs) crate and pre-processing from [rust-tokenizers](https://https://github.com/guillaume-be/rust-tokenizers). Supports multithreaded tokenization and GPU inference.
This repository exposes the model base architecture, task-specific heads (see below) and ready-to-use pipelines.
The following models are currently implemented:
@ -110,7 +110,7 @@ Example output:
[
"The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year"
"The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me"
"The dog has been able to stay in the home for more than three months now. "It's a very good dog. She's"
"The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's"
"The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,"
"The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said"
"The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."

View File

@ -40,7 +40,7 @@ fn main() -> failure::Fallible<()> {
let device = Device::cuda_if_available();
let summarization_config = SummarizationConfig {
num_beams: 1,
num_beams: 3,
..Default::default()
};
@ -62,7 +62,7 @@ but previous discoveries were made on planets with high temperatures or other pr
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \"
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \

View File

@ -137,8 +137,8 @@ impl SelfAttention {
self.prev_state = match &self.prev_state {
Some(_) => Some(LayerState {
prev_key: Some(k.copy().view((bs, self.num_heads, -1, self.head_dim))),
prev_value: Some(v.copy().view((bs, self.num_heads, -1, self.head_dim))),
prev_key: Some(k.view((bs, self.num_heads, -1, self.head_dim))),
prev_value: Some(v.view((bs, self.num_heads, -1, self.head_dim))),
prev_key_padding_mask: match key_padding_mask.as_ref() {
Some(tensor) => Some(tensor.copy()),
None => None

View File

@ -109,20 +109,16 @@ impl DecoderLayer {
let (output, attention_weights) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let residual = output.copy();
let (output, _) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + residual;
let output = output.apply(&self.encoder_attention_layer_norm);
let residual = output.copy();
let output = (self.activation)(&output.apply(&self.fc1));
let output = output
let (output1, _) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, train);
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
let output1 = output1.apply(&self.encoder_attention_layer_norm);
let output2 = (self.activation)(&output1.apply(&self.fc1));
let output2 = output2
.apply_t(&self.activation_dropout, train)
.apply(&self.fc2)
.apply_t(&self.dropout, train);
let output: Tensor = output + residual;
(output.apply(&self.final_layer_norm), attention_weights)
let output2: Tensor = output2 + output1;
(output2.apply(&self.final_layer_norm), attention_weights)
}
}

View File

@ -153,7 +153,7 @@ impl DistilBertModel {
None => input_value.apply_t(&self.embeddings, train)
}
None => match input_embeds {
Some(embeds) => embeds.copy(),
Some(embeds) => embeds,
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
};

View File

@ -826,6 +826,7 @@ mod private_generation_utils {
outputs = temp.0;
encoder_outputs = temp.1;
past = temp.2;
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
@ -849,6 +850,7 @@ mod private_generation_utils {
for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) {
&scores.get(batch_index).index_fill_(0, &Tensor::of_slice(&index_banned_token).to_device(next_token_logits.device()), std::f64::NEG_INFINITY);
}
let (next_scores, next_tokens) = if do_sample {
let mut _scores: Tensor = &scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
@ -865,6 +867,7 @@ mod private_generation_utils {
let next_scores = next_scores.contiguous().view((batch_size, num_beams * vocab_size));
next_scores.topk(2 * num_beams, 1, true, true)
};
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec!();
for batch_index in 0..batch_size {
if done[batch_index as usize] {
@ -923,6 +926,7 @@ mod private_generation_utils {
if done.iter().all(|&x| x) {
break;
}
beam_scores = Tensor::of_slice(&next_batch_beam.iter().map(|(score, _, _)| *score).collect_vec()).to(input_ids.device());
beam_tokens = Tensor::of_slice(&next_batch_beam.iter().map(|(_, token, _)| *token).collect_vec()).to(input_ids.device());
beam_indices = Tensor::of_slice(&next_batch_beam.iter().map(|(_, _, index)| *index).collect_vec()).to(input_ids.device());
@ -1002,6 +1006,7 @@ mod private_generation_utils {
} else {
Tensor::stack(&best_ids, 0).to_kind(Int64).to(input_ids.device())
};
decoded
}