mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Shallow clone optimization (#243)
* Fixed Clippy warnings * Shallow clone optimization (reduce tensor copy) * Updated changelog and fixed Clippy warnings
This commit is contained in:
parent
6f1888e8f9
commit
ba584653bc
@ -130,7 +130,7 @@ impl AlbertLayerGroup {
|
||||
None
|
||||
};
|
||||
|
||||
let mut hidden_state = hidden_states.copy();
|
||||
let mut hidden_state = hidden_states.shallow_clone();
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
|
||||
for layer in &self.layers {
|
||||
|
@ -28,8 +28,8 @@ pub struct LayerState {
|
||||
impl Clone for LayerState {
|
||||
fn clone(&self) -> Self {
|
||||
LayerState {
|
||||
prev_key: self.prev_key.copy(),
|
||||
prev_value: self.prev_value.copy(),
|
||||
prev_key: self.prev_key.shallow_clone(),
|
||||
prev_value: self.prev_value.shallow_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -138,8 +138,8 @@ impl BartAttention {
|
||||
|
||||
let new_layer_state = if self.store_cache {
|
||||
Some(LayerState {
|
||||
prev_key: key_states.copy(),
|
||||
prev_value: value_states.copy(),
|
||||
prev_key: key_states.shallow_clone(),
|
||||
prev_value: value_states.shallow_clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
@ -109,7 +109,7 @@ impl EncoderLayer {
|
||||
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
||||
let output = output.apply(&self.self_attention_layer_norm);
|
||||
|
||||
let residual = output.copy();
|
||||
let residual = output.shallow_clone();
|
||||
let output = (self.activation.get_fn())(&output.apply(&self.fc1));
|
||||
let output = output
|
||||
.apply_t(&self.activation_dropout, train)
|
||||
|
@ -337,7 +337,7 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
|
||||
let encoder_hidden_states_shape = encoder_hidden_states.size();
|
||||
let encoder_mask = match encoder_mask {
|
||||
Some(value) => value.copy(),
|
||||
Some(value) => value.shallow_clone(),
|
||||
None => Tensor::ones(
|
||||
&[
|
||||
encoder_hidden_states_shape[0],
|
||||
@ -416,7 +416,7 @@ impl BertPredictionHeadTransform {
|
||||
}
|
||||
|
||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||
((&self.activation.get_fn())(&hidden_states.apply(&self.dense))).apply(&self.layer_norm)
|
||||
(self.activation.get_fn()(&hidden_states.apply(&self.dense))).apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -205,7 +205,7 @@ impl DebertaDisentangledSelfAttention {
|
||||
let r_pos = if query_layer_size[1] != key_layer_size[1] {
|
||||
build_relative_position(key_layer_size[1], key_layer_size[1], query_layer.device())
|
||||
} else {
|
||||
relative_pos.copy()
|
||||
relative_pos.shallow_clone()
|
||||
};
|
||||
let p2c_pos = (-r_pos + attention_span).clamp(0, attention_span * 2 - 1);
|
||||
let mut p2c_att = key_layer
|
||||
|
@ -333,10 +333,10 @@ impl DebertaV2Encoder {
|
||||
}
|
||||
attention_weights = layer_output.1;
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
attentions.push(attention_weights.as_ref().unwrap().shallow_clone());
|
||||
};
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(output_states.as_ref().unwrap().copy());
|
||||
hidden_states.push(output_states.as_ref().unwrap().shallow_clone());
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -150,7 +150,6 @@ impl Transformer {
|
||||
None
|
||||
};
|
||||
|
||||
// let mut hidden_state = input.copy();
|
||||
let mut hidden_state: Option<Tensor> = None;
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
|
||||
|
@ -175,7 +175,7 @@ impl Attention {
|
||||
}
|
||||
None => (key, value),
|
||||
};
|
||||
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
|
||||
let present = Tensor::stack(&[key.transpose(-2, -1), value.shallow_clone()], 0);
|
||||
let (a, attentions) = self.attention(&query, &key, &value, attention_mask, train);
|
||||
|
||||
let a = self
|
||||
|
@ -386,7 +386,7 @@ impl Gpt2Model {
|
||||
(
|
||||
value
|
||||
.iter()
|
||||
.map(|v| Some(v.copy()))
|
||||
.map(|v| Some(v.shallow_clone()))
|
||||
.collect::<Vec<Option<Tensor>>>(),
|
||||
value[0].size()[3],
|
||||
)
|
||||
@ -399,7 +399,7 @@ impl Gpt2Model {
|
||||
};
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value.copy(),
|
||||
Some(value) => value.shallow_clone(),
|
||||
None => Tensor::arange_start(
|
||||
layer_past_length,
|
||||
seq_length + layer_past_length,
|
||||
|
@ -29,8 +29,8 @@ pub struct LayerState {
|
||||
impl Clone for LayerState {
|
||||
fn clone(&self) -> Self {
|
||||
LayerState {
|
||||
prev_key: self.prev_key.copy(),
|
||||
prev_value: self.prev_value.as_ref().map(|value| value.copy()),
|
||||
prev_key: self.prev_key.shallow_clone(),
|
||||
prev_value: self.prev_value.as_ref().map(|value| value.shallow_clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -211,8 +211,8 @@ impl GptNeoSelfAttention {
|
||||
};
|
||||
|
||||
let layer_state = Some(LayerState {
|
||||
prev_key: key.copy(),
|
||||
prev_value: Some(value.copy()),
|
||||
prev_key: key.shallow_clone(),
|
||||
prev_value: Some(value.shallow_clone()),
|
||||
});
|
||||
|
||||
let (attention_output, attention_weights) =
|
||||
|
@ -107,7 +107,7 @@ impl MBartEncoderLayer {
|
||||
.forward_t(&output, None, encoder_attention_mask, None, train);
|
||||
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
||||
|
||||
let residual = output.copy();
|
||||
let residual = output.shallow_clone();
|
||||
let output = output.apply(&self.final_layer_norm);
|
||||
let output = (self.activation.get_fn())(&output.apply(&self.fc1));
|
||||
let output = output
|
||||
|
@ -221,7 +221,7 @@ impl OpenAiGptModel {
|
||||
let seq_length = input_shape[1];
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value.copy(),
|
||||
Some(value) => value.shallow_clone(),
|
||||
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
|
||||
};
|
||||
|
||||
|
@ -507,7 +507,7 @@ pub(crate) mod private_generation_utils {
|
||||
&Tensor::arange_start(1, vocab_size, (Int64, logits.device())),
|
||||
&sorted_indices_to_remove
|
||||
.slice(1, 0, vocab_size - 1, 1)
|
||||
.copy(),
|
||||
.shallow_clone(),
|
||||
);
|
||||
let _ = sorted_indices_to_remove.index_fill_(
|
||||
1,
|
||||
@ -749,8 +749,8 @@ pub(crate) mod private_generation_utils {
|
||||
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
|
||||
self.split_bad_word_ids(gen_opt.bad_word_ids);
|
||||
let mut static_bad_words_mask: Option<Tensor> = None;
|
||||
let mut attention_mask = attention_mask.copy();
|
||||
let mut input_ids = input_ids.copy();
|
||||
let mut attention_mask = attention_mask.shallow_clone();
|
||||
let mut input_ids = input_ids.shallow_clone();
|
||||
let mut past: Cache = Cache::None;
|
||||
let mut outputs: Tensor;
|
||||
let mut current_length = cur_len;
|
||||
@ -759,10 +759,10 @@ pub(crate) mod private_generation_utils {
|
||||
|
||||
while current_length < gen_opt.max_length {
|
||||
let prepared_input = self.prepare_inputs_for_generation(
|
||||
input_ids.copy(),
|
||||
input_ids.shallow_clone(),
|
||||
encoder_outputs.as_ref(),
|
||||
past,
|
||||
attention_mask.copy(),
|
||||
attention_mask.shallow_clone(),
|
||||
);
|
||||
let temp = self
|
||||
.get_model()
|
||||
@ -1020,10 +1020,10 @@ pub(crate) mod private_generation_utils {
|
||||
);
|
||||
}
|
||||
let prepared_input = self.prepare_inputs_for_generation(
|
||||
input_ids.copy(),
|
||||
input_ids.shallow_clone(),
|
||||
encoder_outputs.as_ref(),
|
||||
past,
|
||||
attention_mask.copy(),
|
||||
attention_mask.shallow_clone(),
|
||||
);
|
||||
let temp = self
|
||||
.get_model()
|
||||
@ -1240,10 +1240,10 @@ pub(crate) mod private_generation_utils {
|
||||
saved_beam_scores.as_ref().map(|step_wise_scores| {
|
||||
Tensor::stack(step_wise_scores, 1)
|
||||
.get(effective_beam_id)
|
||||
.copy()
|
||||
.shallow_clone()
|
||||
});
|
||||
hypotheses[batch_index as usize].add(
|
||||
input_ids.get(effective_beam_id).copy(),
|
||||
input_ids.get(effective_beam_id).shallow_clone(),
|
||||
beam_token_score,
|
||||
saved_beam_scores,
|
||||
);
|
||||
@ -1301,7 +1301,7 @@ pub(crate) mod private_generation_utils {
|
||||
}
|
||||
|
||||
if let Some(scores_output) = saved_beam_scores.as_mut() {
|
||||
scores_output.push(beam_scores.copy());
|
||||
scores_output.push(beam_scores.shallow_clone());
|
||||
}
|
||||
if done.iter().all(|&x| x) {
|
||||
break;
|
||||
@ -2105,10 +2105,10 @@ impl Clone for BeamHypotheses {
|
||||
.map(|(score, tensor, scores_tensor)| {
|
||||
(
|
||||
*score,
|
||||
tensor.copy(),
|
||||
tensor.shallow_clone(),
|
||||
scores_tensor
|
||||
.as_ref()
|
||||
.map(|scores_tensor| scores_tensor.copy()),
|
||||
.map(|scores_tensor| scores_tensor.shallow_clone()),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<(f64, Tensor, Option<Tensor>)>>(),
|
||||
|
@ -31,8 +31,8 @@ pub struct LayerState {
|
||||
impl Clone for LayerState {
|
||||
fn clone(&self) -> Self {
|
||||
LayerState {
|
||||
prev_key: self.prev_key.copy(),
|
||||
prev_value: self.prev_value.copy(),
|
||||
prev_key: self.prev_key.shallow_clone(),
|
||||
prev_value: self.prev_value.shallow_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ use tch::{nn, Device, Kind, Tensor};
|
||||
fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor {
|
||||
let left_block = Tensor::ones(&[ngram, sequence_length, sequence_length], (kind, device))
|
||||
* get_negative_infinity(kind).unwrap();
|
||||
let right_block = left_block.copy();
|
||||
let right_block = left_block.shallow_clone();
|
||||
for stream_idx in 0..ngram {
|
||||
let _ = right_block.get(stream_idx).fill_diagonal_(0, false);
|
||||
let _ = left_block.get(stream_idx).triu_(-stream_idx + 1);
|
||||
|
@ -68,7 +68,7 @@ impl ProphetNetPositionalEmbeddings {
|
||||
attention_mask.cumsum(1, Kind::Int64) * attention_mask + self.padding_idx
|
||||
}
|
||||
}
|
||||
Some(value) => value.copy(),
|
||||
Some(value) => value.shallow_clone(),
|
||||
};
|
||||
|
||||
(calc_position_ids.apply(&self.embeddings), calc_position_ids)
|
||||
|
@ -40,10 +40,10 @@ impl Clone for LayerState {
|
||||
let prev_buckets = self
|
||||
.prev_buckets
|
||||
.as_ref()
|
||||
.map(|prev_buckets| prev_buckets.copy());
|
||||
.map(|prev_buckets| prev_buckets.shallow_clone());
|
||||
LayerState {
|
||||
prev_buckets,
|
||||
prev_states: self.prev_states.copy(),
|
||||
prev_states: self.prev_states.shallow_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -417,7 +417,7 @@ impl LSHSelfAttention {
|
||||
None,
|
||||
)?;
|
||||
let key_value_bucket_idx = look_adjacent(
|
||||
query_bucket_idx.copy(),
|
||||
query_bucket_idx.shallow_clone(),
|
||||
self.num_chunks_before,
|
||||
self.num_chunks_after,
|
||||
);
|
||||
@ -446,7 +446,7 @@ impl LSHSelfAttention {
|
||||
(query_bucket_idx, key_value_bucket_idx)
|
||||
} else {
|
||||
(
|
||||
sorted_bucket_indices_per_hash.copy(),
|
||||
sorted_bucket_indices_per_hash.shallow_clone(),
|
||||
sorted_bucket_indices_per_hash,
|
||||
)
|
||||
};
|
||||
@ -849,7 +849,7 @@ impl LSHSelfAttention {
|
||||
false
|
||||
}
|
||||
} {
|
||||
(sorted_bucket_idx.unwrap().copy(), None)
|
||||
(sorted_bucket_idx.unwrap().shallow_clone(), None)
|
||||
} else {
|
||||
(
|
||||
Tensor::arange(sequence_length, (Kind::Int64, query_key_vectors.device()))
|
||||
@ -1121,7 +1121,7 @@ impl LocalSelfAttention {
|
||||
self.num_attention_heads,
|
||||
None,
|
||||
)?;
|
||||
let key_indices = query_indices.copy();
|
||||
let key_indices = query_indices.shallow_clone();
|
||||
|
||||
key_vectors = look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after);
|
||||
value_vectors =
|
||||
@ -1130,7 +1130,7 @@ impl LocalSelfAttention {
|
||||
look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after);
|
||||
(query_indices, key_indices)
|
||||
} else {
|
||||
(indices.copy(), indices.copy())
|
||||
(indices.shallow_clone(), indices.shallow_clone())
|
||||
};
|
||||
|
||||
let mut query_key_dots = query_vectors.matmul(&key_vectors.transpose(-1, -2));
|
||||
@ -1356,7 +1356,7 @@ impl ReformerAttention {
|
||||
if original_sequence_length > 1 {
|
||||
Some(buckets_value.slice(3, 0, original_sequence_length, 1))
|
||||
} else {
|
||||
Some(buckets_value.copy())
|
||||
Some(buckets_value.shallow_clone())
|
||||
}
|
||||
} else {
|
||||
Some(Tensor::cat(
|
||||
|
@ -291,8 +291,8 @@ impl ReformerEncoder {
|
||||
original_sequence_length: i64,
|
||||
train: bool,
|
||||
) -> Result<ReformerModelOutput, RustBertError> {
|
||||
let mut hidden_state = hidden_states.copy();
|
||||
let mut attention_output = hidden_states.copy();
|
||||
let mut hidden_state = hidden_states.shallow_clone();
|
||||
let mut attention_output = hidden_states.shallow_clone();
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(Vec::with_capacity(self.layers.len()))
|
||||
} else {
|
||||
|
@ -30,8 +30,8 @@ pub struct LayerState {
|
||||
impl Clone for LayerState {
|
||||
fn clone(&self) -> Self {
|
||||
LayerState {
|
||||
prev_key: self.prev_key.copy(),
|
||||
prev_value: self.prev_value.copy(),
|
||||
prev_key: self.prev_key.shallow_clone(),
|
||||
prev_value: self.prev_value.shallow_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -176,15 +176,15 @@ impl T5Attention {
|
||||
k = Tensor::cat(&[&layer_state.prev_key, &k], 2);
|
||||
v = Tensor::cat(&[&layer_state.prev_value, &v], 2);
|
||||
} else {
|
||||
k = layer_state.prev_key.copy();
|
||||
v = layer_state.prev_value.copy();
|
||||
k = layer_state.prev_key.shallow_clone();
|
||||
v = layer_state.prev_value.shallow_clone();
|
||||
}
|
||||
};
|
||||
|
||||
layer_state = if self.is_decoder & self.store_cache {
|
||||
Some(LayerState {
|
||||
prev_key: k.copy(),
|
||||
prev_value: v.copy(),
|
||||
prev_key: k.shallow_clone(),
|
||||
prev_value: v.shallow_clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
@ -434,7 +434,7 @@ impl T5Stack {
|
||||
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
|
||||
let encoder_hidden_states_shape = encoder_hidden_states.size();
|
||||
let encoder_mask = match encoder_attention_mask {
|
||||
Some(value) => value.copy(),
|
||||
Some(value) => value.shallow_clone(),
|
||||
None => Tensor::ones(
|
||||
&[
|
||||
encoder_hidden_states_shape[0],
|
||||
|
@ -29,7 +29,7 @@ pub struct LayerState {
|
||||
impl Clone for LayerState {
|
||||
fn clone(&self) -> Self {
|
||||
LayerState {
|
||||
prev_content: self.prev_content.copy(),
|
||||
prev_content: self.prev_content.shallow_clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ impl XLNetFeedForward {
|
||||
|
||||
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
|
||||
let output = input.apply(&self.layer_1);
|
||||
let output: Tensor = (&self.activation.get_fn())(&output);
|
||||
let output: Tensor = self.activation.get_fn()(&output);
|
||||
let output = output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.layer_2)
|
||||
|
@ -516,9 +516,9 @@ impl XLNetModel {
|
||||
&[m_len, batch_size],
|
||||
(Kind::Int64, token_type_ids_value.device()),
|
||||
);
|
||||
Tensor::cat(&[mem_pad, token_type_ids_value.copy()], 0)
|
||||
Tensor::cat(&[mem_pad, token_type_ids_value.shallow_clone()], 0)
|
||||
} else {
|
||||
token_type_ids_value.copy()
|
||||
token_type_ids_value.shallow_clone()
|
||||
};
|
||||
let seg_mat = token_type_ids_value
|
||||
.unsqueeze(-1)
|
||||
|
Loading…
Reference in New Issue
Block a user