Merge remote-tracking branch 'origin/master' into m2m100_implementation

# Conflicts:
#	Cargo.toml
This commit is contained in:
Guillaume B 2021-06-28 18:53:46 +02:00
commit 0b2e339e87
44 changed files with 415 additions and 206 deletions

View File

@ -3,13 +3,15 @@ All notable changes to this project will be documented in this file. The format
## [Unreleased] ## [Unreleased]
## Added ## Added
- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions - (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models) - (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
- (BREAKING) Support for `output_scores` boolean argument for generation, allowing users to output the log-probability scores of generated sequences. Updated the return type of low-level generate API to `GeneratedTextOutput` and `GeneratedIndicesOutput` containing optional scores along with the generated output.
- Addition of the MBart Language model and support for text generation / direct translation between 50 language - Addition of the MBart Language model and support for text generation / direct translation between 50 language
- Addition of the M2M100 Language model and support for text generation / direct translation between 100 language - Addition of the M2M100 Language model and support for text generation / direct translation between 100 language
## Changed ## Changed
- Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint) - Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint)
- Upgraded `tch` version to 0.5.0 (using `libtorch` 1.9.0)
## [0.15.1] - 2021-06-01 ## [0.15.1] - 2021-06-01
### Fixed ### Fixed

View File

@ -58,7 +58,7 @@ features = ["doc-only"]
[dependencies] [dependencies]
rust_tokenizers = { version = "~6.2.4", path = "E:/Coding/backup-rust/rust-tokenizers/main" } rust_tokenizers = { version = "~6.2.4", path = "E:/Coding/backup-rust/rust-tokenizers/main" }
tch = "~0.4.1" tch = "~0.5.0"
serde_json = "1.0.64" serde_json = "1.0.64"
serde = { version = "1.0.126", features = ["derive"] } serde = { version = "1.0.126", features = ["derive"] }
dirs = "3.0.2" dirs = "3.0.2"
@ -72,5 +72,5 @@ thiserror = "1.0.24"
anyhow = "1.0.40" anyhow = "1.0.40"
csv = "1.1.6" csv = "1.1.6"
criterion = "0.3.4" criterion = "0.3.4"
torch-sys = "0.4.1" torch-sys = "0.5.0"
tempfile = "3.2.0" tempfile = "3.2.0"

View File

@ -71,8 +71,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended) ### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.8.1`: if this version is no longer available on the "get started" page, 1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.9.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11. the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.9.0%2Bcu111.zip` for a Linux version with CUDA11.
2. Extract the library to a location of your choice 2. Extract the library to a location of your choice
3. Set the following environment variables 3. Set the following environment variables
##### Linux: ##### Linux:

View File

@ -20,17 +20,17 @@ fn main() -> anyhow::Result<()> {
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
max_length: 30, max_length: 30,
do_sample: true, do_sample: false,
num_beams: 5, num_beams: 1,
temperature: 1.1, temperature: 1.0,
num_return_sequences: 3, num_return_sequences: 1,
..Default::default() ..Default::default()
}; };
let model = TextGenerationModel::new(generate_config)?; let model = TextGenerationModel::new(generate_config)?;
let input_context = "The dog"; let input_context = "The dog";
let second_input_context = "The cat was"; // let second_input_context = "The cat was";
let output = model.generate(&[input_context, second_input_context], None); let output = model.generate(&[input_context], None);
for sentence in output { for sentence in output {
println!("{:?}", sentence); println!("{:?}", sentence);

View File

@ -50,10 +50,11 @@ fn main() -> anyhow::Result<()> {
None, None,
target_language, target_language,
None, None,
false,
); );
for sentence in output { for sentence in output {
println!("{:?}", sentence); println!("{:?}", sentence.text);
} }
Ok(()) Ok(())
} }

View File

@ -41,7 +41,16 @@ fn main() -> anyhow::Result<()> {
// Define input // Define input
let input = ["translate English to German: This sentence will get translated to German"]; let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None); let output = t5_model.generate(
Some(input.to_vec()),
None,
None,
None,
None,
None,
None,
false,
);
println!("{:?}", output); println!("{:?}", output);
Ok(()) Ok(())

View File

@ -848,8 +848,8 @@ impl AlbertForQuestionAnswering {
.apply(&self.qa_outputs) .apply(&self.qa_outputs)
.split(1, -1); .split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
AlbertQuestionAnsweringOutput { AlbertQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -240,7 +240,7 @@ pub(crate) fn _make_causal_mask(
); );
let mask_cond = Tensor::arange(target_length, (dtype, device)); let mask_cond = Tensor::arange(target_length, (dtype, device));
let _ = mask.masked_fill_( let _ = mask.masked_fill_(
&mask_cond.lt1(&(&mask_cond + 1).view([target_length, 1])), &mask_cond.lt_tensor(&(&mask_cond + 1).view([target_length, 1])),
0, 0,
); );
@ -306,7 +306,10 @@ pub(crate) fn _prepare_decoder_attention_mask(
} }
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1; let index_eos: Tensor = input_ids
.ne(pad_token_id)
.sum_dim_intlist(&[-1], true, Int64)
- 1;
let output = input_ids.empty_like().to_kind(Int64); let output = input_ids.empty_like().to_kind(Int64);
output output
.select(1, 0) .select(1, 0)
@ -809,7 +812,7 @@ impl BartForSequenceClassification {
train, train,
); );
let eos_mask = input_ids.eq(self.eos_token_id); let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum1(&[1], true, Int64); let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64);
let sentence_representation = base_model_output let sentence_representation = base_model_output
.decoder_output .decoder_output
.permute(&[2, 0, 1]) .permute(&[2, 0, 1])

View File

