Shallow clone optimization (#243)

* Fixed Clippy warnings

* Shallow clone optimization (reduce tensor copy)

* Updated changelog and fixed Clippy warnings
This commit is contained in:
guillaume-be 2022-04-10 08:52:37 +01:00 committed by GitHub
parent 6f1888e8f9
commit ba584653bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 57 additions and 58 deletions

View File

@ -130,7 +130,7 @@ impl AlbertLayerGroup {
None
};
let mut hidden_state = hidden_states.copy();
let mut hidden_state = hidden_states.shallow_clone();
let mut attention_weights: Option<Tensor>;
for layer in &self.layers {

View File

@ -28,8 +28,8 @@ pub struct LayerState {
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.copy(),
prev_key: self.prev_key.shallow_clone(),
prev_value: self.prev_value.shallow_clone(),
}
}
}
@ -138,8 +138,8 @@ impl BartAttention {
let new_layer_state = if self.store_cache {
Some(LayerState {
prev_key: key_states.copy(),
prev_value: value_states.copy(),
prev_key: key_states.shallow_clone(),
prev_value: value_states.shallow_clone(),
})
} else {
None

View File

@ -109,7 +109,7 @@ impl EncoderLayer {
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let residual = output.copy();
let residual = output.shallow_clone();
let output = (self.activation.get_fn())(&output.apply(&self.fc1));
let output = output
.apply_t(&self.activation_dropout, train)

View File

@ -337,7 +337,7 @@ impl<T: BertEmbedding> BertModel<T> {
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
let encoder_hidden_states_shape = encoder_hidden_states.size();
let encoder_mask = match encoder_mask {
Some(value) => value.copy(),
Some(value) => value.shallow_clone(),
None => Tensor::ones(
&[
encoder_hidden_states_shape[0],
@ -416,7 +416,7 @@ impl BertPredictionHeadTransform {
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
((&self.activation.get_fn())(&hidden_states.apply(&self.dense))).apply(&self.layer_norm)
(self.activation.get_fn()(&hidden_states.apply(&self.dense))).apply(&self.layer_norm)
}
}

View File

@ -205,7 +205,7 @@ impl DebertaDisentangledSelfAttention {
let r_pos = if query_layer_size[1] != key_layer_size[1] {
build_relative_position(key_layer_size[1], key_layer_size[1], query_layer.device())
} else {
relative_pos.copy()
relative_pos.shallow_clone()
};
let p2c_pos = (-r_pos + attention_span).clamp(0, attention_span * 2 - 1);
let mut p2c_att = key_layer

View File

@ -333,10 +333,10 @@ impl DebertaV2Encoder {
}
attention_weights = layer_output.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
attentions.push(attention_weights.as_ref().unwrap().shallow_clone());
};
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(output_states.as_ref().unwrap().copy());
hidden_states.push(output_states.as_ref().unwrap().shallow_clone());
};
}

View File

@ -150,7 +150,6 @@ impl Transformer {
None
};
// let mut hidden_state = input.copy();
let mut hidden_state: Option<Tensor> = None;
let mut attention_weights: Option<Tensor>;

View File

@ -175,7 +175,7 @@ impl Attention {
}
None => (key, value),
};
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
let present = Tensor::stack(&[key.transpose(-2, -1), value.shallow_clone()], 0);
let (a, attentions) = self.attention(&query, &key, &value, attention_mask, train);
let a = self

View File

@ -386,7 +386,7 @@ impl Gpt2Model {
(
value
.iter()
.map(|v| Some(v.copy()))
.map(|v| Some(v.shallow_clone()))
.collect::<Vec<Option<Tensor>>>(),
value[0].size()[3],
)
@ -399,7 +399,7 @@ impl Gpt2Model {
};
let position_ids = match position_ids {
Some(value) => value.copy(),
Some(value) => value.shallow_clone(),
None => Tensor::arange_start(
layer_past_length,
seq_length + layer_past_length,

View File

@ -29,8 +29,8 @@ pub struct LayerState {
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.as_ref().map(|value| value.copy()),
prev_key: self.prev_key.shallow_clone(),
prev_value: self.prev_value.as_ref().map(|value| value.shallow_clone()),
}
}
}
@ -211,8 +211,8 @@ impl GptNeoSelfAttention {
};
let layer_state = Some(LayerState {
prev_key: key.copy(),
prev_value: Some(value.copy()),
prev_key: key.shallow_clone(),
prev_value: Some(value.shallow_clone()),
});
let (attention_output, attention_weights) =

View File

@ -107,7 +107,7 @@ impl MBartEncoderLayer {
.forward_t(&output, None, encoder_attention_mask, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let residual = output.copy();
let residual = output.shallow_clone();
let output = output.apply(&self.final_layer_norm);
let output = (self.activation.get_fn())(&output.apply(&self.fc1));
let output = output

View File

@ -221,7 +221,7 @@ impl OpenAiGptModel {
let seq_length = input_shape[1];
let position_ids = match position_ids {
Some(value) => value.copy(),
Some(value) => value.shallow_clone(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
};

View File

@ -507,7 +507,7 @@ pub(crate) mod private_generation_utils {
&Tensor::arange_start(1, vocab_size, (Int64, logits.device())),
&sorted_indices_to_remove
.slice(1, 0, vocab_size - 1, 1)
.copy(),
.shallow_clone(),
);
let _ = sorted_indices_to_remove.index_fill_(
1,
@ -749,8 +749,8 @@ pub(crate) mod private_generation_utils {
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
self.split_bad_word_ids(gen_opt.bad_word_ids);
let mut static_bad_words_mask: Option<Tensor> = None;
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut attention_mask = attention_mask.shallow_clone();
let mut input_ids = input_ids.shallow_clone();
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
@ -759,10 +759,10 @@ pub(crate) mod private_generation_utils {
while current_length < gen_opt.max_length {
let prepared_input = self.prepare_inputs_for_generation(
input_ids.copy(),
input_ids.shallow_clone(),
encoder_outputs.as_ref(),
past,
attention_mask.copy(),
attention_mask.shallow_clone(),
);
let temp = self
.get_model()
@ -1020,10 +1020,10 @@ pub(crate) mod private_generation_utils {
);
}
let prepared_input = self.prepare_inputs_for_generation(
input_ids.copy(),
input_ids.shallow_clone(),
encoder_outputs.as_ref(),
past,
attention_mask.copy(),
attention_mask.shallow_clone(),
);
let temp = self
.get_model()
@ -1240,10 +1240,10 @@ pub(crate) mod private_generation_utils {
saved_beam_scores.as_ref().map(|step_wise_scores| {
Tensor::stack(step_wise_scores, 1)
.get(effective_beam_id)
.copy()
.shallow_clone()
});
hypotheses[batch_index as usize].add(
input_ids.get(effective_beam_id).copy(),
input_ids.get(effective_beam_id).shallow_clone(),
beam_token_score,
saved_beam_scores,
);
@ -1301,7 +1301,7 @@ pub(crate) mod private_generation_utils {
}
if let Some(scores_output) = saved_beam_scores.as_mut() {
scores_output.push(beam_scores.copy());
scores_output.push(beam_scores.shallow_clone());
}
if done.iter().all(|&x| x) {
break;
@ -2105,10 +2105,10 @@ impl Clone for BeamHypotheses {
.map(|(score, tensor, scores_tensor)| {
(
*score,
tensor.copy(),
tensor.shallow_clone(),
scores_tensor
.as_ref()
.map(|scores_tensor| scores_tensor.copy()),
.map(|scores_tensor| scores_tensor.shallow_clone()),
)
})
.collect::<Vec<(f64, Tensor, Option<Tensor>)>>(),

View File

@ -31,8 +31,8 @@ pub struct LayerState {
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.copy(),
prev_key: self.prev_key.shallow_clone(),
prev_value: self.prev_value.shallow_clone(),
}
}
}

View File

@ -27,7 +27,7 @@ use tch::{nn, Device, Kind, Tensor};
fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor {
let left_block = Tensor::ones(&[ngram, sequence_length, sequence_length], (kind, device))
* get_negative_infinity(kind).unwrap();
let right_block = left_block.copy();
let right_block = left_block.shallow_clone();
for stream_idx in 0..ngram {
let _ = right_block.get(stream_idx).fill_diagonal_(0, false);
let _ = left_block.get(stream_idx).triu_(-stream_idx + 1);

View File

@ -68,7 +68,7 @@ impl ProphetNetPositionalEmbeddings {
attention_mask.cumsum(1, Kind::Int64) * attention_mask + self.padding_idx
}
}
Some(value) => value.copy(),
Some(value) => value.shallow_clone(),
};
(calc_position_ids.apply(&self.embeddings), calc_position_ids)

View File

@ -40,10 +40,10 @@ impl Clone for LayerState {
let prev_buckets = self
.prev_buckets
.as_ref()
.map(|prev_buckets| prev_buckets.copy());
.map(|prev_buckets| prev_buckets.shallow_clone());
LayerState {
prev_buckets,
prev_states: self.prev_states.copy(),
prev_states: self.prev_states.shallow_clone(),
}
}
}
@ -417,7 +417,7 @@ impl LSHSelfAttention {
None,
)?;
let key_value_bucket_idx = look_adjacent(
query_bucket_idx.copy(),
query_bucket_idx.shallow_clone(),
self.num_chunks_before,
self.num_chunks_after,
);
@ -446,7 +446,7 @@ impl LSHSelfAttention {
(query_bucket_idx, key_value_bucket_idx)
} else {
(
sorted_bucket_indices_per_hash.copy(),
sorted_bucket_indices_per_hash.shallow_clone(),
sorted_bucket_indices_per_hash,
)
};
@ -849,7 +849,7 @@ impl LSHSelfAttention {
false
}
} {
(sorted_bucket_idx.unwrap().copy(), None)
(sorted_bucket_idx.unwrap().shallow_clone(), None)
} else {
(
Tensor::arange(sequence_length, (Kind::Int64, query_key_vectors.device()))
@ -1121,7 +1121,7 @@ impl LocalSelfAttention {
self.num_attention_heads,
None,
)?;
let key_indices = query_indices.copy();
let key_indices = query_indices.shallow_clone();
key_vectors = look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after);
value_vectors =
@ -1130,7 +1130,7 @@ impl LocalSelfAttention {
look_adjacent(key_indices, self.num_chunks_before, self.num_chunks_after);
(query_indices, key_indices)
} else {
(indices.copy(), indices.copy())
(indices.shallow_clone(), indices.shallow_clone())
};
let mut query_key_dots = query_vectors.matmul(&key_vectors.transpose(-1, -2));
@ -1356,7 +1356,7 @@ impl ReformerAttention {
if original_sequence_length > 1 {
Some(buckets_value.slice(3, 0, original_sequence_length, 1))
} else {
Some(buckets_value.copy())
Some(buckets_value.shallow_clone())
}
} else {
Some(Tensor::cat(

View File

@ -291,8 +291,8 @@ impl ReformerEncoder {
original_sequence_length: i64,
train: bool,
) -> Result<ReformerModelOutput, RustBertError> {
let mut hidden_state = hidden_states.copy();
let mut attention_output = hidden_states.copy();
let mut hidden_state = hidden_states.shallow_clone();
let mut attention_output = hidden_states.shallow_clone();
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.layers.len()))
} else {

View File

@ -30,8 +30,8 @@ pub struct LayerState {
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.copy(),
prev_key: self.prev_key.shallow_clone(),
prev_value: self.prev_value.shallow_clone(),
}
}
}
@ -176,15 +176,15 @@ impl T5Attention {
k = Tensor::cat(&[&layer_state.prev_key, &k], 2);
v = Tensor::cat(&[&layer_state.prev_value, &v], 2);
} else {
k = layer_state.prev_key.copy();
v = layer_state.prev_value.copy();
k = layer_state.prev_key.shallow_clone();
v = layer_state.prev_value.shallow_clone();
}
};
layer_state = if self.is_decoder & self.store_cache {
Some(LayerState {
prev_key: k.copy(),
prev_value: v.copy(),
prev_key: k.shallow_clone(),
prev_value: v.shallow_clone(),
})
} else {
None

View File

@ -434,7 +434,7 @@ impl T5Stack {
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
let encoder_hidden_states_shape = encoder_hidden_states.size();
let encoder_mask = match encoder_attention_mask {
Some(value) => value.copy(),
Some(value) => value.shallow_clone(),
None => Tensor::ones(
&[
encoder_hidden_states_shape[0],

View File

@ -29,7 +29,7 @@ pub struct LayerState {
impl Clone for LayerState {
fn clone(&self) -> Self {
LayerState {
prev_content: self.prev_content.copy(),
prev_content: self.prev_content.shallow_clone(),
}
}
}

View File

@ -67,7 +67,7 @@ impl XLNetFeedForward {
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
let output = input.apply(&self.layer_1);
let output: Tensor = (&self.activation.get_fn())(&output);
let output: Tensor = self.activation.get_fn()(&output);
let output = output
.apply_t(&self.dropout, train)
.apply(&self.layer_2)

View File

@ -516,9 +516,9 @@ impl XLNetModel {
&[m_len, batch_size],
(Kind::Int64, token_type_ids_value.device()),
);
Tensor::cat(&[mem_pad, token_type_ids_value.copy()], 0)
Tensor::cat(&[mem_pad, token_type_ids_value.shallow_clone()], 0)
} else {
token_type_ids_value.copy()
token_type_ids_value.shallow_clone()
};
let seg_mat = token_type_ids_value
.unsqueeze(-1)