mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Merge remote-tracking branch 'origin/master' into m2m100_implementation
# Conflicts: # Cargo.toml
This commit is contained in:
commit
0b2e339e87
@ -3,13 +3,15 @@ All notable changes to this project will be documented in this file. The format
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
## Added
|
## Added
|
||||||
- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions
|
- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions
|
||||||
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
|
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
|
||||||
|
- (BREAKING) Support for `output_scores` boolean argument for generation, allowing users to output the log-probability scores of generated sequences. Updated the return type of low-level generate API to `GeneratedTextOutput` and `GeneratedIndicesOutput` containing optional scores along with the generated output.
|
||||||
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
|
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
|
||||||
- Addition of the M2M100 Language model and support for text generation / direct translation between 100 language
|
- Addition of the M2M100 Language model and support for text generation / direct translation between 100 language
|
||||||
|
|
||||||
## Changed
|
## Changed
|
||||||
- Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint)
|
- Updated GPT2 architecture to re-use embeddings for the output projection layer (resulting in smaller model weights files and memory footprint)
|
||||||
|
- Upgraded `tch` version to 0.5.0 (using `libtorch` 1.9.0)
|
||||||
|
|
||||||
## [0.15.1] - 2021-06-01
|
## [0.15.1] - 2021-06-01
|
||||||
### Fixed
|
### Fixed
|
||||||
|
@ -58,7 +58,7 @@ features = ["doc-only"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rust_tokenizers = { version = "~6.2.4", path = "E:/Coding/backup-rust/rust-tokenizers/main" }
|
rust_tokenizers = { version = "~6.2.4", path = "E:/Coding/backup-rust/rust-tokenizers/main" }
|
||||||
tch = "~0.4.1"
|
tch = "~0.5.0"
|
||||||
serde_json = "1.0.64"
|
serde_json = "1.0.64"
|
||||||
serde = { version = "1.0.126", features = ["derive"] }
|
serde = { version = "1.0.126", features = ["derive"] }
|
||||||
dirs = "3.0.2"
|
dirs = "3.0.2"
|
||||||
@ -72,5 +72,5 @@ thiserror = "1.0.24"
|
|||||||
anyhow = "1.0.40"
|
anyhow = "1.0.40"
|
||||||
csv = "1.1.6"
|
csv = "1.1.6"
|
||||||
criterion = "0.3.4"
|
criterion = "0.3.4"
|
||||||
torch-sys = "0.4.1"
|
torch-sys = "0.5.0"
|
||||||
tempfile = "3.2.0"
|
tempfile = "3.2.0"
|
||||||
|
@ -71,8 +71,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
|
|||||||
|
|
||||||
### Manual installation (recommended)
|
### Manual installation (recommended)
|
||||||
|
|
||||||
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.8.1`: if this version is no longer available on the "get started" page,
|
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.9.0`: if this version is no longer available on the "get started" page,
|
||||||
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11.
|
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.9.0%2Bcu111.zip` for a Linux version with CUDA11.
|
||||||
2. Extract the library to a location of your choice
|
2. Extract the library to a location of your choice
|
||||||
3. Set the following environment variables
|
3. Set the following environment variables
|
||||||
##### Linux:
|
##### Linux:
|
||||||
|
@ -20,17 +20,17 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let generate_config = TextGenerationConfig {
|
let generate_config = TextGenerationConfig {
|
||||||
model_type: ModelType::GPT2,
|
model_type: ModelType::GPT2,
|
||||||
max_length: 30,
|
max_length: 30,
|
||||||
do_sample: true,
|
do_sample: false,
|
||||||
num_beams: 5,
|
num_beams: 1,
|
||||||
temperature: 1.1,
|
temperature: 1.0,
|
||||||
num_return_sequences: 3,
|
num_return_sequences: 1,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let model = TextGenerationModel::new(generate_config)?;
|
let model = TextGenerationModel::new(generate_config)?;
|
||||||
|
|
||||||
let input_context = "The dog";
|
let input_context = "The dog";
|
||||||
let second_input_context = "The cat was";
|
// let second_input_context = "The cat was";
|
||||||
let output = model.generate(&[input_context, second_input_context], None);
|
let output = model.generate(&[input_context], None);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{:?}", sentence);
|
println!("{:?}", sentence);
|
||||||
|
@ -50,10 +50,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
None,
|
None,
|
||||||
target_language,
|
target_language,
|
||||||
None,
|
None,
|
||||||
|
false,
|
||||||
);
|
);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{:?}", sentence);
|
println!("{:?}", sentence.text);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,16 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Define input
|
// Define input
|
||||||
let input = ["translate English to German: This sentence will get translated to German"];
|
let input = ["translate English to German: This sentence will get translated to German"];
|
||||||
|
|
||||||
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None);
|
let output = t5_model.generate(
|
||||||
|
Some(input.to_vec()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
);
|
||||||
println!("{:?}", output);
|
println!("{:?}", output);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -848,8 +848,8 @@ impl AlbertForQuestionAnswering {
|
|||||||
.apply(&self.qa_outputs)
|
.apply(&self.qa_outputs)
|
||||||
.split(1, -1);
|
.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
AlbertQuestionAnsweringOutput {
|
AlbertQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -240,7 +240,7 @@ pub(crate) fn _make_causal_mask(
|
|||||||
);
|
);
|
||||||
let mask_cond = Tensor::arange(target_length, (dtype, device));
|
let mask_cond = Tensor::arange(target_length, (dtype, device));
|
||||||
let _ = mask.masked_fill_(
|
let _ = mask.masked_fill_(
|
||||||
&mask_cond.lt1(&(&mask_cond + 1).view([target_length, 1])),
|
&mask_cond.lt_tensor(&(&mask_cond + 1).view([target_length, 1])),
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -306,7 +306,10 @@ pub(crate) fn _prepare_decoder_attention_mask(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
||||||
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
|
let index_eos: Tensor = input_ids
|
||||||
|
.ne(pad_token_id)
|
||||||
|
.sum_dim_intlist(&[-1], true, Int64)
|
||||||
|
- 1;
|
||||||
let output = input_ids.empty_like().to_kind(Int64);
|
let output = input_ids.empty_like().to_kind(Int64);
|
||||||
output
|
output
|
||||||
.select(1, 0)
|
.select(1, 0)
|
||||||
@ -809,7 +812,7 @@ impl BartForSequenceClassification {
|
|||||||
train,
|
train,
|
||||||
);
|
);
|
||||||
let eos_mask = input_ids.eq(self.eos_token_id);
|
let eos_mask = input_ids.eq(self.eos_token_id);
|
||||||
let reshape = eos_mask.sum1(&[1], true, Int64);
|
let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64);
|
||||||
let sentence_representation = base_model_output
|
let sentence_representation = base_model_output
|
||||||
.decoder_output
|
.decoder_output
|
||||||
.permute(&[2, 0, 1])
|
.permute(&[2, 0, 1])
|
||||||
|
@ -64,7 +64,7 @@ impl LearnedPositionalEmbedding {
|
|||||||
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
|
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
|
||||||
let input_shape = input.size();
|
let input_shape = input.size();
|
||||||
let (_, sequence_length) = (input_shape[0], input_shape[1]);
|
let (_, sequence_length) = (input_shape[0], input_shape[1]);
|
||||||
let positions = Tensor::arange1(
|
let positions = Tensor::arange_start(
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
past_key_values_length + sequence_length,
|
past_key_values_length + sequence_length,
|
||||||
(Int64, input.device()),
|
(Int64, input.device()),
|
||||||
@ -99,7 +99,7 @@ impl SinusoidalPositionalEmbedding {
|
|||||||
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
|
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
|
||||||
let input_shape = input.size();
|
let input_shape = input.size();
|
||||||
let (_, sequence_length) = (input_shape[0], input_shape[1]);
|
let (_, sequence_length) = (input_shape[0], input_shape[1]);
|
||||||
let positions = Tensor::arange1(
|
let positions = Tensor::arange_start(
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
past_key_values_length + sequence_length,
|
past_key_values_length + sequence_length,
|
||||||
(Int64, input.device()),
|
(Int64, input.device()),
|
||||||
|
@ -323,7 +323,7 @@ impl<T: BertEmbedding> BertModel<T> {
|
|||||||
input_shape[1],
|
input_shape[1],
|
||||||
1,
|
1,
|
||||||
]);
|
]);
|
||||||
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1));
|
let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
|
||||||
causal_mask * mask.unsqueeze(1).unsqueeze(1)
|
causal_mask * mask.unsqueeze(1).unsqueeze(1)
|
||||||
} else {
|
} else {
|
||||||
mask.unsqueeze(1).unsqueeze(1)
|
mask.unsqueeze(1).unsqueeze(1)
|
||||||
@ -1161,8 +1161,8 @@ impl BertForQuestionAnswering {
|
|||||||
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
||||||
let logits = sequence_output.split(1, -1);
|
let logits = sequence_output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
BertQuestionAnsweringOutput {
|
BertQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -131,11 +131,12 @@ impl BertLayer {
|
|||||||
encoder_mask: &Option<Tensor>,
|
encoder_mask: &Option<Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> BertLayerOutput {
|
) -> BertLayerOutput {
|
||||||
|
let (attention_output, attention_weights) =
|
||||||
|
self.attention
|
||||||
|
.forward_t(hidden_states, mask, &None, &None, train);
|
||||||
|
|
||||||
let (attention_output, attention_scores, cross_attention_scores) =
|
let (attention_output, attention_scores, cross_attention_scores) =
|
||||||
if self.is_decoder & encoder_hidden_states.is_some() {
|
if self.is_decoder & encoder_hidden_states.is_some() {
|
||||||
let (attention_output, attention_weights) =
|
|
||||||
self.attention
|
|
||||||
.forward_t(hidden_states, mask, &None, &None, train);
|
|
||||||
let (attention_output, cross_attention_weights) =
|
let (attention_output, cross_attention_weights) =
|
||||||
self.cross_attention.as_ref().unwrap().forward_t(
|
self.cross_attention.as_ref().unwrap().forward_t(
|
||||||
&attention_output,
|
&attention_output,
|
||||||
@ -146,9 +147,6 @@ impl BertLayer {
|
|||||||
);
|
);
|
||||||
(attention_output, attention_weights, cross_attention_weights)
|
(attention_output, attention_weights, cross_attention_weights)
|
||||||
} else {
|
} else {
|
||||||
let (attention_output, attention_weights) =
|
|
||||||
self.attention
|
|
||||||
.forward_t(hidden_states, mask, &None, &None, train);
|
|
||||||
(attention_output, attention_weights, None)
|
(attention_output, attention_weights, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ impl SequenceSummary {
|
|||||||
{
|
{
|
||||||
let p = p.borrow();
|
let p = p.borrow();
|
||||||
|
|
||||||
let summary_type = config.summary_type.clone().unwrap_or(SummaryType::last);
|
let summary_type = config.summary_type.unwrap_or(SummaryType::last);
|
||||||
let summary = if let Some(summary_use_proj) = config.summary_use_proj {
|
let summary = if let Some(summary_use_proj) = config.summary_use_proj {
|
||||||
let num_classes = match (config.summary_proj_to_labels, config.num_labels) {
|
let num_classes = match (config.summary_proj_to_labels, config.num_labels) {
|
||||||
(Some(summary_proj_to_labels), Some(num_labels))
|
(Some(summary_proj_to_labels), Some(num_labels))
|
||||||
@ -132,7 +132,7 @@ impl SequenceSummary {
|
|||||||
let mut output = match self.summary_type {
|
let mut output = match self.summary_type {
|
||||||
SummaryType::last => hidden_states.select(1, -1),
|
SummaryType::last => hidden_states.select(1, -1),
|
||||||
SummaryType::first => hidden_states.select(1, 0),
|
SummaryType::first => hidden_states.select(1, 0),
|
||||||
SummaryType::mean => hidden_states.mean1(&[1], false, Kind::Float),
|
SummaryType::mean => hidden_states.mean_dim(&[1], false, Kind::Float),
|
||||||
SummaryType::cls_index => {
|
SummaryType::cls_index => {
|
||||||
let cls_index = if let Some(cls_index_value) = cls_index {
|
let cls_index = if let Some(cls_index_value) = cls_index {
|
||||||
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];
|
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];
|
||||||
@ -147,7 +147,7 @@ impl SequenceSummary {
|
|||||||
let fill_value = fill_value[2];
|
let fill_value = fill_value[2];
|
||||||
hidden_states.select(-2, 0).full_like(fill_value)
|
hidden_states.select(-2, 0).full_like(fill_value)
|
||||||
};
|
};
|
||||||
hidden_states.gather(-2, &cls_index, false).squeeze1(-2)
|
hidden_states.gather(-2, &cls_index, false).squeeze_dim(-2)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ impl MultiHeadSelfAttention {
|
|||||||
let scores = if let Some(mask) = mask {
|
let scores = if let Some(mask) = mask {
|
||||||
let unmasked_scores = q.matmul(&k.transpose(2, 3));
|
let unmasked_scores = q.matmul(&k.transpose(2, 3));
|
||||||
let mask = mask
|
let mask = mask
|
||||||
.le1(&(mask.zeros_like() + 0.1))
|
.le_tensor(&(mask.zeros_like() + 0.1))
|
||||||
.view((bs, 1i64, 1i64, k_length))
|
.view((bs, 1i64, 1i64, k_length))
|
||||||
.expand_as(&unmasked_scores);
|
.expand_as(&unmasked_scores);
|
||||||
unmasked_scores.masked_fill(&mask, f64::NEG_INFINITY)
|
unmasked_scores.masked_fill(&mask, f64::NEG_INFINITY)
|
||||||
|
@ -614,8 +614,8 @@ impl DistilBertForQuestionAnswering {
|
|||||||
|
|
||||||
let logits = output.split(1, -1);
|
let logits = output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
Ok(DistilBertQuestionAnsweringOutput {
|
Ok(DistilBertQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -414,7 +414,7 @@ impl Gpt2Model {
|
|||||||
|
|
||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value.copy(),
|
Some(value) => value.copy(),
|
||||||
None => Tensor::arange1(
|
None => Tensor::arange_start(
|
||||||
layer_past_length,
|
layer_past_length,
|
||||||
seq_length + layer_past_length,
|
seq_length + layer_past_length,
|
||||||
(Int64, input_embeddings.device()),
|
(Int64, input_embeddings.device()),
|
||||||
|
@ -132,7 +132,9 @@ pub(crate) trait GptNeoAttentionUtils {
|
|||||||
let query_indices = Self::split_sequence_length_dim_to(&indices, num_blocks, block_length)?;
|
let query_indices = Self::split_sequence_length_dim_to(&indices, num_blocks, block_length)?;
|
||||||
let key_indices = Self::look_back(&indices, block_length, window_size, None, false)?;
|
let key_indices = Self::look_back(&indices, block_length, window_size, None, false)?;
|
||||||
|
|
||||||
let causal_mask = query_indices.unsqueeze(-1).ge1(&key_indices.unsqueeze(-2));
|
let causal_mask = query_indices
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.ge_tensor(&key_indices.unsqueeze(-2));
|
||||||
|
|
||||||
let calc_attention_mask = if attention_mask.is_none() {
|
let calc_attention_mask = if attention_mask.is_none() {
|
||||||
Some(Tensor::ones(
|
Some(Tensor::ones(
|
||||||
@ -212,7 +214,7 @@ pub(crate) trait GptNeoAttentionUtils {
|
|||||||
) -> (Tensor, Tensor) {
|
) -> (Tensor, Tensor) {
|
||||||
let mut attention_weights = query
|
let mut attention_weights = query
|
||||||
.matmul(&key.transpose(-1, -2))
|
.matmul(&key.transpose(-1, -2))
|
||||||
.where1(causal_mask, &masked_bias.to_kind(query.kind()));
|
.where_self(causal_mask, &masked_bias.to_kind(query.kind()));
|
||||||
|
|
||||||
if let Some(attention_mask_value) = attention_mask {
|
if let Some(attention_mask_value) = attention_mask {
|
||||||
attention_weights = attention_weights + attention_mask_value;
|
attention_weights = attention_weights + attention_mask_value;
|
||||||
|
@ -345,7 +345,7 @@ impl GptNeoModel {
|
|||||||
|
|
||||||
let calc_position_ids = if position_ids.is_none() {
|
let calc_position_ids = if position_ids.is_none() {
|
||||||
let position_ids =
|
let position_ids =
|
||||||
Tensor::arange1(past_length, full_sequence_length, (Kind::Int64, device));
|
Tensor::arange_start(past_length, full_sequence_length, (Kind::Int64, device));
|
||||||
Some(
|
Some(
|
||||||
position_ids
|
position_ids
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
|
@ -352,7 +352,8 @@ impl LongformerSelfAttention {
|
|||||||
&self,
|
&self,
|
||||||
is_index_global_attn: &Tensor,
|
is_index_global_attn: &Tensor,
|
||||||
) -> GlobalAttentionIndices {
|
) -> GlobalAttentionIndices {
|
||||||
let num_global_attention_indices = is_index_global_attn.sum1(&[1], false, Kind::Int64);
|
let num_global_attention_indices =
|
||||||
|
is_index_global_attn.sum_dim_intlist(&[1], false, Kind::Int64);
|
||||||
let max_num_global_attention_indices = i64::from(num_global_attention_indices.max());
|
let max_num_global_attention_indices = i64::from(num_global_attention_indices.max());
|
||||||
let is_index_global_attn_nonzero = is_index_global_attn
|
let is_index_global_attn_nonzero = is_index_global_attn
|
||||||
.nonzero_numpy()
|
.nonzero_numpy()
|
||||||
@ -364,7 +365,7 @@ impl LongformerSelfAttention {
|
|||||||
max_num_global_attention_indices,
|
max_num_global_attention_indices,
|
||||||
(Kind::Int64, is_index_global_attn.device()),
|
(Kind::Int64, is_index_global_attn.device()),
|
||||||
)
|
)
|
||||||
.lt1(&num_global_attention_indices.unsqueeze(-1));
|
.lt_tensor(&num_global_attention_indices.unsqueeze(-1));
|
||||||
|
|
||||||
let is_local_index_global_attention_nonzero = is_local_index_global_attention
|
let is_local_index_global_attention_nonzero = is_local_index_global_attention
|
||||||
.nonzero_numpy()
|
.nonzero_numpy()
|
||||||
|
@ -86,7 +86,7 @@ impl LongformerEmbeddings {
|
|||||||
let input_shape = inputs_embeds.size();
|
let input_shape = inputs_embeds.size();
|
||||||
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
|
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
|
||||||
|
|
||||||
Tensor::arange1(
|
Tensor::arange_start(
|
||||||
self.pad_token_id + 1,
|
self.pad_token_id + 1,
|
||||||
sequence_length + self.pad_token_id + 1,
|
sequence_length + self.pad_token_id + 1,
|
||||||
(Kind::Int64, inputs_embeds.device()),
|
(Kind::Int64, inputs_embeds.device()),
|
||||||
|
@ -140,11 +140,13 @@ fn compute_global_attention_mask(
|
|||||||
let attention_mask = Tensor::arange(input_ids.size()[1], (Kind::Int64, input_ids.device()));
|
let attention_mask = Tensor::arange(input_ids.size()[1], (Kind::Int64, input_ids.device()));
|
||||||
|
|
||||||
if before_sep_token {
|
if before_sep_token {
|
||||||
attention_mask.expand_as(input_ids).lt1(&question_end_index)
|
attention_mask
|
||||||
|
.expand_as(input_ids)
|
||||||
|
.lt_tensor(&question_end_index)
|
||||||
} else {
|
} else {
|
||||||
attention_mask
|
attention_mask
|
||||||
.expand_as(input_ids)
|
.expand_as(input_ids)
|
||||||
.gt1(&(question_end_index + 1))
|
.gt_tensor(&(question_end_index + 1))
|
||||||
* attention_mask
|
* attention_mask
|
||||||
.expand_as(input_ids)
|
.expand_as(input_ids)
|
||||||
.lt(*input_ids.size().last().unwrap())
|
.lt(*input_ids.size().last().unwrap())
|
||||||
@ -580,7 +582,7 @@ impl LongformerModel {
|
|||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.repeat(&[batch_size, sequence_length, 1])
|
.repeat(&[batch_size, sequence_length, 1])
|
||||||
.le1(&sequence_ids.unsqueeze(-1).unsqueeze(0))
|
.le_tensor(&sequence_ids.unsqueeze(-1).unsqueeze(0))
|
||||||
.totype(Kind::Int);
|
.totype(Kind::Int);
|
||||||
if causal_mask.size()[1] < padded_attention_mask.size()[1] {
|
if causal_mask.size()[1] < padded_attention_mask.size()[1] {
|
||||||
let prefix_sequence_length =
|
let prefix_sequence_length =
|
||||||
@ -1147,8 +1149,8 @@ impl LongformerForQuestionAnswering {
|
|||||||
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
||||||
let logits = sequence_output.split(1, -1);
|
let logits = sequence_output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
Ok(LongformerQuestionAnsweringOutput {
|
Ok(LongformerQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -111,7 +111,10 @@ impl Config for MBartConfig {}
|
|||||||
|
|
||||||
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
||||||
let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id);
|
let output = input_ids.masked_fill(&input_ids.eq(-100), pad_token_id);
|
||||||
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[1], true, Int64) - 1;
|
let index_eos: Tensor = input_ids
|
||||||
|
.ne(pad_token_id)
|
||||||
|
.sum_dim_intlist(&[1], true, Int64)
|
||||||
|
- 1;
|
||||||
output
|
output
|
||||||
.select(1, 0)
|
.select(1, 0)
|
||||||
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
|
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
|
||||||
@ -632,7 +635,7 @@ impl MBartForSequenceClassification {
|
|||||||
train,
|
train,
|
||||||
);
|
);
|
||||||
let eos_mask = input_ids.eq(self.eos_token_id);
|
let eos_mask = input_ids.eq(self.eos_token_id);
|
||||||
let reshape = eos_mask.sum1(&[1], true, Int64);
|
let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64);
|
||||||
let sentence_representation = base_model_output
|
let sentence_representation = base_model_output
|
||||||
.decoder_output
|
.decoder_output
|
||||||
.permute(&[2, 0, 1])
|
.permute(&[2, 0, 1])
|
||||||
|
@ -901,8 +901,8 @@ impl MobileBertForQuestionAnswering {
|
|||||||
let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs);
|
let sequence_output = mobilebert_output.hidden_state.apply(&self.qa_outputs);
|
||||||
let logits = sequence_output.split(1, -1);
|
let logits = sequence_output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
Ok(MobileBertQuestionAnsweringOutput {
|
Ok(MobileBertQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -87,7 +87,7 @@ impl SinusoidalPositionalEmbedding {
|
|||||||
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
|
pub fn forward(&self, input: &Tensor, past_key_values_length: i64) -> Tensor {
|
||||||
let input_shape = input.size();
|
let input_shape = input.size();
|
||||||
let (_, sequence_length) = (input_shape[0], input_shape[1]);
|
let (_, sequence_length) = (input_shape[0], input_shape[1]);
|
||||||
let positions = Tensor::arange1(
|
let positions = Tensor::arange_start(
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
past_key_values_length + sequence_length,
|
past_key_values_length + sequence_length,
|
||||||
(Kind::Int64, input.device()),
|
(Kind::Int64, input.device()),
|
||||||
|
@ -716,15 +716,20 @@ impl ConversationOption {
|
|||||||
attention_mask: Option<Tensor>,
|
attention_mask: Option<Tensor>,
|
||||||
) -> Vec<Vec<i64>> {
|
) -> Vec<Vec<i64>> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::GPT2(ref model) => model.generate_from_ids_and_past(
|
Self::GPT2(ref model) => model
|
||||||
input_ids,
|
.generate_from_ids_and_past(
|
||||||
attention_mask,
|
input_ids,
|
||||||
None,
|
attention_mask,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
),
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.indices)
|
||||||
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,7 @@
|
|||||||
//! decoder_start_id,
|
//! decoder_start_id,
|
||||||
//! forced_bos_token_id,
|
//! forced_bos_token_id,
|
||||||
//! None,
|
//! None,
|
||||||
|
//! false,
|
||||||
//! );
|
//! );
|
||||||
//! # Ok(())
|
//! # Ok(())
|
||||||
//! # }
|
//! # }
|
||||||
@ -492,13 +493,13 @@ pub(crate) mod private_generation_utils {
|
|||||||
if min_tokens_to_keep > 1 {
|
if min_tokens_to_keep > 1 {
|
||||||
let _ = sorted_indices_to_remove.index_fill_(
|
let _ = sorted_indices_to_remove.index_fill_(
|
||||||
1,
|
1,
|
||||||
&Tensor::arange1(0, min_tokens_to_keep + 1, (Int64, logits.device())),
|
&Tensor::arange_start(0, min_tokens_to_keep + 1, (Int64, logits.device())),
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let _ = sorted_indices_to_remove.index_copy_(
|
let _ = sorted_indices_to_remove.index_copy_(
|
||||||
1,
|
1,
|
||||||
&Tensor::arange1(1, vocab_size, (Int64, logits.device())),
|
&Tensor::arange_start(1, vocab_size, (Int64, logits.device())),
|
||||||
&sorted_indices_to_remove
|
&sorted_indices_to_remove
|
||||||
.slice(1, 0, vocab_size - 1, 1)
|
.slice(1, 0, vocab_size - 1, 1)
|
||||||
.copy(),
|
.copy(),
|
||||||
@ -585,7 +586,8 @@ pub(crate) mod private_generation_utils {
|
|||||||
attention_mask: Tensor,
|
attention_mask: Tensor,
|
||||||
gen_opt: GenerateOptions,
|
gen_opt: GenerateOptions,
|
||||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||||
) -> Tensor {
|
output_scores: bool,
|
||||||
|
) -> (Tensor, Option<Vec<f64>>) {
|
||||||
let mut unfinished_sentences =
|
let mut unfinished_sentences =
|
||||||
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
|
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
|
||||||
let mut sentence_lengths: Tensor =
|
let mut sentence_lengths: Tensor =
|
||||||
@ -596,6 +598,14 @@ pub(crate) mod private_generation_utils {
|
|||||||
let mut past: Cache = Cache::None;
|
let mut past: Cache = Cache::None;
|
||||||
let mut outputs: Tensor;
|
let mut outputs: Tensor;
|
||||||
let mut current_length = cur_len;
|
let mut current_length = cur_len;
|
||||||
|
let mut scores_output = if output_scores {
|
||||||
|
Some(Tensor::zeros(
|
||||||
|
&[batch_size],
|
||||||
|
(Float, self.get_var_store().device()),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
while current_length < gen_opt.max_length {
|
while current_length < gen_opt.max_length {
|
||||||
let prepared_input = self.prepare_inputs_for_generation(
|
let prepared_input = self.prepare_inputs_for_generation(
|
||||||
@ -690,11 +700,22 @@ pub(crate) mod private_generation_utils {
|
|||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
let probabilities = next_token_logits.softmax(-1, Float);
|
let probabilities = next_token_logits.softmax(-1, Float);
|
||||||
probabilities.multinomial(1, false).squeeze1(1)
|
probabilities.multinomial(1, false).squeeze_dim(1)
|
||||||
} else {
|
} else {
|
||||||
next_token_logits.argmax(-1, false)
|
next_token_logits.argmax(-1, false)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if let Some(prev_scores) = scores_output {
|
||||||
|
let finished_mask = unfinished_sentences.eq(0);
|
||||||
|
scores_output = Some(
|
||||||
|
prev_scores
|
||||||
|
+ (&next_token_logits
|
||||||
|
.log_softmax(-1, Float)
|
||||||
|
.gather(1, &next_token.reshape(&[-1, 1]), true)
|
||||||
|
.squeeze()
|
||||||
|
.masked_fill(&finished_mask, 0)),
|
||||||
|
);
|
||||||
|
}
|
||||||
// Add tokens to unfinished sentences
|
// Add tokens to unfinished sentences
|
||||||
let tokens_to_add = match &gen_opt.eos_token_ids {
|
let tokens_to_add = match &gen_opt.eos_token_ids {
|
||||||
Some(_) => {
|
Some(_) => {
|
||||||
@ -736,7 +757,14 @@ pub(crate) mod private_generation_utils {
|
|||||||
}
|
}
|
||||||
current_length += 1;
|
current_length += 1;
|
||||||
}
|
}
|
||||||
input_ids
|
let scores_output = scores_output.map(|scores_tensor| {
|
||||||
|
(scores_tensor / sentence_lengths.pow(gen_opt.length_penalty))
|
||||||
|
.iter::<f64>()
|
||||||
|
.unwrap()
|
||||||
|
.collect::<Vec<f64>>()
|
||||||
|
});
|
||||||
|
|
||||||
|
(input_ids, scores_output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_beam_search(
|
fn generate_beam_search(
|
||||||
@ -748,7 +776,8 @@ pub(crate) mod private_generation_utils {
|
|||||||
mut attention_mask: Tensor,
|
mut attention_mask: Tensor,
|
||||||
gen_opt: GenerateOptions,
|
gen_opt: GenerateOptions,
|
||||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||||
) -> Tensor {
|
output_scores: bool,
|
||||||
|
) -> (Tensor, Option<Vec<f64>>) {
|
||||||
let num_beam_groups = gen_opt.num_beam_groups.unwrap_or(1);
|
let num_beam_groups = gen_opt.num_beam_groups.unwrap_or(1);
|
||||||
let num_sub_beams = gen_opt.num_beams / num_beam_groups;
|
let num_sub_beams = gen_opt.num_beams / num_beam_groups;
|
||||||
let diversity_penalty = gen_opt.diversity_penalty.unwrap_or(5.5);
|
let diversity_penalty = gen_opt.diversity_penalty.unwrap_or(5.5);
|
||||||
@ -960,12 +989,12 @@ pub(crate) mod private_generation_utils {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let eos_token_ids = gen_opt.eos_token_ids.as_ref();
|
let eos_token_ids = gen_opt.eos_token_ids.as_ref();
|
||||||
let beam_ids_tensor = &next_tokens.floor_divide1(vocab_size);
|
let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor");
|
||||||
let effective_beam_ids_tensor = (&next_tokens.ones_like().cumsum(0, Int64) - 1)
|
let effective_beam_ids_tensor = (&next_tokens.ones_like().cumsum(0, Int64) - 1)
|
||||||
* group_size
|
* group_size
|
||||||
+ beam_ids_tensor;
|
+ beam_ids_tensor;
|
||||||
let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
|
let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
|
||||||
let (max_scores, _) = next_scores.max2(1, false);
|
let (max_scores, _) = next_scores.max_dim(1, false);
|
||||||
let mut eos_mask = token_id_tensor.ones_like();
|
let mut eos_mask = token_id_tensor.ones_like();
|
||||||
if let Some(eos_token_id) = eos_token_ids {
|
if let Some(eos_token_id) = eos_token_ids {
|
||||||
eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Int64);
|
eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Int64);
|
||||||
@ -1034,7 +1063,7 @@ pub(crate) mod private_generation_utils {
|
|||||||
&group_beam_tokens,
|
&group_beam_tokens,
|
||||||
);
|
);
|
||||||
let new_indices = gen_opt.num_beams
|
let new_indices = gen_opt.num_beams
|
||||||
* group_beam_indices.floor_divide1(group_size)
|
* group_beam_indices.divide_scalar_mode(group_size, "floor")
|
||||||
+ group_start_index
|
+ group_start_index
|
||||||
+ group_beam_indices.remainder(group_size);
|
+ group_beam_indices.remainder(group_size);
|
||||||
let _ = beam_indices.index_copy_(
|
let _ = beam_indices.index_copy_(
|
||||||
@ -1110,6 +1139,11 @@ pub(crate) mod private_generation_utils {
|
|||||||
Tensor::zeros(&[output_batch_size], (Int64, input_ids.device()));
|
Tensor::zeros(&[output_batch_size], (Int64, input_ids.device()));
|
||||||
let mut best_ids = vec![];
|
let mut best_ids = vec![];
|
||||||
|
|
||||||
|
let mut scores_output = if output_scores {
|
||||||
|
Some(Vec::with_capacity(best_ids.len()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
|
for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
|
||||||
let mut sorted_hypotheses = hypothesis.clone();
|
let mut sorted_hypotheses = hypothesis.clone();
|
||||||
sorted_hypotheses
|
sorted_hypotheses
|
||||||
@ -1118,13 +1152,16 @@ pub(crate) mod private_generation_utils {
|
|||||||
for j in 0..output_num_return_sequences_per_batch {
|
for j in 0..output_num_return_sequences_per_batch {
|
||||||
let effective_batch_index =
|
let effective_batch_index =
|
||||||
output_num_return_sequences_per_batch * hypothesis_index as i64 + j;
|
output_num_return_sequences_per_batch * hypothesis_index as i64 + j;
|
||||||
let (_, best_hyp) = sorted_hypotheses.beams.pop().unwrap();
|
let (best_score, best_hyp) = sorted_hypotheses.beams.pop().unwrap();
|
||||||
let _ = sentence_lengths.index_fill_(
|
let _ = sentence_lengths.index_fill_(
|
||||||
0,
|
0,
|
||||||
&Tensor::of_slice(&[effective_batch_index]).to(sentence_lengths.device()),
|
&Tensor::of_slice(&[effective_batch_index]).to(sentence_lengths.device()),
|
||||||
*best_hyp.size().first().unwrap(),
|
*best_hyp.size().first().unwrap(),
|
||||||
);
|
);
|
||||||
best_ids.push(best_hyp);
|
best_ids.push(best_hyp);
|
||||||
|
if let Some(current_best_scores) = &mut scores_output {
|
||||||
|
current_best_scores.push(best_score);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let sentence_max_length =
|
let sentence_max_length =
|
||||||
@ -1143,7 +1180,7 @@ pub(crate) mod private_generation_utils {
|
|||||||
for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
|
for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
|
||||||
let _ = decoded.get(hypothesis_index as i64).index_copy_(
|
let _ = decoded.get(hypothesis_index as i64).index_copy_(
|
||||||
0,
|
0,
|
||||||
&Tensor::arange1(
|
&Tensor::arange_start(
|
||||||
0,
|
0,
|
||||||
i64::from(sentence_lengths.get(hypothesis_index as i64)),
|
i64::from(sentence_lengths.get(hypothesis_index as i64)),
|
||||||
(Int64, input_ids.device()),
|
(Int64, input_ids.device()),
|
||||||
@ -1159,7 +1196,7 @@ pub(crate) mod private_generation_utils {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
decoded
|
(decoded, scores_output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reorder_cache(
|
fn reorder_cache(
|
||||||
@ -1178,6 +1215,22 @@ pub(crate) mod private_generation_utils {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// # Generated text output
|
||||||
|
/// Contains generated text and an optional log-likelihood score for the generated sequence
|
||||||
|
pub struct GeneratedTextOutput {
|
||||||
|
pub text: String,
|
||||||
|
pub score: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// # Generated indices output
|
||||||
|
/// Contains generated indices and an optional log-likelihood score for the generated sequence
|
||||||
|
pub struct GeneratedIndicesOutput {
|
||||||
|
pub indices: Vec<i64>,
|
||||||
|
pub score: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
/// # Common trait for text generation models.
|
/// # Common trait for text generation models.
|
||||||
/// Main API for text generation
|
/// Main API for text generation
|
||||||
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||||
@ -1195,7 +1248,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Vec<String>` Vector of generated strings based on the prompts of length *number_of_prompts* x *num_return_sequences*.
|
/// * `Vec<TextOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing TextOutput with the generated texts and the generation score if `output_scores` is true.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
@ -1231,6 +1284,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// let max_length = 128;
|
/// let max_length = 128;
|
||||||
/// let decoder_start_token_id = None;
|
/// let decoder_start_token_id = None;
|
||||||
/// let forced_bos_token_id = None;
|
/// let forced_bos_token_id = None;
|
||||||
|
/// let output_scores = true;
|
||||||
///
|
///
|
||||||
/// //Example custom function for fine-grained generation control
|
/// //Example custom function for fine-grained generation control
|
||||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||||
@ -1257,6 +1311,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// decoder_start_token_id,
|
/// decoder_start_token_id,
|
||||||
/// forced_bos_token_id,
|
/// forced_bos_token_id,
|
||||||
/// Some(&force_one_paragraph),
|
/// Some(&force_one_paragraph),
|
||||||
|
/// output_scores,
|
||||||
/// );
|
/// );
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
@ -1283,11 +1338,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
decoder_start_token_id: impl Into<Option<i64>>,
|
decoder_start_token_id: impl Into<Option<i64>>,
|
||||||
forced_bos_token_id: impl Into<Option<i64>>,
|
forced_bos_token_id: impl Into<Option<i64>>,
|
||||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||||
) -> Vec<String>
|
output_scores: bool,
|
||||||
|
) -> Vec<GeneratedTextOutput>
|
||||||
where
|
where
|
||||||
S: AsRef<[&'a str]>,
|
S: AsRef<[&'a str]>,
|
||||||
{
|
{
|
||||||
let generated = self.generate_indices(
|
let indices_outputs = self.generate_indices(
|
||||||
prompt_texts,
|
prompt_texts,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
min_length,
|
min_length,
|
||||||
@ -1295,10 +1351,16 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
forced_bos_token_id,
|
forced_bos_token_id,
|
||||||
prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn,
|
||||||
|
output_scores,
|
||||||
);
|
);
|
||||||
let mut output = Vec::with_capacity(generated.len());
|
let mut output = Vec::with_capacity(indices_outputs.len());
|
||||||
for generated_sequence in generated {
|
for generated_sequence in indices_outputs {
|
||||||
output.push(self._get_tokenizer().decode(generated_sequence, true, true));
|
output.push(GeneratedTextOutput {
|
||||||
|
text: self
|
||||||
|
._get_tokenizer()
|
||||||
|
.decode(generated_sequence.indices, true, true),
|
||||||
|
score: generated_sequence.score,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
output
|
output
|
||||||
}
|
}
|
||||||
@ -1315,7 +1377,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
|
/// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
@ -1350,6 +1412,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// let max_length = 128;
|
/// let max_length = 128;
|
||||||
/// let decoder_start_token_id = None;
|
/// let decoder_start_token_id = None;
|
||||||
/// let forced_bos_token_id = None;
|
/// let forced_bos_token_id = None;
|
||||||
|
/// let output_scores = true;
|
||||||
///
|
///
|
||||||
/// //Example custom function for fine-grained generation control
|
/// //Example custom function for fine-grained generation control
|
||||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||||
@ -1376,6 +1439,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// decoder_start_token_id,
|
/// decoder_start_token_id,
|
||||||
/// forced_bos_token_id,
|
/// forced_bos_token_id,
|
||||||
/// Some(&force_one_paragraph),
|
/// Some(&force_one_paragraph),
|
||||||
|
/// output_scores,
|
||||||
/// );
|
/// );
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
@ -1389,7 +1453,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
decoder_start_token_id: impl Into<Option<i64>>,
|
decoder_start_token_id: impl Into<Option<i64>>,
|
||||||
forced_bos_token_id: impl Into<Option<i64>>,
|
forced_bos_token_id: impl Into<Option<i64>>,
|
||||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||||
) -> Vec<Vec<i64>>
|
output_scores: bool,
|
||||||
|
) -> Vec<GeneratedIndicesOutput>
|
||||||
where
|
where
|
||||||
S: AsRef<[&'a str]>,
|
S: AsRef<[&'a str]>,
|
||||||
{
|
{
|
||||||
@ -1426,6 +1491,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
forced_bos_token_id,
|
forced_bos_token_id,
|
||||||
prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn,
|
||||||
|
output_scores,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1442,7 +1508,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
|
/// * `Vec<IndicesOutput>` Vector of length *number_of_prompts* x *num_return_sequences* containing IndicesOutput with the generated indices and the generation score if `output_scores` is true.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
@ -1477,6 +1543,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// let max_length = 128;
|
/// let max_length = 128;
|
||||||
/// let decoder_start_token_id = None;
|
/// let decoder_start_token_id = None;
|
||||||
/// let forced_bos_token_id = None;
|
/// let forced_bos_token_id = None;
|
||||||
|
/// let output_scores = true;
|
||||||
///
|
///
|
||||||
/// //Example custom function for fine-grained generation control
|
/// //Example custom function for fine-grained generation control
|
||||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||||
@ -1503,6 +1570,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
/// decoder_start_token_id,
|
/// decoder_start_token_id,
|
||||||
/// forced_bos_token_id,
|
/// forced_bos_token_id,
|
||||||
/// Some(&force_one_paragraph),
|
/// Some(&force_one_paragraph),
|
||||||
|
/// output_scores,
|
||||||
/// );
|
/// );
|
||||||
/// # Ok(())
|
/// # Ok(())
|
||||||
/// # }
|
/// # }
|
||||||
@ -1516,7 +1584,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
decoder_start_token_id: impl Into<Option<i64>>,
|
decoder_start_token_id: impl Into<Option<i64>>,
|
||||||
forced_bos_token_id: impl Into<Option<i64>>,
|
forced_bos_token_id: impl Into<Option<i64>>,
|
||||||
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
|
||||||
) -> Vec<Vec<i64>> {
|
output_scores: bool,
|
||||||
|
) -> Vec<GeneratedIndicesOutput> {
|
||||||
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
|
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
|
||||||
|
|
||||||
let config = PrivateLanguageGenerator::get_config(self);
|
let config = PrivateLanguageGenerator::get_config(self);
|
||||||
@ -1647,7 +1716,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
forced_bos_token_id: forced_bos_token_id.into(),
|
forced_bos_token_id: forced_bos_token_id.into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let decoded = no_grad(|| {
|
let (decoded, scores) = no_grad(|| {
|
||||||
if num_beams > 1 {
|
if num_beams > 1 {
|
||||||
self.generate_beam_search(
|
self.generate_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1657,6 +1726,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
gen_opt,
|
gen_opt,
|
||||||
prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn,
|
||||||
|
output_scores,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
self.generate_no_beam_search(
|
self.generate_no_beam_search(
|
||||||
@ -1667,21 +1737,25 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
gen_opt,
|
gen_opt,
|
||||||
prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn,
|
||||||
|
output_scores,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let num_sequences = *decoded.size().first().unwrap();
|
let num_sequences = *decoded.size().first().unwrap();
|
||||||
let mut output_ids = Vec::with_capacity(num_sequences as usize);
|
let mut output = Vec::with_capacity(num_sequences as usize);
|
||||||
for sequence_index in 0..num_sequences {
|
for sequence_index in 0..num_sequences {
|
||||||
let sequence_output_ids = decoded
|
let indices = decoded
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.get(sequence_index)
|
.get(sequence_index)
|
||||||
.iter::<i64>()
|
.iter::<i64>()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.collect::<Vec<i64>>();
|
.collect::<Vec<i64>>();
|
||||||
output_ids.push(sequence_output_ids.clone());
|
let score = scores
|
||||||
|
.as_ref()
|
||||||
|
.map(|scores_value| scores_value[sequence_index as usize]);
|
||||||
|
output.push(GeneratedIndicesOutput { indices, score });
|
||||||
}
|
}
|
||||||
output_ids
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a reference to the text generator's tokenizer
|
/// Returns a reference to the text generator's tokenizer
|
||||||
|
@ -597,10 +597,10 @@ impl SequenceClassificationModel {
|
|||||||
);
|
);
|
||||||
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
|
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
|
||||||
});
|
});
|
||||||
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
|
let label_indices = output.as_ref().argmax(-1, true).squeeze_dim(1);
|
||||||
let scores = output
|
let scores = output
|
||||||
.gather(1, &label_indices.unsqueeze(-1), false)
|
.gather(1, &label_indices.unsqueeze(-1), false)
|
||||||
.squeeze1(1);
|
.squeeze_dim(1);
|
||||||
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||||
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||||
|
|
||||||
|
@ -263,18 +263,62 @@ impl SummarizationOption {
|
|||||||
S: AsRef<[&'a str]>,
|
S: AsRef<[&'a str]>,
|
||||||
{
|
{
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bart(ref model) => {
|
Self::Bart(ref model) => model
|
||||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
.generate(
|
||||||
}
|
prompt_texts,
|
||||||
Self::T5(ref model) => {
|
attention_mask,
|
||||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
None,
|
||||||
}
|
None,
|
||||||
Self::ProphetNet(ref model) => {
|
None,
|
||||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
None,
|
||||||
}
|
None,
|
||||||
Self::Pegasus(ref model) => {
|
false,
|
||||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
)
|
||||||
}
|
.into_iter()
|
||||||
|
.map(|output| output.text)
|
||||||
|
.collect(),
|
||||||
|
Self::T5(ref model) => model
|
||||||
|
.generate(
|
||||||
|
prompt_texts,
|
||||||
|
attention_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.text)
|
||||||
|
.collect(),
|
||||||
|
Self::ProphetNet(ref model) => model
|
||||||
|
.generate(
|
||||||
|
prompt_texts,
|
||||||
|
attention_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.text)
|
||||||
|
.collect(),
|
||||||
|
Self::Pegasus(ref model) => model
|
||||||
|
.generate(
|
||||||
|
prompt_texts,
|
||||||
|
attention_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.text)
|
||||||
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -249,51 +249,76 @@ impl TextGenerationOption {
|
|||||||
S: AsRef<[&'a str]>,
|
S: AsRef<[&'a str]>,
|
||||||
{
|
{
|
||||||
match *self {
|
match *self {
|
||||||
Self::GPT(ref model) => model.generate_indices(
|
Self::GPT(ref model) => model
|
||||||
prompt_texts,
|
.generate_indices(
|
||||||
attention_mask,
|
prompt_texts,
|
||||||
min_length,
|
attention_mask,
|
||||||
max_length,
|
min_length,
|
||||||
None,
|
max_length,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
),
|
None,
|
||||||
Self::GPT2(ref model) => model.generate_indices(
|
false,
|
||||||
prompt_texts,
|
)
|
||||||
attention_mask,
|
.into_iter()
|
||||||
min_length,
|
.map(|output| output.indices)
|
||||||
max_length,
|
.collect(),
|
||||||
None,
|
Self::GPT2(ref model) => model
|
||||||
None,
|
.generate_indices(
|
||||||
None,
|
prompt_texts,
|
||||||
),
|
attention_mask,
|
||||||
Self::GPTNeo(ref model) => model.generate_indices(
|
min_length,
|
||||||
prompt_texts,
|
max_length,
|
||||||
attention_mask,
|
None,
|
||||||
min_length,
|
None,
|
||||||
max_length,
|
None,
|
||||||
None,
|
false,
|
||||||
None,
|
)
|
||||||
None,
|
.into_iter()
|
||||||
),
|
.map(|output| output.indices)
|
||||||
Self::XLNet(ref model) => model.generate_indices(
|
.collect(),
|
||||||
prompt_texts,
|
Self::GPTNeo(ref model) => model
|
||||||
attention_mask,
|
.generate_indices(
|
||||||
min_length,
|
prompt_texts,
|
||||||
max_length,
|
attention_mask,
|
||||||
None,
|
min_length,
|
||||||
None,
|
max_length,
|
||||||
None,
|
None,
|
||||||
),
|
None,
|
||||||
Self::Reformer(ref model) => model.generate_indices(
|
None,
|
||||||
prompt_texts,
|
false,
|
||||||
attention_mask,
|
)
|
||||||
min_length,
|
.into_iter()
|
||||||
max_length,
|
.map(|output| output.indices)
|
||||||
None,
|
.collect(),
|
||||||
None,
|
Self::XLNet(ref model) => model
|
||||||
None,
|
.generate_indices(
|
||||||
),
|
prompt_texts,
|
||||||
|
attention_mask,
|
||||||
|
min_length,
|
||||||
|
max_length,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.indices)
|
||||||
|
.collect(),
|
||||||
|
Self::Reformer(ref model) => model
|
||||||
|
.generate_indices(
|
||||||
|
prompt_texts,
|
||||||
|
attention_mask,
|
||||||
|
min_length,
|
||||||
|
max_length,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.indices)
|
||||||
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -692,7 +692,7 @@ impl TokenClassificationModel {
|
|||||||
)
|
)
|
||||||
});
|
});
|
||||||
let output = output.detach().to(Device::Cpu);
|
let output = output.detach().to(Device::Cpu);
|
||||||
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
|
let score: Tensor = output.exp() / output.exp().sum_dim_intlist(&[-1], true, Float);
|
||||||
let labels_idx = &score.argmax(-1, true);
|
let labels_idx = &score.argmax(-1, true);
|
||||||
let mut tokens: Vec<Vec<Token>> = vec![];
|
let mut tokens: Vec<Vec<Token>> = vec![];
|
||||||
for sentence_idx in 0..labels_idx.size()[0] {
|
for sentence_idx in 0..labels_idx.size()[0] {
|
||||||
|
@ -674,12 +674,34 @@ impl TranslationOption {
|
|||||||
S: AsRef<[&'a str]>,
|
S: AsRef<[&'a str]>,
|
||||||
{
|
{
|
||||||
match *self {
|
match *self {
|
||||||
Self::Marian(ref model) => {
|
Self::Marian(ref model) => model
|
||||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
.generate(
|
||||||
}
|
prompt_texts,
|
||||||
Self::T5(ref model) => {
|
attention_mask,
|
||||||
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
|
None,
|
||||||
}
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.text)
|
||||||
|
.collect(),
|
||||||
|
Self::T5(ref model) => model
|
||||||
|
.generate(
|
||||||
|
prompt_texts,
|
||||||
|
attention_mask,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.into_iter()
|
||||||
|
.map(|output| output.text)
|
||||||
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -687,10 +687,10 @@ impl ZeroShotClassificationModel {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let scores = output.softmax(1, Float).select(-1, -1);
|
let scores = output.softmax(1, Float).select(-1, -1);
|
||||||
let label_indices = scores.as_ref().argmax(-1, true).squeeze1(1);
|
let label_indices = scores.as_ref().argmax(-1, true).squeeze_dim(1);
|
||||||
let scores = scores
|
let scores = scores
|
||||||
.gather(1, &label_indices.unsqueeze(-1), false)
|
.gather(1, &label_indices.unsqueeze(-1), false)
|
||||||
.squeeze1(1);
|
.squeeze_dim(1);
|
||||||
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||||
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||||
|
|
||||||
|
@ -602,7 +602,7 @@ impl ProphetNetNgramAttention {
|
|||||||
let hidden_states_size = hidden_states.size();
|
let hidden_states_size = hidden_states.size();
|
||||||
let (sequence_length, batch_size) = (hidden_states_size[0], hidden_states_size[1]);
|
let (sequence_length, batch_size) = (hidden_states_size[0], hidden_states_size[1]);
|
||||||
let calc_main_relative_position_buckets = if main_relative_position_buckets.is_none() {
|
let calc_main_relative_position_buckets = if main_relative_position_buckets.is_none() {
|
||||||
let relative_positions = Tensor::arange1(
|
let relative_positions = Tensor::arange_start(
|
||||||
1,
|
1,
|
||||||
attention_weights.size().last().unwrap() + 1,
|
attention_weights.size().last().unwrap() + 1,
|
||||||
(Kind::Int64, hidden_states.device()),
|
(Kind::Int64, hidden_states.device()),
|
||||||
@ -742,7 +742,7 @@ pub(crate) fn compute_relative_buckets(
|
|||||||
(
|
(
|
||||||
num_buckets,
|
num_buckets,
|
||||||
relative_positions.zeros_like(),
|
relative_positions.zeros_like(),
|
||||||
inverse_relative_positions.max1(&inverse_relative_positions.zeros_like()),
|
inverse_relative_positions.max_other(&inverse_relative_positions.zeros_like()),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
let max_exact = num_buckets / 2;
|
let max_exact = num_buckets / 2;
|
||||||
@ -754,10 +754,10 @@ pub(crate) fn compute_relative_buckets(
|
|||||||
+ max_exact_f64;
|
+ max_exact_f64;
|
||||||
|
|
||||||
let val_if_large = val_if_large
|
let val_if_large = val_if_large
|
||||||
.min1(&(val_if_large.ones_like() * (num_buckets as f64 - 1.0)))
|
.min_other(&(val_if_large.ones_like() * (num_buckets as f64 - 1.0)))
|
||||||
.totype(Kind::Int64);
|
.totype(Kind::Int64);
|
||||||
|
|
||||||
relative_positions_bucket + inverse_relative_positions.where1(&is_small, &val_if_large)
|
relative_positions_bucket + inverse_relative_positions.where_self(&is_small, &val_if_large)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn compute_all_stream_relative_buckets(
|
pub(crate) fn compute_all_stream_relative_buckets(
|
||||||
|
@ -308,9 +308,9 @@ impl ProphetNetDecoder {
|
|||||||
|
|
||||||
let hidden_states = (input_embeds + main_stream_pos_embed).transpose(0, 1);
|
let hidden_states = (input_embeds + main_stream_pos_embed).transpose(0, 1);
|
||||||
|
|
||||||
let (mut ngram_hidden_states, extended_attention_mask, extended_predict_attention_mask) =
|
let (mut ngram_hidden_states, extended_attention_mask, extended_predict_attention_mask) = {
|
||||||
|
let mut ngram_hidden_states = Vec::with_capacity(self.ngram as usize);
|
||||||
if old_layer_states.is_some() {
|
if old_layer_states.is_some() {
|
||||||
let mut ngram_hidden_states = Vec::with_capacity(self.ngram as usize);
|
|
||||||
for ngram in 0..self.ngram {
|
for ngram in 0..self.ngram {
|
||||||
ngram_hidden_states.push(
|
ngram_hidden_states.push(
|
||||||
(&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed)
|
(&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed)
|
||||||
@ -320,7 +320,6 @@ impl ProphetNetDecoder {
|
|||||||
}
|
}
|
||||||
(ngram_hidden_states, None, None)
|
(ngram_hidden_states, None, None)
|
||||||
} else {
|
} else {
|
||||||
let mut ngram_hidden_states = Vec::with_capacity(self.ngram as usize);
|
|
||||||
for ngram in 0..self.ngram {
|
for ngram in 0..self.ngram {
|
||||||
ngram_hidden_states.push(
|
ngram_hidden_states.push(
|
||||||
(&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed)
|
(&self.ngram_embeddings.get(ngram - 1) + &predicting_stream_pos_embed)
|
||||||
@ -336,7 +335,8 @@ impl ProphetNetDecoder {
|
|||||||
Some(extended_attention_mask),
|
Some(extended_attention_mask),
|
||||||
Some(extended_predict_attention_mask),
|
Some(extended_predict_attention_mask),
|
||||||
)
|
)
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let extended_encoder_attention_mask =
|
let extended_encoder_attention_mask =
|
||||||
encoder_attention_mask.map(|encoder_attention_mask_value| {
|
encoder_attention_mask.map(|encoder_attention_mask_value| {
|
||||||
@ -510,7 +510,7 @@ impl ProphetNetDecoder {
|
|||||||
let input_size = position_ids.size();
|
let input_size = position_ids.size();
|
||||||
let (batch_size, sequence_length) = (input_size[0], input_size[1]);
|
let (batch_size, sequence_length) = (input_size[0], input_size[1]);
|
||||||
|
|
||||||
let position_ids = Tensor::arange1(
|
let position_ids = Tensor::arange_start(
|
||||||
1,
|
1,
|
||||||
self.max_target_positions,
|
self.max_target_positions,
|
||||||
(Kind::Int64, position_ids.device()),
|
(Kind::Int64, position_ids.device()),
|
||||||
|
@ -307,7 +307,7 @@ impl LSHSelfAttention {
|
|||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.expand(&buckets.size(), true)
|
.expand(&buckets.size(), true)
|
||||||
.to_kind(Kind::Bool);
|
.to_kind(Kind::Bool);
|
||||||
buckets = buckets.where1(
|
buckets = buckets.where_self(
|
||||||
&buckets_mask,
|
&buckets_mask,
|
||||||
&Tensor::of_slice(&[num_buckets - 1])
|
&Tensor::of_slice(&[num_buckets - 1])
|
||||||
.to_kind(Kind::Float)
|
.to_kind(Kind::Float)
|
||||||
@ -423,15 +423,16 @@ impl LSHSelfAttention {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if let Some(mask) = mask {
|
if let Some(mask) = mask {
|
||||||
query_key_dots = query_key_dots.where1(&mask.to_kind(Kind::Bool), &self.mask_value);
|
query_key_dots =
|
||||||
|
query_key_dots.where_self(&mask.to_kind(Kind::Bool), &self.mask_value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
let self_mask = query_bucket_idx
|
let self_mask = query_bucket_idx
|
||||||
.unsqueeze(-1)
|
.unsqueeze(-1)
|
||||||
.ne1(&key_value_bucket_idx.unsqueeze(-2));
|
.ne_tensor(&key_value_bucket_idx.unsqueeze(-2));
|
||||||
query_key_dots =
|
query_key_dots =
|
||||||
query_key_dots.where1(&self_mask.to_kind(Kind::Bool), &self.self_mask_value);
|
query_key_dots.where_self(&self_mask.to_kind(Kind::Bool), &self.self_mask_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut logits = query_key_dots.logsumexp(&[-1], true);
|
let mut logits = query_key_dots.logsumexp(&[-1], true);
|
||||||
@ -441,7 +442,7 @@ impl LSHSelfAttention {
|
|||||||
|
|
||||||
let mut out_vectors = attention_probs.matmul(&value_vectors);
|
let mut out_vectors = attention_probs.matmul(&value_vectors);
|
||||||
if out_vectors.dim() > 4 {
|
if out_vectors.dim() > 4 {
|
||||||
logits = logits.flatten(2, 3).squeeze1(-1);
|
logits = logits.flatten(2, 3).squeeze_dim(-1);
|
||||||
out_vectors = out_vectors.flatten(2, 3)
|
out_vectors = out_vectors.flatten(2, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -476,7 +477,9 @@ impl LSHSelfAttention {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if self.is_decoder {
|
if self.is_decoder {
|
||||||
let causal_mask = query_indices.unsqueeze(-1).ge1(&key_indices.unsqueeze(-2));
|
let causal_mask = query_indices
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.ge_tensor(&key_indices.unsqueeze(-2));
|
||||||
let attention_mask = if let Some(attention_mask) = attention_mask {
|
let attention_mask = if let Some(attention_mask) = attention_mask {
|
||||||
causal_mask * attention_mask
|
causal_mask * attention_mask
|
||||||
} else {
|
} else {
|
||||||
@ -534,7 +537,10 @@ impl LSHSelfAttention {
|
|||||||
*relevant_bucket_indices_chunk.size().last().unwrap(),
|
*relevant_bucket_indices_chunk.size().last().unwrap(),
|
||||||
(Kind::Int64, hidden_states.device()),
|
(Kind::Int64, hidden_states.device()),
|
||||||
)
|
)
|
||||||
.floor_divide1(*relevant_bucket_indices_chunk.size().last().unwrap()));
|
.divide_scalar_mode(
|
||||||
|
*relevant_bucket_indices_chunk.size().last().unwrap(),
|
||||||
|
"floor",
|
||||||
|
));
|
||||||
|
|
||||||
let relevant_bucket_indices_chunk_all_batch =
|
let relevant_bucket_indices_chunk_all_batch =
|
||||||
&relevant_bucket_indices_chunk + bucket_indices_batch_offset;
|
&relevant_bucket_indices_chunk + bucket_indices_batch_offset;
|
||||||
@ -566,7 +572,9 @@ impl LSHSelfAttention {
|
|||||||
indices: &Tensor,
|
indices: &Tensor,
|
||||||
sequence_length: i64,
|
sequence_length: i64,
|
||||||
) -> Tensor {
|
) -> Tensor {
|
||||||
let start_indices_chunk = (indices.select(1, -1).floor_divide1(self.chunk_length)
|
let start_indices_chunk = (indices
|
||||||
|
.select(1, -1)
|
||||||
|
.divide_scalar_mode(self.chunk_length, "floor")
|
||||||
- self.num_chunks_before)
|
- self.num_chunks_before)
|
||||||
* self.chunk_length;
|
* self.chunk_length;
|
||||||
let total_chunk_size =
|
let total_chunk_size =
|
||||||
@ -593,7 +601,7 @@ impl LSHSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor {
|
fn len_norm(&self, input_tensor: &Tensor, epsilon: f64) -> Tensor {
|
||||||
let variance = (input_tensor * input_tensor).mean1(&[-1], true, input_tensor.kind());
|
let variance = (input_tensor * input_tensor).mean_dim(&[-1], true, input_tensor.kind());
|
||||||
input_tensor * (variance + epsilon).rsqrt()
|
input_tensor * (variance + epsilon).rsqrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -850,7 +858,7 @@ impl LSHSelfAttention {
|
|||||||
)?
|
)?
|
||||||
.unsqueeze(-1);
|
.unsqueeze(-1);
|
||||||
let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp();
|
let probs_vectors = (&logits - &logits.logsumexp(&[2], true)).exp();
|
||||||
out_vectors = (out_vectors * probs_vectors).sum1(&[2], false, Kind::Float);
|
out_vectors = (out_vectors * probs_vectors).sum_dim_intlist(&[2], false, Kind::Float);
|
||||||
}
|
}
|
||||||
|
|
||||||
out_vectors = merge_hidden_size_dim(
|
out_vectors = merge_hidden_size_dim(
|
||||||
@ -967,7 +975,9 @@ impl LocalSelfAttention {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if self.is_decoder {
|
if self.is_decoder {
|
||||||
let causal_mask = query_indices.unsqueeze(-1).ge1(&key_indices.unsqueeze(-2));
|
let causal_mask = query_indices
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.ge_tensor(&key_indices.unsqueeze(-2));
|
||||||
attention_mask = Some(if let Some(mask) = attention_mask {
|
attention_mask = Some(if let Some(mask) = attention_mask {
|
||||||
causal_mask * mask
|
causal_mask * mask
|
||||||
} else {
|
} else {
|
||||||
@ -1087,7 +1097,7 @@ impl LocalSelfAttention {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if let Some(mask) = attention_mask {
|
if let Some(mask) = attention_mask {
|
||||||
query_key_dots = query_key_dots.where1(&mask.to_kind(Kind::Bool), &self.mask_value);
|
query_key_dots = query_key_dots.where_self(&mask.to_kind(Kind::Bool), &self.mask_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
let logits = query_key_dots.logsumexp(&[-1], true);
|
let logits = query_key_dots.logsumexp(&[-1], true);
|
||||||
|
@ -263,10 +263,9 @@ impl ReformerEmbeddings {
|
|||||||
|
|
||||||
let calc_position_ids = if position_ids.is_none() {
|
let calc_position_ids = if position_ids.is_none() {
|
||||||
Some(
|
Some(
|
||||||
Tensor::arange2(
|
Tensor::arange_start(
|
||||||
start_ids_pos_encoding,
|
start_ids_pos_encoding,
|
||||||
start_ids_pos_encoding + input_shape[1],
|
start_ids_pos_encoding + input_shape[1],
|
||||||
1,
|
|
||||||
(Kind::Int64, device),
|
(Kind::Int64, device),
|
||||||
)
|
)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
|
@ -418,10 +418,9 @@ impl ReformerModel {
|
|||||||
let input_ids = Tensor::cat(&[input_ids, &input_ids_padding], -1);
|
let input_ids = Tensor::cat(&[input_ids, &input_ids_padding], -1);
|
||||||
new_input_shape = input_ids.size();
|
new_input_shape = input_ids.size();
|
||||||
let position_ids = if let Some(position_ids) = position_ids {
|
let position_ids = if let Some(position_ids) = position_ids {
|
||||||
let position_ids_padding = Tensor::arange2(
|
let position_ids_padding = Tensor::arange_start(
|
||||||
*input_shape.last().unwrap(),
|
*input_shape.last().unwrap(),
|
||||||
self.least_common_mult_chunk_length,
|
self.least_common_mult_chunk_length,
|
||||||
1,
|
|
||||||
(Kind::Int64, device),
|
(Kind::Int64, device),
|
||||||
)
|
)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
@ -568,7 +567,7 @@ impl ReformerModelWithLMHead {
|
|||||||
///
|
///
|
||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// reformer_model.forward_t(
|
/// reformer_model.forward_t(
|
||||||
/// Some(&input_tensor),
|
/// Some(&input_tensor),
|
||||||
/// Some(&input_positions),
|
/// Some(&input_positions),
|
||||||
/// None,
|
/// None,
|
||||||
/// Some(&attention_mask),
|
/// Some(&attention_mask),
|
||||||
@ -801,7 +800,7 @@ impl ReformerForSequenceClassification {
|
|||||||
///
|
///
|
||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// reformer_model.forward_t(
|
/// reformer_model.forward_t(
|
||||||
/// Some(&input_tensor),
|
/// Some(&input_tensor),
|
||||||
/// Some(&input_positions),
|
/// Some(&input_positions),
|
||||||
/// None,
|
/// None,
|
||||||
/// Some(&attention_mask),
|
/// Some(&attention_mask),
|
||||||
@ -939,7 +938,7 @@ impl ReformerForQuestionAnswering {
|
|||||||
///
|
///
|
||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// reformer_model.forward_t(
|
/// reformer_model.forward_t(
|
||||||
/// Some(&input_tensor),
|
/// Some(&input_tensor),
|
||||||
/// Some(&input_positions),
|
/// Some(&input_positions),
|
||||||
/// None,
|
/// None,
|
||||||
/// Some(&attention_mask),
|
/// Some(&attention_mask),
|
||||||
@ -972,8 +971,8 @@ impl ReformerForQuestionAnswering {
|
|||||||
.apply(&self.qa_outputs)
|
.apply(&self.qa_outputs)
|
||||||
.split(1, -1);
|
.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
Ok(ReformerQuestionAnsweringModelOutput {
|
Ok(ReformerQuestionAnsweringModelOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -39,7 +39,7 @@ impl RobertaEmbeddings {
|
|||||||
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
|
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
|
||||||
let input_shape = x.size();
|
let input_shape = x.size();
|
||||||
let input_shape = vec![input_shape[0], input_shape[1]];
|
let input_shape = vec![input_shape[0], input_shape[1]];
|
||||||
let position_ids: Tensor = Tensor::arange1(
|
let position_ids: Tensor = Tensor::arange_start(
|
||||||
self.padding_index + 1,
|
self.padding_index + 1,
|
||||||
input_shape[0],
|
input_shape[0],
|
||||||
(Kind::Int64, x.device()),
|
(Kind::Int64, x.device()),
|
||||||
|
@ -961,8 +961,8 @@ impl RobertaForQuestionAnswering {
|
|||||||
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
||||||
let logits = sequence_output.split(1, -1);
|
let logits = sequence_output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
RobertaQuestionAnsweringOutput {
|
RobertaQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -254,7 +254,7 @@ impl T5Attention {
|
|||||||
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
|
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
|
||||||
n.abs()
|
n.abs()
|
||||||
} else {
|
} else {
|
||||||
n.max1(&n.zeros_like())
|
n.max_other(&n.zeros_like())
|
||||||
};
|
};
|
||||||
|
|
||||||
let max_exact = num_buckets / 2;
|
let max_exact = num_buckets / 2;
|
||||||
@ -266,8 +266,8 @@ impl T5Attention {
|
|||||||
.to_kind(Kind::Int64)
|
.to_kind(Kind::Int64)
|
||||||
+ max_exact;
|
+ max_exact;
|
||||||
|
|
||||||
let value_if_large = value_if_large.min1(&value_if_large.full_like(num_buckets - 1));
|
let value_if_large = value_if_large.min_other(&value_if_large.full_like(num_buckets - 1));
|
||||||
ret += n.where1(&is_small, &value_if_large);
|
ret += n.where_self(&is_small, &value_if_large);
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,7 +328,7 @@ impl T5Stack {
|
|||||||
input_shape[1],
|
input_shape[1],
|
||||||
1,
|
1,
|
||||||
]);
|
]);
|
||||||
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1));
|
let causal_mask = causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
|
||||||
causal_mask.unsqueeze(1) * attention_mask.unsqueeze(1).unsqueeze(1)
|
causal_mask.unsqueeze(1) * attention_mask.unsqueeze(1).unsqueeze(1)
|
||||||
} else {
|
} else {
|
||||||
attention_mask.unsqueeze(1).unsqueeze(1)
|
attention_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
@ -32,7 +32,7 @@ impl T5LayerNorm {
|
|||||||
|
|
||||||
impl Module for T5LayerNorm {
|
impl Module for T5LayerNorm {
|
||||||
fn forward(&self, x: &Tensor) -> Tensor {
|
fn forward(&self, x: &Tensor) -> Tensor {
|
||||||
let variance = x.pow(2f64).mean1(&[-1], true, Kind::Float);
|
let variance = x.pow(2f64).mean_dim(&[-1], true, Kind::Float);
|
||||||
let x = x / (variance + self.epsilon).sqrt();
|
let x = x / (variance + self.epsilon).sqrt();
|
||||||
&self.weight * x
|
&self.weight * x
|
||||||
}
|
}
|
||||||
|
@ -287,13 +287,16 @@ impl XLNetModel {
|
|||||||
batch_size: Option<i64>,
|
batch_size: Option<i64>,
|
||||||
device: Device,
|
device: Device,
|
||||||
) -> Tensor {
|
) -> Tensor {
|
||||||
let frequency_sequence = Tensor::arange2(0, self.d_model, 2, (Kind::Float, device));
|
let frequency_sequence =
|
||||||
let inverse_frequency = 1f64 / Tensor::pow2(10000f64, &(frequency_sequence / self.d_model));
|
Tensor::arange_start_step(0, self.d_model, 2, (Kind::Float, device));
|
||||||
|
let inverse_frequency =
|
||||||
|
1f64 / Tensor::pow_scalar(10000f64, &(frequency_sequence / self.d_model));
|
||||||
let (begin, end) = match self.attention_type {
|
let (begin, end) = match self.attention_type {
|
||||||
AttentionType::bi => (k_len, -q_len),
|
AttentionType::bi => (k_len, -q_len),
|
||||||
AttentionType::uni => (k_len, -1),
|
AttentionType::uni => (k_len, -1),
|
||||||
};
|
};
|
||||||
let mut forward_positions_sequence = Tensor::arange2(begin, end, -1, (Kind::Float, device));
|
let mut forward_positions_sequence =
|
||||||
|
Tensor::arange_start_step(begin, end, -1, (Kind::Float, device));
|
||||||
match self.clamp_len {
|
match self.clamp_len {
|
||||||
Some(clamp_value) if clamp_value > 0 => {
|
Some(clamp_value) if clamp_value > 0 => {
|
||||||
let _ = forward_positions_sequence.clamp_(-clamp_value, clamp_value);
|
let _ = forward_positions_sequence.clamp_(-clamp_value, clamp_value);
|
||||||
@ -302,7 +305,7 @@ impl XLNetModel {
|
|||||||
}
|
}
|
||||||
if self.bi_data {
|
if self.bi_data {
|
||||||
let mut backward_positions_sequence =
|
let mut backward_positions_sequence =
|
||||||
Tensor::arange2(-begin, -end, 1, (Kind::Float, device));
|
Tensor::arange_start(-begin, -end, (Kind::Float, device));
|
||||||
match self.clamp_len {
|
match self.clamp_len {
|
||||||
Some(clamp_value) if clamp_value > 0 => {
|
Some(clamp_value) if clamp_value > 0 => {
|
||||||
let _ = backward_positions_sequence.clamp_(-clamp_value, clamp_value);
|
let _ = backward_positions_sequence.clamp_(-clamp_value, clamp_value);
|
||||||
@ -512,7 +515,7 @@ impl XLNetModel {
|
|||||||
};
|
};
|
||||||
let seg_mat = token_type_ids_value
|
let seg_mat = token_type_ids_value
|
||||||
.unsqueeze(-1)
|
.unsqueeze(-1)
|
||||||
.ne1(&cat_ids.unsqueeze(0))
|
.ne_tensor(&cat_ids.unsqueeze(0))
|
||||||
.to_kind(Kind::Int64);
|
.to_kind(Kind::Int64);
|
||||||
Some(seg_mat.one_hot(2).to_kind(Kind::Float))
|
Some(seg_mat.one_hot(2).to_kind(Kind::Float))
|
||||||
} else {
|
} else {
|
||||||
@ -1461,8 +1464,8 @@ impl XLNetForQuestionAnswering {
|
|||||||
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
|
||||||
let logits = sequence_output.split(1, -1);
|
let logits = sequence_output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze_dim(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze_dim(-1);
|
||||||
|
|
||||||
XLNetQuestionAnsweringOutput {
|
XLNetQuestionAnsweringOutput {
|
||||||
start_logits,
|
start_logits,
|
||||||
|
@ -428,17 +428,20 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
Some(&force_one_paragraph),
|
Some(&force_one_paragraph),
|
||||||
|
true,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(output.len(), 2);
|
assert_eq!(output.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
output[0],
|
output[0].text,
|
||||||
"Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n"
|
"Rust is a very simple and powerful library for building and running web applications. It is a simple, fast, and lightweight library that can be used to build web applications in a number of different ways.\n"
|
||||||
);
|
);
|
||||||
|
assert!((output[0].score.unwrap() - (-1.4666)).abs() < 1e-4);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
output[1],
|
output[1].text,
|
||||||
"There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And"
|
"There was a urn in the room, and I was sitting on it. I was like, \'What the hell is going on?\' And he said, \'Well, I\'m not sure. I\'m just going to go back to my room and get some coffee.\' And"
|
||||||
);
|
);
|
||||||
|
assert!((output[1].score.unwrap() - (-1.3545)).abs() < 1e-4);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -493,17 +496,20 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
Some(&force_one_paragraph),
|
Some(&force_one_paragraph),
|
||||||
|
true,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(output.len(), 2);
|
assert_eq!(output.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
output[0],
|
output[0].text,
|
||||||
"Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and"
|
"Rust is a simple, fast, and easy-to-use framework for building web applications. It is designed to be easy to use and maintain, and"
|
||||||
);
|
);
|
||||||
|
assert!((output[0].score.unwrap() - (-1.2750)).abs() < 1e-4);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
output[1],
|
output[1].text,
|
||||||
"There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I"
|
"There was a urn in the back of the room, and I was sitting on it, and it looked like it was going to explode. And then I"
|
||||||
);
|
);
|
||||||
|
assert!((output[1].score.unwrap() - (-1.3326)).abs() < 1e-4);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -97,11 +97,12 @@ fn mbart_translation() -> anyhow::Result<()> {
|
|||||||
None,
|
None,
|
||||||
target_language,
|
target_language,
|
||||||
None,
|
None,
|
||||||
|
false,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(output.len(), 1);
|
assert_eq!(output.len(), 1);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
output[0],
|
output[0].text,
|
||||||
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
|
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user