@ -64,7 +64,7 @@ impl LearnedPositionalEmbedding {
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor { pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
let input_shape = input.size(); let input_shape = input.size();
let (_, sequence_length) = (input_shape[0], input_shape[1]); let (_, sequence_length) = (input_shape[0], input_shape[1]);
let positions = Tensor::arange1( let positions = Tensor::arange_start(
past_key_values_length, past_key_values_length,
past_key_values_length + sequence_length, past_key_values_length + sequence_length,
(Int64, input.device()), (Int64, input.device()),
@ -99,7 +99,7 @@ impl SinusoidalPositionalEmbedding {
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor { pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
let input_shape = input.size(); let input_shape = input.size();
let (_, sequence_length) = (input_shape[0], input_shape[1]); let (_, sequence_length) = (input_shape[0], input_shape[1]);
let positions = Tensor::arange1( let positions = Tensor::arange_start(
past_key_values_length, past_key_values_length,
past_key_values_length + sequence_length, past_key_values_length + sequence_length,
(Int64, input.device()), (Int64, input.device()),

View File

@ -323,7 +323,7 @@ impl<T: BertEmbedding> BertModel<T> {
input_shape[1], input_shape[1],
1, 1,
]); ]);
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1)); let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
causal_mask * mask.unsqueeze(1).unsqueeze(1) causal_mask * mask.unsqueeze(1).unsqueeze(1)
} else { } else {
mask.unsqueeze(1).unsqueeze(1) mask.unsqueeze(1).unsqueeze(1)
@ -1161,8 +1161,8 @@ impl BertForQuestionAnswering {
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs); let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1); let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
BertQuestionAnsweringOutput { BertQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -131,11 +131,12 @@ impl BertLayer {
encoder_mask: &Option<Tensor>, encoder_mask: &Option<Tensor>,
train: bool, train: bool,
) -> BertLayerOutput { ) -> BertLayerOutput {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, attention_scores, cross_attention_scores) = let (attention_output, attention_scores, cross_attention_scores) =
if self.is_decoder & encoder_hidden_states.is_some() { if self.is_decoder & encoder_hidden_states.is_some() {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, cross_attention_weights) = let (attention_output, cross_attention_weights) =
self.cross_attention.as_ref().unwrap().forward_t( self.cross_attention.as_ref().unwrap().forward_t(
&attention_output, &attention_output,
@ -146,9 +147,6 @@ impl BertLayer {
); );
(attention_output, attention_weights, cross_attention_weights) (attention_output, attention_weights, cross_attention_weights)
} else { } else {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
(attention_output, attention_weights, None) (attention_output, attention_weights, None)
}; };

View File

@ -78,7 +78,7 @@ impl SequenceSummary {
{ {
let p = p.borrow(); let p = p.borrow();
let summary_type = config.summary_type.clone().unwrap_or(SummaryType::last); let summary_type = config.summary_type.unwrap_or(SummaryType::last);
let summary = if let Some(summary_use_proj) = config.summary_use_proj { let summary = if let Some(summary_use_proj) = config.summary_use_proj {
let num_classes = match (config.summary_proj_to_labels, config.num_labels) { let num_classes = match (config.summary_proj_to_labels, config.num_labels) {
(Some(summary_proj_to_labels), Some(num_labels)) (Some(summary_proj_to_labels), Some(num_labels))
@ -132,7 +132,7 @@ impl SequenceSummary {
let mut output = match self.summary_type { let mut output = match self.summary_type {
SummaryType::last => hidden_states.select(1, -1), SummaryType::last => hidden_states.select(1, -1),
SummaryType::first => hidden_states.select(1, 0), SummaryType::first => hidden_states.select(1, 0),
SummaryType::mean => hidden_states.mean1(&[1], false, Kind::Float), SummaryType::mean => hidden_states.mean_dim(&[1], false, Kind::Float),
SummaryType::cls_index => { SummaryType::cls_index => {
let cls_index = if let Some(cls_index_value) = cls_index { let cls_index = if let Some(cls_index_value) = cls_index {
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1]; let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];
@ -147,7 +147,7 @@ impl SequenceSummary {
let fill_value = fill_value[2]; let fill_value = fill_value[2];
hidden_states.select(-2, 0).full_like(fill_value) hidden_states.select(-2, 0).full_like(fill_value)
}; };
hidden_states.gather(-2, &cls_index, false).squeeze1(-2) hidden_states.gather(-2, &cls_index, false).squeeze_dim(-2)
} }
}; };

View File

@ -83,7 +83,7 @@ impl MultiHeadSelfAttention {
let scores = if let Some(mask) = mask { let scores = if let Some(mask) = mask {
let unmasked_scores = q.matmul(&k.transpose(2, 3)); let unmasked_scores = q.matmul(&k.transpose(2, 3));
let mask = mask let mask = mask
.le1(&(mask.zeros_like() + 0.1)) .le_tensor(&(mask.zeros_like() + 0.1))
.view((bs, 1i64, 1i64, k_length)) .view((bs, 1i64, 1i64, k_length))
.expand_as(&unmasked_scores); .expand_as(&unmasked_scores);
unmasked_scores.masked_fill(&mask, f64::NEG_INFINITY) unmasked_scores.masked_fill(&mask, f64::NEG_INFINITY)

View File

@ -614,8 +614,8 @@ impl DistilBertForQuestionAnswering {
let logits = output.split(1, -1); let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
Ok(DistilBertQuestionAnsweringOutput { Ok(DistilBertQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -414,7 +414,7 @@ impl Gpt2Model {
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value.copy(), Some(value) => value.copy(),
None => Tensor::arange1( None => Tensor::arange_start(
layer_past_length, layer_past_length,
seq_length + layer_past_length, seq_length + layer_past_length,
(Int64, input_embeddings.device()), (Int64, input_embeddings.device()),

View File

@ -132,7 +132,9 @@ pub(crate) trait GptNeoAttentionUtils {
let query_indices = Self::split_sequence_length_dim_to(&indices, num_blocks, block_length)?; let query_indices = Self::split_sequence_length_dim_to(&indices, num_blocks, block_length)?;
let key_indices = Self::look_back(&indices, block_length, window_size, None, false)?; let key_indices = Self::look_back(&indices, block_length, window_size, None, false)?;
let causal_mask = query_indices.unsqueeze(-1).ge1(&key_indices.unsqueeze(-2)); let causal_mask = query_indices
.unsqueeze(-1)
.ge_tensor(&key_indices.unsqueeze(-2));
let calc_attention_mask = if attention_mask.is_none() { let calc_attention_mask = if attention_mask.is_none() {
Some(Tensor::ones( Some(Tensor::ones(
@ -212,7 +214,7 @@ pub(crate) trait GptNeoAttentionUtils {
) -> (Tensor, Tensor) { ) -> (Tensor, Tensor) {
let mut attention_weights = query let mut attention_weights = query
.matmul(&key.transpose(-1, -2)) .matmul(&key.transpose(-1, -2))
.where1(causal_mask, &masked_bias.to_kind(query.kind())); .where_self(causal_mask, &masked_bias.to_kind(query.kind()));
if let Some(attention_mask_value) = attention_mask { if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value; attention_weights = attention_weights + attention_mask_value;

View File

@ -345,7 +345,7 @@ impl GptNeoModel {
let calc_position_ids = if position_ids.is_none() { let calc_position_ids = if position_ids.is_none() {
let position_ids = let position_ids =
Tensor::arange1(past_length, full_sequence_length, (Kind::Int64, device)); Tensor::arange_start(past_length, full_sequence_length, (Kind::Int64, device));
Some( Some(
position_ids position_ids
.unsqueeze(0) .unsqueeze(0)

View File

@ -352,7 +352,8 @@ impl LongformerSelfAttention {
&self, &self,
is_index_global_attn: &Tensor, is_index_global_attn: &Tensor,
) -> GlobalAttentionIndices { ) -> GlobalAttentionIndices {
let num_global_attention_indices = is_index_global_attn.sum1(&[1], false, Kind::Int64); let num_global_attention_indices =
is_index_global_attn.sum_dim_intlist(&[1], false, Kind::Int64);
let max_num_global_attention_indices = i64::from(num_global_attention_indices.max()); let max_num_global_attention_indices = i64::from(num_global_attention_indices.max());
let is_index_global_attn_nonzero = is_index_global_attn let is_index_global_attn_nonzero = is_index_global_attn
.nonzero_numpy() .nonzero_numpy()
@ -364,7 +365,7 @@ impl LongformerSelfAttention {
max_num_global_attention_indices, max_num_global_attention_indices,
(Kind::Int64, is_index_global_attn.device()), (Kind::Int64, is_index_global_attn.device()),
) )
.lt1(&num_global_attention_indices.unsqueeze(-1)); .lt_tensor(&num_global_attention_indices.unsqueeze(-1));
let is_local_index_global_attention_nonzero = is_local_index_global_attention let is_local_index_global_attention_nonzero = is_local_index_global_attention
.nonzero_numpy() .nonzero_numpy()

View File

@ -86,7 +86,7 @@ impl LongformerEmbeddings {
let input_shape = inputs_embeds.size(); let input_shape = inputs_embeds.size();
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]); let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
Tensor::arange1( Tensor::arange_start(
self.pad_token_id + 1, self.pad_token_id + 1,
sequence_length + self.pad_token_id + 1, sequence_length + self.pad_token_id + 1,
(Kind::Int64, inputs_embeds.device()), (Kind::Int64, inputs_embeds.device()),

View File

@ -140,11 +140,13 @@ fn compute_global_attention_mask(
let attention_mask = Tensor::arange(input_ids.size()[1], (Kind::Int64, input_ids.device())); let attention_mask = Tensor::arange(input_ids.size()[1], (Kind::Int64, input_ids.device()));
if before_sep_token { if before_sep_token {
attention_mask.expand_as(input_ids).lt1(&question_end_index) attention_mask
.expand_as(input_ids)
.lt_tensor(&question_end_index)
} else { } else {
attention_mask attention_mask
.expand_as(input_ids) .expand_as(input_ids)
.gt1(&(question_end_index + 1)) .gt_tensor(&(question_end_index + 1))
* attention_mask * attention_mask
.expand_as(input_ids) .expand_as(input_ids)
.lt(*input_ids.size().last().unwrap()) .lt(*input_ids.size().last().unwrap())
@ -580,7 +582,7 @@ impl LongformerModel {
.unsqueeze(0) .unsqueeze(0)
.unsqueeze(0) .unsqueeze(0)
.repeat(&[batch_size, sequence_length, 1]) .repeat(&[batch_size, sequence_length, 1])
.le1(&sequence_ids.unsqueeze(-1).unsqueeze(0)) .le_tensor(&sequence_ids.unsqueeze(-1).unsqueeze(0))
.totype(Kind::Int); .totype(Kind::Int);
if causal_mask.size()[1] < padded_attention_mask.size()[1] { if causal_mask.size()[1] < padded_attention_mask.size()[1] {
let prefix_sequence_length = let prefix_sequence_length =
@ -1147,8 +1149,8 @@ impl LongformerForQuestionAnswering {
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs); let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1); let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
Ok(LongformerQuestionAnsweringOutput { Ok(LongformerQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -111,7 +111,10 @@ impl Config for MBartConfig {}
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id); let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id);
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[1], true, Int64) - 1; let index_eos: Tensor = input_ids
.ne(pad_token_id)
.sum_dim_intlist(&[1], true, Int64)
- 1;
output output
.select(1, 0) .select(1, 0)
.copy_(&input_ids.gather(1, &index_eos, true).squeeze()); .copy_(&input_ids.gather(1, &index_eos, true).squeeze());
@ -632,7 +635,7 @@ impl MBartForSequenceClassification {
train, train,
); );
let eos_mask = input_ids.eq(self.eos_token_id); let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum1(&[1], true, Int64); let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64);
let sentence_representation = base_model_output let sentence_representation = base_model_output
.decoder_output .decoder_output
.permute(&[2, 0, 1]) .permute(&[2, 0, 1])

View File

@ -901,8 +901,8 @@ impl MobileBertForQuestionAnswering {
let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs); let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1); let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
Ok(MobileBertQuestionAnsweringOutput { Ok(MobileBertQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -87,7 +87,7 @@ impl SinusoidalPositionalEmbedding {
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor { pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
let input_shape = input.size(); let input_shape = input.size();
let (_, sequence_length) = (input_shape[0], input_shape[1]); let (_, sequence_length) = (input_shape[0], input_shape[1]);
let positions = Tensor::arange1( let positions = Tensor::arange_start(
past_key_values_length, past_key_values_length,
past_key_values_length + sequence_length, past_key_values_length + sequence_length,
(Kind::Int64, input.device()), (Kind::Int64, input.device()),

View File

@ -716,15 +716,20 @@ impl ConversationOption {
attention_mask: Option<Tensor>, attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> { ) -> Vec<Vec<i64>> {
match *self { match *self {
Self::GPT2(ref model) => model.generate_from_ids_and_past( Self::GPT2(ref model) => model
input_ids, .generate_from_ids_and_past(
attention_mask, input_ids,
None, attention_mask,
None, None,
None, None,
None, None,
None, None,
), None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
} }
} }
} }

View File

@ -48,6 +48,7 @@
//! decoder_start_id, //! decoder_start_id,
//! forced_bos_token_id, //! forced_bos_token_id,
//! None, //! None,
//! false,
//! ); //! );
//! # Ok(()) //! # Ok(())
//! # } //! # }
@ -492,13 +493,13 @@ pub(crate) mod private_generation_utils {
if min_tokens_to_keep > 1 { if min_tokens_to_keep > 1 {
let _ = sorted_indices_to_remove.index_fill_( let _ = sorted_indices_to_remove.index_fill_(
1, 1,
&Tensor::arange1(0, min_tokens_to_keep + 1, (Int64, logits.device())), &Tensor::arange_start(0, min_tokens_to_keep + 1, (Int64, logits.device())),
0, 0,
); );
} }
let _ = sorted_indices_to_remove.index_copy_( let _ = sorted_indices_to_remove.index_copy_(
1, 1,
&Tensor::arange1(1, vocab_size, (Int64, logits.device())), &Tensor::arange_start(1, vocab_size, (Int64, logits.device())),
&sorted_indices_to_remove &sorted_indices_to_remove
.slice(1, 0, vocab_size - 1, 1) .slice(1, 0, vocab_size - 1, 1)
.copy(), .copy(),
@ -585,7 +586,8 @@ pub(crate) mod private_generation_utils {
attention_mask: Tensor, attention_mask: Tensor,
gen_opt: GenerateOptions, gen_opt: GenerateOptions,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Tensor { output_scores: bool,
) -> (Tensor, Option<Vec<f64>>) {
let mut unfinished_sentences = let mut unfinished_sentences =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device())); Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
let mut sentence_lengths: Tensor = let mut sentence_lengths: Tensor =
@ -596,6 +598,14 @@ pub(crate) mod private_generation_utils {
let mut past: Cache = Cache::None; let mut past: Cache = Cache::None;
let mut outputs: Tensor; let mut outputs: Tensor;
let mut current_length = cur_len; let mut current_length = cur_len;
let mut scores_output = if output_scores {
Some(Tensor::zeros(
&[batch_size],
(Float, self.get_var_store().device()),
))
} else {
None
};
while current_length < gen_opt.max_length { while current_length < gen_opt.max_length {
let prepared_input = self.prepare_inputs_for_generation( let prepared_input = self.prepare_inputs_for_generation(
@ -690,11 +700,22 @@ pub(crate) mod private_generation_utils {
1, 1,
); );
let probabilities = next_token_logits.softmax(-1, Float); let probabilities = next_token_logits.softmax(-1, Float);
probabilities.multinomial(1, false).squeeze1(1) probabilities.multinomial(1, false).squeeze_dim(1)
} else { } else {
next_token_logits.argmax(-1, false) next_token_logits.argmax(-1, false)
}; };
if let Some(prev_scores) = scores_output {
let finished_mask = unfinished_sentences.eq(0);
scores_output = Some(
prev_scores
+ (&next_token_logits
.log_softmax(-1, Float)
.gather(1, &next_token.reshape(&[-1, 1]), true)
.squeeze()
.masked_fill(&finished_mask, 0)),
);
}
// Add tokens to unfinished sentences // Add tokens to unfinished sentences
let tokens_to_add = match &gen_opt.eos_token_ids { let tokens_to_add = match &gen_opt.eos_token_ids {
Some(_) => { Some(_) => {
@ -736,7 +757,14 @@ pub(crate) mod private_generation_utils {
} }
current_length += 1; current_length += 1;
} }
input_ids let scores_output = scores_output.map(|scores_tensor| {
(scores_tensor / sentence_lengths.pow(gen_opt.length_penalty))
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>()
});
(input_ids, scores_output)
} }
fn generate_beam_search( fn generate_beam_search(
@ -748,7 +776,8 @@ pub(crate) mod private_generation_utils {
mut attention_mask: Tensor, mut attention_mask: Tensor,
gen_opt: GenerateOptions, gen_opt: GenerateOptions,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Tensor { output_scores: bool,
) -> (Tensor, Option<Vec<f64>>) {
let num_beam_groups = gen_opt.num_beam_groups.unwrap_or(1); let num_beam_groups = gen_opt.num_beam_groups.unwrap_or(1);
let num_sub_beams = gen_opt.num_beams / num_beam_groups; let num_sub_beams = gen_opt.num_beams / num_beam_groups;
let diversity_penalty = gen_opt.diversity_penalty.unwrap_or(5.5); let diversity_penalty = gen_opt.diversity_penalty.unwrap_or(5.5);
@ -960,12 +989,12 @@ pub(crate) mod private_generation_utils {
}; };
let eos_token_ids = gen_opt.eos_token_ids.as_ref(); let eos_token_ids = gen_opt.eos_token_ids.as_ref();
let beam_ids_tensor = &next_tokens.floor_divide1(vocab_size); let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor");
let effective_beam_ids_tensor = (&next_tokens.ones_like().cumsum(0, Int64) - 1) let effective_beam_ids_tensor = (&next_tokens.ones_like().cumsum(0, Int64) - 1)
* group_size * group_size
+ beam_ids_tensor; + beam_ids_tensor;
let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size; let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
let (max_scores, _) = next_scores.max2(1, false); let (max_scores, _) = next_scores.max_dim(1, false);
let mut eos_mask = token_id_tensor.ones_like(); let mut eos_mask = token_id_tensor.ones_like();
if let Some(eos_token_id) = eos_token_ids { if let Some(eos_token_id) = eos_token_ids {
eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Int64); eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Int64);
@ -1034,7 +1063,7 @@ pub(crate) mod private_generation_utils {
&group_beam_tokens, &group_beam_tokens,
); );
let new_indices = gen_opt.num_beams let new_indices = gen_opt.num_beams
* group_beam_indices.floor_divide1(group_size) * group_beam_indices.divide_scalar_mode(group_size, "floor")
+ group_start_index + group_start_index
+ group_beam_indices.remainder(group_size); + group_beam_indices.remainder(group_size);
let _ = beam_indices.index_copy_( let _ = beam_indices.index_copy_(
@ -1110,6 +1139,11 @@ pub(crate) mod private_generation_utils {
Tensor::zeros(&[output_batch_size], (Int64, input_ids.device())); Tensor::zeros(&[output_batch_size], (Int64, input_ids.device()));
let mut best_ids = vec![]; let mut best_ids = vec![];
let mut scores_output = if output_scores {
Some(Vec::with_capacity(best_ids.len()))
} else {
None
};
for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() { for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
let mut sorted_hypotheses = hypothesis.clone(); let mut sorted_hypotheses = hypothesis.clone();
sorted_hypotheses sorted_hypotheses
@ -1118,13 +1152,16 @@ pub(crate) mod private_generation_utils {
for j in 0..output_num_return_sequences_per_batch { for j in 0..output_num_return_sequences_per_batch {
let effective_batch_index = let effective_batch_index =
output_num_return_sequences_per_batch * hypothesis_index as i64 + j; output_num_return_sequences_per_batch * hypothesis_index as i64 + j;
let (_, best_hyp) = sorted_hypotheses.beams.pop().unwrap(); let (best_score, best_hyp) = sorted_hypotheses.beams.pop().unwrap();
let _ = sentence_lengths.index_fill_( let _ = sentence_lengths.index_fill_(
0, 0,
&Tensor::of_slice(&[effective_batch_index]).to(sentence_lengths.device()), &Tensor::of_slice(&[effective_batch_index]).to(sentence_lengths.device()),
*best_hyp.size().first().unwrap(), *best_hyp.size().first().unwrap(),
); );
best_ids.push(best_hyp); best_ids.push(best_hyp);
if let Some(current_best_scores) = &mut scores_output {
current_best_scores.push(best_score);
}
} }
} }
let sentence_max_length = let sentence_max_length =
@ -1143,7 +1180,7 @@ pub(crate) mod private_generation_utils {
for (hypothesis_index, best_id) in best_ids.iter().enumerate() { for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
let _ = decoded.get(hypothesis_index as i64).index_copy_( let _ = decoded.get(hypothesis_index as i64).index_copy_(
0, 0,
&Tensor::arange1( &Tensor::arange_start(
0, 0,
i64::from(sentence_lengths.get(hypothesis_index as i64)), i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()), (Int64, input_ids.device()),
@ -1159,7 +1196,7 @@ pub(crate) mod private_generation_utils {
); );
} }
} }
decoded (decoded, scores_output)
} }
fn reorder_cache( fn reorder_cache(
@ -1178,6 +1215,22 @@ pub(crate) mod private_generation_utils {
} }
} }
#[derive(Debug, Clone)]
/// # Generated text output
/// Contains generated text and an optional log-likelihood score for the generated sequence
pub struct GeneratedTextOutput {
pub text: String,
pub score: Option<f64>,
}
#[derive(Debug, Clone)]
/// # Generated indices output
/// Contains generated indices and an optional log-likelihood score for the generated sequence
pub struct GeneratedIndicesOutput {
pub indices: Vec<i64>,
pub score: Option<f64>,
}
/// # Common trait for text generation models. /// # Common trait for text generation models.
/// Main API for text generation /// Main API for text generation
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
@ -1195,7 +1248,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens. /// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
/// ///
/// # Returns /// # Returns
/// * `Vec<String>` Vector of generated strings based on the prompts of length *number_of_prompts* x *num_return_sequences*. /// * `Vec<TextOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing TextOutput with the generated texts and the generation score if `output_scores` is true.
/// ///
/// # Example /// # Example
/// ///
@ -1231,6 +1284,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128; /// let max_length = 128;
/// let decoder_start_token_id = None; /// let decoder_start_token_id = None;
/// let forced_bos_token_id = None; /// let forced_bos_token_id = None;
/// let output_scores = true;
/// ///
/// //Example custom function for fine-grained generation control /// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1257,6 +1311,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id, /// decoder_start_token_id,
/// forced_bos_token_id, /// forced_bos_token_id,
/// Some(&force_one_paragraph), /// Some(&force_one_paragraph),
/// output_scores,
/// ); /// );
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -1283,11 +1338,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id: impl Into<Option<i64>>, decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>, forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Vec<String> output_scores: bool,
) -> Vec<GeneratedTextOutput>
where where
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
let generated = self.generate_indices( let indices_outputs = self.generate_indices(
prompt_texts, prompt_texts,
attention_mask, attention_mask,
min_length, min_length,
@ -1295,10 +1351,16 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id, decoder_start_token_id,
forced_bos_token_id, forced_bos_token_id,
prefix_allowed_tokens_fn, prefix_allowed_tokens_fn,
output_scores,
); );
let mut output = Vec::with_capacity(generated.len()); let mut output = Vec::with_capacity(indices_outputs.len());
for generated_sequence in generated { for generated_sequence in indices_outputs {
output.push(self._get_tokenizer().decode(generated_sequence, true, true)); output.push(GeneratedTextOutput {
text: self
._get_tokenizer()
.decode(generated_sequence.indices, true, true),
score: generated_sequence.score,
});
} }
output output
} }
@ -1315,7 +1377,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens. /// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
/// ///
/// # Returns /// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*. /// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
/// ///
/// # Example /// # Example
/// ///
@ -1350,6 +1412,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128; /// let max_length = 128;
/// let decoder_start_token_id = None; /// let decoder_start_token_id = None;
/// let forced_bos_token_id = None; /// let forced_bos_token_id = None;
/// let output_scores = true;
/// ///
/// //Example custom function for fine-grained generation control /// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1376,6 +1439,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id, /// decoder_start_token_id,
/// forced_bos_token_id, /// forced_bos_token_id,
/// Some(&force_one_paragraph), /// Some(&force_one_paragraph),
/// output_scores,
/// ); /// );
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -1389,7 +1453,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id: impl Into<Option<i64>>, decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>, forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Vec<Vec<i64>> output_scores: bool,
) -> Vec<GeneratedIndicesOutput>
where where
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
@ -1426,6 +1491,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id, decoder_start_token_id,
forced_bos_token_id, forced_bos_token_id,
prefix_allowed_tokens_fn, prefix_allowed_tokens_fn,
output_scores,
) )
} }
@ -1442,7 +1508,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens. /// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
/// ///
/// # Returns /// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*. /// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
/// ///
/// # Example /// # Example
/// ///
@ -1477,6 +1543,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128; /// let max_length = 128;
/// let decoder_start_token_id = None; /// let decoder_start_token_id = None;
/// let forced_bos_token_id = None; /// let forced_bos_token_id = None;
/// let output_scores = true;
/// ///
/// //Example custom function for fine-grained generation control /// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> { /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1503,6 +1570,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// decoder_start_token_id, /// decoder_start_token_id,
/// forced_bos_token_id, /// forced_bos_token_id,
/// Some(&force_one_paragraph), /// Some(&force_one_paragraph),
/// output_scores,
/// ); /// );
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -1516,7 +1584,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
decoder_start_token_id: impl Into<Option<i64>>, decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>, forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>, prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Vec<Vec<i64>> { output_scores: bool,
) -> Vec<GeneratedIndicesOutput> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone(); let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self); let config = PrivateLanguageGenerator::get_config(self);
@ -1647,7 +1716,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
forced_bos_token_id: forced_bos_token_id.into(), forced_bos_token_id: forced_bos_token_id.into(),
}; };
let decoded = no_grad(|| { let (decoded, scores) = no_grad(|| {
if num_beams > 1 { if num_beams > 1 {
self.generate_beam_search( self.generate_beam_search(
input_ids, input_ids,
@ -1657,6 +1726,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
attention_mask, attention_mask,
gen_opt, gen_opt,
prefix_allowed_tokens_fn, prefix_allowed_tokens_fn,
output_scores,
) )
} else { } else {
self.generate_no_beam_search( self.generate_no_beam_search(
@ -1667,21 +1737,25 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
attention_mask, attention_mask,
gen_opt, gen_opt,
prefix_allowed_tokens_fn, prefix_allowed_tokens_fn,
output_scores,
) )
} }
}); });
let num_sequences = *decoded.size().first().unwrap(); let num_sequences = *decoded.size().first().unwrap();
let mut output_ids = Vec::with_capacity(num_sequences as usize); let mut output = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences { for sequence_index in 0..num_sequences {
let sequence_output_ids = decoded let indices = decoded
.as_ref() .as_ref()
.get(sequence_index) .get(sequence_index)
.iter::<i64>() .iter::<i64>()
.unwrap() .unwrap()
.collect::<Vec<i64>>(); .collect::<Vec<i64>>();
output_ids.push(sequence_output_ids.clone()); let score = scores
.as_ref()
.map(|scores_value| scores_value[sequence_index as usize]);
output.push(GeneratedIndicesOutput { indices, score });
} }
output_ids output
} }
/// Returns a reference to the text generator's tokenizer /// Returns a reference to the text generator's tokenizer

View File

@ -597,10 +597,10 @@ impl SequenceClassificationModel {
); );
output.softmax(-1, Kind::Float).detach().to(Device::Cpu) output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
}); });
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1); let label_indices = output.as_ref().argmax(-1, true).squeeze_dim(1);
let scores = output let scores = output
.gather(1, &label_indices.unsqueeze(-1), false) .gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1); .squeeze_dim(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>(); let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>(); let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();

View File

@ -263,18 +263,62 @@ impl SummarizationOption {
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
match *self { match *self {
Self::Bart(ref model) => { Self::Bart(ref model) => model
model.generate(prompt_texts, attention_mask, None, None, None, None, None) .generate(
} prompt_texts,
Self::T5(ref model) => { attention_mask,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) None,
} None,
Self::ProphetNet(ref model) => { None,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) None,
} None,
Self::Pegasus(ref model) => { false,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) )
} .into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::Pegasus(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
} }
} }
} }

View File

@ -249,51 +249,76 @@ impl TextGenerationOption {
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
match *self { match *self {
Self::GPT(ref model) => model.generate_indices( Self::GPT(ref model) => model
prompt_texts, .generate_indices(
attention_mask, prompt_texts,
min_length, attention_mask,
max_length, min_length,
None, max_length,
None, None,
None, None,
), None,
Self::GPT2(ref model) => model.generate_indices( false,
prompt_texts, )
attention_mask, .into_iter()
min_length, .map(|output| output.indices)
max_length, .collect(),
None, Self::GPT2(ref model) => model
None, .generate_indices(
None, prompt_texts,
), attention_mask,
Self::GPTNeo(ref model) => model.generate_indices( min_length,
prompt_texts, max_length,
attention_mask, None,
min_length, None,
max_length, None,
None, false,
None, )
None, .into_iter()
), .map(|output| output.indices)
Self::XLNet(ref model) => model.generate_indices( .collect(),
prompt_texts, Self::GPTNeo(ref model) => model
attention_mask, .generate_indices(
min_length, prompt_texts,
max_length, attention_mask,
None, min_length,
None, max_length,
None, None,
), None,
Self::Reformer(ref model) => model.generate_indices( None,
prompt_texts, false,
attention_mask, )
min_length, .into_iter()
max_length, .map(|output| output.indices)
None, .collect(),
None, Self::XLNet(ref model) => model
None, .generate_indices(
), prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
Self::Reformer(ref model) => model
.generate_indices(
prompt_texts,
attention_mask,
min_length,
max_length,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.indices)
.collect(),
} }
} }
} }

View File

@ -692,7 +692,7 @@ impl TokenClassificationModel {
) )
}); });
let output = output.detach().to(Device::Cpu); let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float); let score: Tensor = output.exp() / output.exp().sum_dim_intlist(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true); let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Vec<Token>> = vec![]; let mut tokens: Vec<Vec<Token>> = vec![];
for sentence_idx in 0..labels_idx.size()[0] { for sentence_idx in 0..labels_idx.size()[0] {

View File

@ -674,12 +674,34 @@ impl TranslationOption {
S: AsRef<[&'a str]>, S: AsRef<[&'a str]>,
{ {
match *self { match *self {
Self::Marian(ref model) => { Self::Marian(ref model) => model
model.generate(prompt_texts, attention_mask, None, None, None, None, None) .generate(
} prompt_texts,
Self::T5(ref model) => { attention_mask,
model.generate(prompt_texts, attention_mask, None, None, None, None, None) None,
} None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
Self::T5(ref model) => model
.generate(
prompt_texts,
attention_mask,
None,
None,
None,
None,
None,
false,
)
.into_iter()
.map(|output| output.text)
.collect(),
} }
} }
} }

View File

@ -687,10 +687,10 @@ impl ZeroShotClassificationModel {
}); });
let scores = output.softmax(1, Float).select(-1, -1); let scores = output.softmax(1, Float).select(-1, -1);
let label_indices = scores.as_ref().argmax(-1, true).squeeze1(1); let label_indices = scores.as_ref().argmax(-1, true).squeeze_dim(1);
let scores = scores let scores = scores
.gather(1, &label_indices.unsqueeze(-1), false) .gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1); .squeeze_dim(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>(); let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>(); let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();

View File

@ -602,7 +602,7 @@ impl ProphetNetNgramAttention {
let hidden_states_size = hidden_states.size(); let hidden_states_size = hidden_states.size();
let (sequence_length, batch_size) = (hidden_states_size[0], hidden_states_size[1]); let (sequence_length, batch_size) = (hidden_states_size[0], hidden_states_size[1]);
let calc_main_relative_position_buckets = if main_relative_position_buckets.is_none() { let calc_main_relative_position_buckets = if main_relative_position_buckets.is_none() {
let relative_positions = Tensor::arange1( let relative_positions = Tensor::arange_start(
1, 1,
attention_weights.size().last().unwrap() + 1, attention_weights.size().last().unwrap() + 1,
(Kind::Int64, hidden_states.device()), (Kind::Int64, hidden_states.device()),
@ -742,7 +742,7 @@ pub(crate) fn compute_relative_buckets(
( (
num_buckets, num_buckets,
relative_positions.zeros_like(), relative_positions.zeros_like(),
inverse_relative_positions.max1(&inverse_relative_positions.zeros_like()), inverse_relative_positions.max_other(&inverse_relative_positions.zeros_like()),
) )
}; };
let max_exact = num_buckets / 2; let max_exact = num_buckets / 2;
@ -754,10 +754,10 @@ pub(crate) fn compute_relative_buckets(
+ max_exact_f64; + max_exact_f64;
let val_if_large = val_if_large let val_if_large = val_if_large
.min1(&(val_if_large.ones_like() * (num_buckets as f64 - 1.0))) .min_other(&(val_if_large.ones_like() * (num_buckets as f64 - 1.0)))
.totype(Kind::Int64); .totype(Kind::Int64);
relative_positions_bucket + inverse_relative_positions.where1(&is_small, &val_if_large) relative_positions_bucket + inverse_relative_positions.where_self(&is_small, &val_if_large)
} }
pub(crate) fn compute_all_stream_relative_buckets( pub(crate) fn compute_all_stream_relative_buckets(

View File

@ -308,9 +308,9 @@ impl ProphetNetDecoder {
let hidden_states = (input_embeds + main_stream_pos_embed).transpose(0, 1); let hidden_states = (input_embeds + main_stream_pos_embed).transpose(0, 1);
let (mut ngram_hidden_states, extended_attention_mask, extended_predict_attention_mask) = let (mut ngram_hidden_states, extended_attention_mask, extended_predict_attention_mask) = {
let mut ngram_hidden_states = Vec::with_capacity(self.ngram as usize);
if old_layer_states.is_some() { if old_layer_states.is_some() {
let mut ngram_hidden_states = Vec::with_capacity(self.ngram as usize);
for ngram in 0..self.ngram { for ngram in 0..self.ngram {
ngram_hidden_states.push( ngram_hidden_states.push(
(&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed) (&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed)
@ -320,7 +320,6 @@ impl ProphetNetDecoder {
} }
(ngram_hidden_states, None, None) (ngram_hidden_states, None, None)
} else { } else {
let mut ngram_hidden_states = Vec::with_capacity(self.ngram as usize);
for ngram in 0..self.ngram { for ngram in 0..self.ngram {
ngram_hidden_states.push( ngram_hidden_states.push(
(&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed) (&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed)
@ -336,7 +335,8 @@ impl ProphetNetDecoder {
Some(extended_attention_mask), Some(extended_attention_mask),
Some(extended_predict_attention_mask), Some(extended_predict_attention_mask),
) )
}; }
};
let extended_encoder_attention_mask = let extended_encoder_attention_mask =
encoder_attention_mask.map(|encoder_attention_mask_value| { encoder_attention_mask.map(|encoder_attention_mask_value| {
@ -510,7 +510,7 @@ impl ProphetNetDecoder {
let input_size = position_ids.size(); let input_size = position_ids.size();
let (batch_size, sequence_length) = (input_size[0], input_size[1]); let (batch_size, sequence_length) = (input_size[0], input_size[1]);
let position_ids = Tensor::arange1( let position_ids = Tensor::arange_start(
1, 1,
self.max_target_positions, self.max_target_positions,
(Kind::Int64, position_ids.device()), (Kind::Int64, position_ids.device()),

View File

@ -307,7 +307,7 @@ impl LSHSelfAttention {
.unsqueeze(1) .unsqueeze(1)
.expand(&buckets.size(), true) .expand(&buckets.size(), true)
.to_kind(Kind::Bool); .to_kind(Kind::Bool);
buckets = buckets.where1( buckets = buckets.where_self(
&buckets_mask, &buckets_mask,
&Tensor::of_slice(&[num_buckets - 1]) &Tensor::of_slice(&[num_buckets - 1])
.to_kind(Kind::Float) .to_kind(Kind::Float)
@ -423,15 +423,16 @@ impl LSHSelfAttention {
); );
if let Some(mask) = mask { if let Some(mask) = mask {
query_key_dots = query_key_dots.where1(&mask.to_kind(Kind::Bool), &self.mask_value); query_key_dots =
query_key_dots.where_self(&mask.to_kind(Kind::Bool), &self.mask_value);
} }
} }
{ {
let self_mask = query_bucket_idx let self_mask = query_bucket_idx
.unsqueeze(-1) .unsqueeze(-1)
.ne1(&key_value_bucket_idx.unsqueeze(-2)); .ne_tensor(&key_value_bucket_idx.unsqueeze(-2));
query_key_dots = query_key_dots =
query_key_dots.where1(&self_mask.to_kind(Kind::Bool), &self.self_mask_value); query_key_dots.where_self(&self_mask.to_kind(Kind::Bool), &self.self_mask_value);
} }
let mut logits = query_key_dots.logsumexp(&[-1], true); let mut logits = query_key_dots.logsumexp(&[-1], true);
@ -441,7 +442,7 @@ impl LSHSelfAttention {
let mut out_vectors = attention_probs.matmul(&value_vectors); let mut out_vectors = attention_probs.matmul(&value_vectors);
if out_vectors.dim() > 4 { if out_vectors.dim() > 4 {
logits = logits.flatten(2, 3).squeeze1(-1); logits = logits.flatten(2, 3).squeeze_dim(-1);
out_vectors = out_vectors.flatten(2, 3) out_vectors = out_vectors.flatten(2, 3)
} }
@ -476,7 +477,9 @@ impl LSHSelfAttention {
}; };
if self.is_decoder { if self.is_decoder {
let causal_mask = query_indices.unsqueeze(-1).ge1(&key_indices.unsqueeze(-2)); let causal_mask = query_indices
.unsqueeze(-1)
.ge_tensor(&key_indices.unsqueeze(-2));
let attention_mask = if let Some(attention_mask) = attention_mask { let attention_mask = if let Some(attention_mask) = attention_mask {
causal_mask * attention_mask causal_mask * attention_mask
} else { } else {
@ -534,7 +537,10 @@ impl LSHSelfAttention {
*relevant_bucket_indices_chunk.size().last().unwrap(), *relevant_bucket_indices_chunk.size().last().unwrap(),
(Kind::Int64, hidden_states.device()), (Kind::Int64, hidden_states.device()),
) )
.floor_divide1(*relevant_bucket_indices_chunk.size().last().unwrap())); .divide_scalar_mode(
*relevant_bucket_indices_chunk.size().last().unwrap(),
"floor",
));
let relevant_bucket_indices_chunk_all_batch = let relevant_bucket_indices_chunk_all_batch =
&relevant_bucket_indices_chunk + bucket_indices_batch_offset; &relevant_bucket_indices_chunk + bucket_indices_batch_offset;
@ -566,7 +572,9 @@ impl LSHSelfAttention {
indices: &Tensor, indices: &Tensor,
sequence_length: i64, sequence_length: i64,
) -> Tensor { ) -> Tensor {
let start_indices_chunk = (indices.select(1, -1).floor_divide1(self.chunk_length) let start_indices_chunk = (indices
.select(1, -1)
.divide_scalar_mode(self.chunk_length, "floor")
- self.num_chunks_before) - self.num_chunks_before)
* self.chunk_length; * self.chunk_length;
let total_chunk_size = let total_chunk_size =
@ -593,7 +601,7 @@ impl LSHSelfAttention {
} }
fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor { fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor {
let variance = (input_tensor * input_tensor).mean1(&[-1], true, input_tensor.kind()); let variance = (input_tensor * input_tensor).mean_dim(&[-1], true, input_tensor.kind());
input_tensor * (variance + epsilon).rsqrt() input_tensor * (variance + epsilon).rsqrt()
} }
@ -850,7 +858,7 @@ impl LSHSelfAttention {
)? )?
.unsqueeze(-1); .unsqueeze(-1);
let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp(); let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp();
out_vectors = (out_vectors * probs_vectors).sum1(&[2], false, Kind::Float); out_vectors = (out_vectors * probs_vectors).sum_dim_intlist(&[2], false, Kind::Float);
} }
out_vectors = merge_hidden_size_dim( out_vectors = merge_hidden_size_dim(
@ -967,7 +975,9 @@ impl LocalSelfAttention {
}); });
if self.is_decoder { if self.is_decoder {
let causal_mask = query_indices.unsqueeze(-1).ge1(&key_indices.unsqueeze(-2)); let causal_mask = query_indices
.unsqueeze(-1)
.ge_tensor(&key_indices.unsqueeze(-2));
attention_mask = Some(if let Some(mask) = attention_mask { attention_mask = Some(if let Some(mask) = attention_mask {
causal_mask * mask causal_mask * mask
} else { } else {
@ -1087,7 +1097,7 @@ impl LocalSelfAttention {
); );
if let Some(mask) = attention_mask { if let Some(mask) = attention_mask {
query_key_dots = query_key_dots.where1(&mask.to_kind(Kind::Bool), &self.mask_value); query_key_dots = query_key_dots.where_self(&mask.to_kind(Kind::Bool), &self.mask_value);
} }
let logits = query_key_dots.logsumexp(&[-1], true); let logits = query_key_dots.logsumexp(&[-1], true);

View File

@ -263,10 +263,9 @@ impl ReformerEmbeddings {
let calc_position_ids = if position_ids.is_none() { let calc_position_ids = if position_ids.is_none() {
Some( Some(
Tensor::arange2( Tensor::arange_start(
start_ids_pos_encoding, start_ids_pos_encoding,
start_ids_pos_encoding + input_shape[1], start_ids_pos_encoding + input_shape[1],
1,
(Kind::Int64, device), (Kind::Int64, device),
) )
.unsqueeze(0) .unsqueeze(0)

View File

@ -418,10 +418,9 @@ impl ReformerModel {
let input_ids = Tensor::cat(&[input_ids, &input_ids_padding], -1); let input_ids = Tensor::cat(&[input_ids, &input_ids_padding], -1);
new_input_shape = input_ids.size(); new_input_shape = input_ids.size();
let position_ids = if let Some(position_ids) = position_ids { let position_ids = if let Some(position_ids) = position_ids {
let position_ids_padding = Tensor::arange2( let position_ids_padding = Tensor::arange_start(
*input_shape.last().unwrap(), *input_shape.last().unwrap(),
self.least_common_mult_chunk_length, self.least_common_mult_chunk_length,
1,
(Kind::Int64, device), (Kind::Int64, device),
) )
.unsqueeze(0) .unsqueeze(0)
@ -568,7 +567,7 @@ impl ReformerModelWithLMHead {
/// ///
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// reformer_model.forward_t( /// reformer_model.forward_t(
/// Some(&input_tensor), /// Some(&input_tensor),
/// Some(&input_positions), /// Some(&input_positions),
/// None, /// None,
/// Some(&attention_mask), /// Some(&attention_mask),
@ -801,7 +800,7 @@ impl ReformerForSequenceClassification {
/// ///
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// reformer_model.forward_t( /// reformer_model.forward_t(
/// Some(&input_tensor), /// Some(&input_tensor),
/// Some(&input_positions), /// Some(&input_positions),
/// None, /// None,
/// Some(&attention_mask), /// Some(&attention_mask),
@ -939,7 +938,7 @@ impl ReformerForQuestionAnswering {
/// ///
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// reformer_model.forward_t( /// reformer_model.forward_t(
/// Some(&input_tensor), /// Some(&input_tensor),
/// Some(&input_positions), /// Some(&input_positions),
/// None, /// None,
/// Some(&attention_mask), /// Some(&attention_mask),
@ -972,8 +971,8 @@ impl ReformerForQuestionAnswering {
.apply(&self.qa_outputs) .apply(&self.qa_outputs)
.split(1, -1); .split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
Ok(ReformerQuestionAnsweringModelOutput { Ok(ReformerQuestionAnsweringModelOutput {
start_logits, start_logits,

View File

@ -39,7 +39,7 @@ impl RobertaEmbeddings {
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor { fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
let input_shape = x.size(); let input_shape = x.size();
let input_shape = vec![input_shape[0], input_shape[1]]; let input_shape = vec![input_shape[0], input_shape[1]];
let position_ids: Tensor = Tensor::arange1( let position_ids: Tensor = Tensor::arange_start(
self.padding_index + 1, self.padding_index + 1,
input_shape[0], input_shape[0],
(Kind::Int64, x.device()), (Kind::Int64, x.device()),

View File

@ -961,8 +961,8 @@ impl RobertaForQuestionAnswering {
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs); let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1); let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
RobertaQuestionAnsweringOutput { RobertaQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -254,7 +254,7 @@ impl T5Attention {
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets; ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
n.abs() n.abs()
} else { } else {
n.max1(&n.zeros_like()) n.max_other(&n.zeros_like())
}; };
let max_exact = num_buckets / 2; let max_exact = num_buckets / 2;
@ -266,8 +266,8 @@ impl T5Attention {
.to_kind(Kind::Int64) .to_kind(Kind::Int64)
+ max_exact; + max_exact;
let value_if_large = value_if_large.min1(&value_if_large.full_like(num_buckets - 1)); let value_if_large = value_if_large.min_other(&value_if_large.full_like(num_buckets - 1));
ret += n.where1(&is_small, &value_if_large); ret += n.where_self(&is_small, &value_if_large);
ret ret
} }

View File

@ -328,7 +328,7 @@ impl T5Stack {
input_shape[1], input_shape[1],
1, 1,
]); ]);
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1)); let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
causal_mask.unsqueeze(1) * attention_mask.unsqueeze(1).unsqueeze(1) causal_mask.unsqueeze(1) * attention_mask.unsqueeze(1).unsqueeze(1)
} else { } else {
attention_mask.unsqueeze(1).unsqueeze(1) attention_mask.unsqueeze(1).unsqueeze(1)

View File

@ -32,7 +32,7 @@ impl T5LayerNorm {
impl Module for T5LayerNorm { impl Module for T5LayerNorm {
fn forward(&self, x: &Tensor) -> Tensor { fn forward(&self, x: &Tensor) -> Tensor {
let variance = x.pow(2f64).mean1(&[-1], true, Kind::Float); let variance = x.pow(2f64).mean_dim(&[-1], true, Kind::Float);
let x = x / (variance + self.epsilon).sqrt(); let x = x / (variance + self.epsilon).sqrt();
&self.weight * x &self.weight * x
} }

View File

@ -287,13 +287,16 @@ impl XLNetModel {
batch_size: Option<i64>, batch_size: Option<i64>,
device: Device, device: Device,
) -> Tensor { ) -> Tensor {
let frequency_sequence = Tensor::arange2(0, self.d_model, 2, (Kind::Float, device)); let frequency_sequence =
let inverse_frequency = 1f64 / Tensor::pow2(10000f64, &(frequency_sequence / self.d_model)); Tensor::arange_start_step(0, self.d_model, 2, (Kind::Float, device));
let inverse_frequency =
1f64 / Tensor::pow_scalar(10000f64, &(frequency_sequence / self.d_model));
let (begin, end) = match self.attention_type { let (begin, end) = match self.attention_type {
AttentionType::bi => (k_len, -q_len), AttentionType::bi => (k_len, -q_len),
AttentionType::uni => (k_len, -1), AttentionType::uni => (k_len, -1),
}; };
let mut forward_positions_sequence = Tensor::arange2(begin, end, -1, (Kind::Float, device)); let mut forward_positions_sequence =
Tensor::arange_start_step(begin, end, -1, (Kind::Float, device));
match self.clamp_len { match self.clamp_len {
Some(clamp_value) if clamp_value > 0 => { Some(clamp_value) if clamp_value > 0 => {
let _ = forward_positions_sequence.clamp_(-clamp_value, clamp_value); let _ = forward_positions_sequence.clamp_(-clamp_value, clamp_value);
@ -302,7 +305,7 @@ impl XLNetModel {
} }
if self.bi_data { if self.bi_data {
let mut backward_positions_sequence = let mut backward_positions_sequence =
Tensor::arange2(-begin, -end, 1, (Kind::Float, device)); Tensor::arange_start(-begin, -end, (Kind::Float, device));
match self.clamp_len { match self.clamp_len {
Some(clamp_value) if clamp_value > 0 => { Some(clamp_value) if clamp_value > 0 => {
let _ = backward_positions_sequence.clamp_(-clamp_value, clamp_value); let _ = backward_positions_sequence.clamp_(-clamp_value, clamp_value);
@ -512,7 +515,7 @@ impl XLNetModel {
}; };
let seg_mat = token_type_ids_value let seg_mat = token_type_ids_value
.unsqueeze(-1) .unsqueeze(-1)
.ne1(&cat_ids.unsqueeze(0)) .ne_tensor(&cat_ids.unsqueeze(0))
.to_kind(Kind::Int64); .to_kind(Kind::Int64);
Some(seg_mat.one_hot(2).to_kind(Kind::Float)) Some(seg_mat.one_hot(2).to_kind(Kind::Float))
} else { } else {
@ -1461,8 +1464,8 @@ impl XLNetForQuestionAnswering {
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs); let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1); let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze_dim(-1);
XLNetQuestionAnsweringOutput { XLNetQuestionAnsweringOutput {
start_logits, start_logits,

View File

@ -428,17 +428,20 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
None, None,
None, None,
Some(&force_one_paragraph), Some(&force_one_paragraph),
true,
); );
assert_eq!(output.len(), 2); assert_eq!(output.len(), 2);
assert_eq!( assert_eq!(
output[0], output[0].text,
"Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n" "Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n"
); );
assert!((output[0].score.unwrap() - (-1.4666)).abs() < 1e-4);
assert_eq!( assert_eq!(
output[1], output[1].text,
"There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And" "There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And"
); );
assert!((output[1].score.unwrap() - (-1.3545)).abs() < 1e-4);
Ok(()) Ok(())
} }
@ -493,17 +496,20 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
None, None,
None, None,
Some(&force_one_paragraph), Some(&force_one_paragraph),
true,
); );
assert_eq!(output.len(), 2); assert_eq!(output.len(), 2);
assert_eq!( assert_eq!(
output[0], output[0].text,
"Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and" "Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and"
); );
assert!((output[0].score.unwrap() - (-1.2750)).abs() < 1e-4);
assert_eq!( assert_eq!(
output[1], output[1].text,
"There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I" "There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I"
); );
assert!((output[1].score.unwrap() - (-1.3326)).abs() < 1e-4);
Ok(()) Ok(())
} }

View File

@ -97,11 +97,12 @@ fn mbart_translation() -> anyhow::Result<()> {
None, None,
target_language, target_language,
None, None,
false,
); );
assert_eq!(output.len(), 1); assert_eq!(output.len(), 1);
assert_eq!( assert_eq!(
output[0], output[0].text,
"de_DE Der schnelle braune Fuchs springt über den faulen Hund." "de_DE Der schnelle braune Fuchs springt über den faulen Hund."
); );