From afa03cd46efe9638828d10587aff9b4fcb37aec0 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 30 Jun 2020 19:47:08 +0200 Subject: [PATCH] Updated tch::nn::Path handling to be consistent across crate --- src/albert/albert.rs | 2 +- src/albert/attention.rs | 22 +++++++------- src/albert/embeddings.rs | 1 - src/bart/attention.rs | 20 ++++++++----- src/bart/bart.rs | 17 ++++++----- src/bart/decoder.rs | 37 +++++++++++++++--------- src/bart/embeddings.rs | 29 +++++++++++++------ src/bart/encoder.rs | 33 +++++++++++++-------- src/bert/attention.rs | 12 +++++--- src/bert/bert.rs | 12 ++++---- src/bert/embeddings.rs | 2 +- src/distilbert/attention.rs | 15 ++++++---- src/distilbert/distilbert.rs | 14 ++++----- src/distilbert/transformer.rs | 10 +++++-- src/electra/electra.rs | 6 ++-- src/gpt2/gpt2.rs | 4 +-- src/marian/marian.rs | 13 ++++++--- src/openai_gpt/openai_gpt.rs | 4 +-- src/pipelines/sequence_classification.rs | 8 +++-- src/pipelines/token_classification.rs | 8 +++-- src/roberta/embeddings.rs | 2 +- src/roberta/roberta.rs | 10 +++---- 22 files changed, 173 insertions(+), 108 deletions(-) diff --git a/src/albert/albert.rs b/src/albert/albert.rs index fad88b2..834bdda 100644 --- a/src/albert/albert.rs +++ b/src/albert/albert.rs @@ -138,7 +138,7 @@ impl AlbertModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = AlbertConfig::from_file(config_path); - /// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config); + /// let albert: AlbertModel = AlbertModel::new(&p.root() / "albert", &config); /// ``` pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertModel where diff --git a/src/albert/attention.rs b/src/albert/attention.rs index 8fd17c8..6b51c0b 100644 --- a/src/albert/attention.rs +++ b/src/albert/attention.rs @@ -13,6 +13,7 @@ use crate::albert::AlbertConfig; use crate::common::dropout::Dropout; +use std::borrow::Borrow; use tch::kind::Kind::Float; use tch::{nn, Tensor}; @@ -31,33 +32,37 @@ pub struct AlbertSelfAttention { } impl AlbertSelfAttention { - pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention { + pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertSelfAttention + where + P: Borrow>, + { assert_eq!( config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads" ); + let p = p.borrow(); let query = nn::linear( - &p / "query", + p / "query", config.hidden_size, config.hidden_size, Default::default(), ); let key = nn::linear( - &p / "key", + p / "key", config.hidden_size, config.hidden_size, Default::default(), ); let value = nn::linear( - &p / "value", + p / "value", config.hidden_size, config.hidden_size, Default::default(), ); let dense = nn::linear( - &p / "dense", + p / "dense", config.hidden_size, config.hidden_size, Default::default(), @@ -76,11 +81,8 @@ impl AlbertSelfAttention { eps: layer_norm_eps, ..Default::default() }; - let layer_norm = nn::layer_norm( - &p / "LayerNorm", - vec![config.hidden_size], - layer_norm_config, - ); + let layer_norm = + nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config); AlbertSelfAttention { num_attention_heads: config.num_attention_heads, diff --git a/src/albert/embeddings.rs b/src/albert/embeddings.rs index 9372f50..c05c648 100644 --- a/src/albert/embeddings.rs +++ b/src/albert/embeddings.rs @@ -19,7 +19,6 @@ use tch::{nn, Kind, Tensor}; /// # Embeddings implementation for Albert model #[derive(Debug)] -/// # Embeddings implementation for Electra model pub struct AlbertEmbeddings { word_embeddings: nn::Embedding, position_embeddings: nn::Embedding, diff --git a/src/bart/attention.rs b/src/bart/attention.rs index afdc65b..394916d 100644 --- a/src/bart/attention.rs +++ b/src/bart/attention.rs @@ -12,6 +12,7 @@ // limitations under the License. use crate::common::dropout::Dropout; +use std::borrow::Borrow; use tch::kind::Kind::Float; use tch::{nn, Tensor}; @@ -72,19 +73,24 @@ pub struct SelfAttention { } impl SelfAttention { - pub fn new( - p: nn::Path, + pub fn new<'p, P>( + p: P, embed_dim: i64, num_heads: i64, dropout: f64, encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool, - ) -> SelfAttention { - let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default()); - let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default()); - let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default()); - let out_proj = nn::linear(&p / "out_proj", embed_dim, embed_dim, Default::default()); + ) -> SelfAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let k_proj = nn::linear(p / "k_proj", embed_dim, embed_dim, Default::default()); + let v_proj = nn::linear(p / "v_proj", embed_dim, embed_dim, Default::default()); + let q_proj = nn::linear(p / "q_proj", embed_dim, embed_dim, Default::default()); + let out_proj = nn::linear(p / "out_proj", embed_dim, embed_dim, Default::default()); let head_dim = embed_dim / num_heads; let scaling = (head_dim as f64).powf(-0.5); diff --git a/src/bart/bart.rs b/src/bart/bart.rs index 0b084d7..d46721a 100644 --- a/src/bart/bart.rs +++ b/src/bart/bart.rs @@ -246,7 +246,7 @@ impl BartModel { /// let p = nn::VarStore::new(device); /// let config = BartConfig::from_file(config_path); /// let generation_mode = true; - /// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode); + /// let bart: BartModel = BartModel::new(&p.root() / "bart", &config, generation_mode); /// ``` pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartModel where @@ -452,14 +452,17 @@ impl BartForConditionalGeneration { /// let config = BartConfig::from_file(config_path); /// let generation_mode = true; /// let bart: BartForConditionalGeneration = - /// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode); + /// BartForConditionalGeneration::new(&p.root() / "bart", &config, generation_mode); /// ``` - pub fn new( - p: &nn::Path, + pub fn new<'p, P>( + p: P, config: &BartConfig, generation_mode: bool, - ) -> BartForConditionalGeneration { - let base_model = BartModel::new(p / "model", config, generation_mode); + ) -> BartForConditionalGeneration + where + P: Borrow>, + { + let base_model = BartModel::new(p.borrow() / "model", config, generation_mode); BartForConditionalGeneration { base_model } } @@ -653,7 +656,7 @@ impl BartForSequenceClassification { /// let config = BartConfig::from_file(config_path); /// let generation_mode = true; /// let bart: BartForSequenceClassification = - /// BartForSequenceClassification::new(&(&p.root() / "bart"), &config); + /// BartForSequenceClassification::new(&p.root() / "bart", &config); /// ``` pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification where diff --git a/src/bart/decoder.rs b/src/bart/decoder.rs index ecef95c..c6bf2d7 100644 --- a/src/bart/decoder.rs +++ b/src/bart/decoder.rs @@ -19,7 +19,7 @@ use crate::bart::embeddings::{ use crate::bart::BartConfig; use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh}; use crate::common::dropout::Dropout; -use std::borrow::BorrowMut; +use std::borrow::{Borrow, BorrowMut}; use tch::kind::Kind::Int64; use tch::{nn, Tensor}; @@ -37,7 +37,12 @@ pub struct DecoderLayer { } impl DecoderLayer { - pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer { + pub fn new<'p, P>(p: P, config: &BartConfig) -> DecoderLayer + where + P: Borrow>, + { + let p = p.borrow(); + let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() @@ -47,7 +52,7 @@ impl DecoderLayer { None => false, }; let self_attention = SelfAttention::new( - &p / "self_attn", + p / "self_attn", config.d_model, config.decoder_attention_heads, config.attention_dropout, @@ -56,7 +61,7 @@ impl DecoderLayer { output_attention, ); let encoder_attention = SelfAttention::new( - &p / "encoder_attn", + p / "encoder_attn", config.d_model, config.decoder_attention_heads, config.attention_dropout, @@ -65,12 +70,12 @@ impl DecoderLayer { output_attention, ); let self_attention_layer_norm = nn::layer_norm( - &p / "self_attn_layer_norm", + p / "self_attn_layer_norm", vec![config.d_model], layer_norm_config, ); let encoder_attention_layer_norm = nn::layer_norm( - &p / "encoder_attn_layer_norm", + p / "encoder_attn_layer_norm", vec![config.d_model], layer_norm_config, ); @@ -89,20 +94,20 @@ impl DecoderLayer { Activation::tanh => _tanh, }); let fc1 = nn::linear( - &p / "fc1", + p / "fc1", config.d_model, config.decoder_ffn_dim, Default::default(), ); let fc2 = nn::linear( - &p / "fc2", + p / "fc2", config.decoder_ffn_dim, config.d_model, Default::default(), ); let final_layer_norm = nn::layer_norm( - &p / "final_layer_norm", + p / "final_layer_norm", vec![config.d_model], layer_norm_config, ); @@ -182,7 +187,11 @@ pub struct BartDecoder { } impl BartDecoder { - pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder { + pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartDecoder + where + P: Borrow>, + { + let p = p.borrow(); let output_past = match config.output_past { Some(value) => value, None => true, @@ -222,7 +231,7 @@ impl BartDecoder { ..Default::default() }; Some(nn::layer_norm( - &p / "layernorm_embedding", + p / "layernorm_embedding", vec![config.d_model], layer_norm_config, )) @@ -237,13 +246,13 @@ impl BartDecoder { let embed_positions = if static_position_embeddings { EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new( - &p / "embed_positions", + p / "embed_positions", config.max_position_embeddings, config.d_model, )) } else { EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new( - &p / "embed_positions", + p / "embed_positions", config.max_position_embeddings, config.d_model, pad_token_id, @@ -251,7 +260,7 @@ impl BartDecoder { }; let mut layers: Vec = vec![]; - let p_layers = &p / "layers"; + let p_layers = p / "layers"; for layer_index in 0..config.decoder_layers { layers.push(DecoderLayer::new(&p_layers / layer_index, config)); } diff --git a/src/bart/embeddings.rs b/src/bart/embeddings.rs index 46000f0..56dac93 100644 --- a/src/bart/embeddings.rs +++ b/src/bart/embeddings.rs @@ -11,6 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Borrow; use tch::kind::Kind::Int64; use tch::nn::{embedding, EmbeddingConfig}; use tch::{nn, Tensor}; @@ -43,12 +44,15 @@ pub struct LearnedPositionalEmbedding { } impl LearnedPositionalEmbedding { - pub fn new( - p: nn::Path, + pub fn new<'p, P>( + p: P, num_embeddings: i64, embedding_dim: i64, padding_index: i64, - ) -> LearnedPositionalEmbedding { + ) -> LearnedPositionalEmbedding + where + P: Borrow>, + { let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() @@ -56,7 +60,7 @@ impl LearnedPositionalEmbedding { let num_embeddings = num_embeddings + padding_index + 1; let embedding: nn::Embedding = - embedding(p, num_embeddings, embedding_dim, embedding_config); + embedding(p.borrow(), num_embeddings, embedding_dim, embedding_config); LearnedPositionalEmbedding { embedding, padding_index, @@ -86,13 +90,20 @@ pub struct SinusoidalPositionalEmbedding { } impl SinusoidalPositionalEmbedding { - pub fn new( - p: nn::Path, + pub fn new<'p, P>( + p: P, num_embeddings: i64, embedding_dim: i64, - ) -> SinusoidalPositionalEmbedding { - let embedding: nn::Embedding = - embedding(p, num_embeddings, embedding_dim, Default::default()); + ) -> SinusoidalPositionalEmbedding + where + P: Borrow>, + { + let embedding: nn::Embedding = embedding( + p.borrow(), + num_embeddings, + embedding_dim, + Default::default(), + ); SinusoidalPositionalEmbedding { embedding } } diff --git a/src/bart/encoder.rs b/src/bart/encoder.rs index 20adff3..784ed1e 100644 --- a/src/bart/encoder.rs +++ b/src/bart/encoder.rs @@ -19,7 +19,7 @@ use crate::bart::embeddings::{ use crate::bart::BartConfig; use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh}; use crate::common::dropout::Dropout; -use std::borrow::BorrowMut; +use std::borrow::{Borrow, BorrowMut}; use tch::kind::Kind::Bool; use tch::{nn, Tensor}; @@ -35,7 +35,12 @@ pub struct EncoderLayer { } impl EncoderLayer { - pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer { + pub fn new<'p, P>(p: P, config: &BartConfig) -> EncoderLayer + where + P: Borrow>, + { + let p = p.borrow(); + let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() @@ -45,7 +50,7 @@ impl EncoderLayer { None => false, }; let self_attention = SelfAttention::new( - &p / "self_attn", + p / "self_attn", config.d_model, config.encoder_attention_heads, config.attention_dropout, @@ -54,7 +59,7 @@ impl EncoderLayer { output_attention, ); let self_attention_layer_norm = nn::layer_norm( - &p / "self_attn_layer_norm", + p / "self_attn_layer_norm", vec![config.d_model], layer_norm_config, ); @@ -72,20 +77,20 @@ impl EncoderLayer { Activation::tanh => _tanh, }); let fc1 = nn::linear( - &p / "fc1", + p / "fc1", config.d_model, config.encoder_ffn_dim, Default::default(), ); let fc2 = nn::linear( - &p / "fc2", + p / "fc2", config.encoder_ffn_dim, config.d_model, Default::default(), ); let final_layer_norm = nn::layer_norm( - &p / "final_layer_norm", + p / "final_layer_norm", vec![config.d_model], layer_norm_config, ); @@ -136,7 +141,11 @@ pub struct BartEncoder { } impl BartEncoder { - pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder { + pub fn new<'p, P>(p: P, config: &BartConfig) -> BartEncoder + where + P: Borrow>, + { + let p = p.borrow(); let output_attentions = match config.output_attentions { Some(value) => value, None => false, @@ -172,7 +181,7 @@ impl BartEncoder { ..Default::default() }; Some(nn::layer_norm( - &p / "layernorm_embedding", + p / "layernorm_embedding", vec![config.d_model], layer_norm_config, )) @@ -187,13 +196,13 @@ impl BartEncoder { let embed_positions = if static_position_embeddings { EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new( - &p / "embed_positions", + p / "embed_positions", config.max_position_embeddings, config.d_model, )) } else { EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new( - &p / "embed_positions", + p / "embed_positions", config.max_position_embeddings, config.d_model, pad_token_id, @@ -201,7 +210,7 @@ impl BartEncoder { }; let mut layers: Vec = vec![]; - let p_layers = &p / "layers"; + let p_layers = p / "layers"; for layer_index in 0..config.encoder_layers { layers.push(EncoderLayer::new(&p_layers / layer_index, config)); } diff --git a/src/bert/attention.rs b/src/bert/attention.rs index 4f8615c..8fedb00 100644 --- a/src/bert/attention.rs +++ b/src/bert/attention.rs @@ -30,27 +30,31 @@ pub struct BertSelfAttention { } impl BertSelfAttention { - pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention { + pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfAttention + where + P: Borrow>, + { assert_eq!( config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads" ); + let p = p.borrow(); let query = nn::linear( - &p / "query", + p / "query", config.hidden_size, config.hidden_size, Default::default(), ); let key = nn::linear( - &p / "key", + p / "key", config.hidden_size, config.hidden_size, Default::default(), ); let value = nn::linear( - &p / "value", + p / "value", config.hidden_size, config.hidden_size, Default::default(), diff --git a/src/bert/bert.rs b/src/bert/bert.rs index 6068195..27d28d5 100644 --- a/src/bert/bert.rs +++ b/src/bert/bert.rs @@ -145,7 +145,7 @@ impl BertModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert: BertModel = BertModel::new(&(&p.root() / "bert"), &config); + /// let bert: BertModel = BertModel::new(&p.root() / "bert", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel where @@ -442,7 +442,7 @@ impl BertForMaskedLM { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert = BertForMaskedLM::new(&(&p.root() / "bert"), &config); + /// let bert = BertForMaskedLM::new(&p.root() / "bert", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM where @@ -569,7 +569,7 @@ impl BertForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert = BertForSequenceClassification::new(&(&p.root() / "bert"), &config); + /// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification where @@ -709,7 +709,7 @@ impl BertForMultipleChoice { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert = BertForMultipleChoice::new(&(&p.root() / "bert"), &config); + /// let bert = BertForMultipleChoice::new(&p.root() / "bert", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice where @@ -852,7 +852,7 @@ impl BertForTokenClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert = BertForTokenClassification::new(&(&p.root() / "bert"), &config); + /// let bert = BertForTokenClassification::new(&p.root() / "bert", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForTokenClassification where @@ -991,7 +991,7 @@ impl BertForQuestionAnswering { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert = BertForQuestionAnswering::new(&(&p.root() / "bert"), &config); + /// let bert = BertForQuestionAnswering::new(&p.root() / "bert", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering where diff --git a/src/bert/embeddings.rs b/src/bert/embeddings.rs index 9005d69..84c2186 100644 --- a/src/bert/embeddings.rs +++ b/src/bert/embeddings.rs @@ -65,7 +65,7 @@ impl BertEmbedding for BertEmbeddings { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config); + /// let bert_embeddings = BertEmbeddings::new(&p.root() / "bert_embeddings", &config); /// ``` fn new<'p, P>(p: P, config: &BertConfig) -> BertEmbeddings where diff --git a/src/distilbert/attention.rs b/src/distilbert/attention.rs index 636af3b..5108faa 100644 --- a/src/distilbert/attention.rs +++ b/src/distilbert/attention.rs @@ -12,6 +12,7 @@ use crate::common::dropout::Dropout; use crate::distilbert::distilbert::DistilBertConfig; +use std::borrow::Borrow; use tch::kind::Kind::Float; use tch::{nn, Tensor}; @@ -28,11 +29,15 @@ pub struct MultiHeadSelfAttention { } impl MultiHeadSelfAttention { - pub fn new(p: nn::Path, config: &DistilBertConfig) -> MultiHeadSelfAttention { - let q_lin = nn::linear(&p / "q_lin", config.dim, config.dim, Default::default()); - let k_lin = nn::linear(&p / "k_lin", config.dim, config.dim, Default::default()); - let v_lin = nn::linear(&p / "v_lin", config.dim, config.dim, Default::default()); - let out_lin = nn::linear(&p / "out_lin", config.dim, config.dim, Default::default()); + pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> MultiHeadSelfAttention + where + P: Borrow>, + { + let p = p.borrow(); + let q_lin = nn::linear(p / "q_lin", config.dim, config.dim, Default::default()); + let k_lin = nn::linear(p / "k_lin", config.dim, config.dim, Default::default()); + let v_lin = nn::linear(p / "v_lin", config.dim, config.dim, Default::default()); + let out_lin = nn::linear(p / "out_lin", config.dim, config.dim, Default::default()); let dropout = Dropout::new(config.attention_dropout); diff --git a/src/distilbert/distilbert.rs b/src/distilbert/distilbert.rs index 3d0769d..91183e2 100644 --- a/src/distilbert/distilbert.rs +++ b/src/distilbert/distilbert.rs @@ -154,15 +154,15 @@ impl DistilBertModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); - /// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config); + /// let distil_bert: DistilBertModel = DistilBertModel::new(&p.root() / "distilbert", &config); /// ``` pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel where P: Borrow>, { let p = p.borrow() / "distilbert"; - let embeddings = DistilBertEmbedding::new(&p / "embeddings", config); - let transformer = Transformer::new(&p / "transformer", config); + let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config); + let transformer = Transformer::new(p.borrow() / "transformer", config); DistilBertModel { embeddings, transformer, @@ -269,7 +269,7 @@ impl DistilBertModelClassifier { /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); /// let distil_bert: DistilBertModelClassifier = - /// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config); + /// DistilBertModelClassifier::new(&p.root() / "distilbert", &config); /// ``` pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier where @@ -404,7 +404,7 @@ impl DistilBertModelMaskedLM { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); - /// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config); + /// let distil_bert = DistilBertModelMaskedLM::new(&p.root() / "distilbert", &config); /// ``` pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM where @@ -538,7 +538,7 @@ impl DistilBertForQuestionAnswering { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); - /// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config); + /// let distil_bert = DistilBertForQuestionAnswering::new(&p.root() / "distilbert", &config); /// ``` pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering where @@ -656,7 +656,7 @@ impl DistilBertForTokenClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = DistilBertConfig::from_file(config_path); - /// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config); + /// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config); /// ``` pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForTokenClassification where diff --git a/src/distilbert/transformer.rs b/src/distilbert/transformer.rs index 9a82cb2..d78aa72 100644 --- a/src/distilbert/transformer.rs +++ b/src/distilbert/transformer.rs @@ -26,15 +26,19 @@ pub struct FeedForwardNetwork { } impl FeedForwardNetwork { - pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork { + pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> FeedForwardNetwork + where + P: Borrow>, + { + let p = p.borrow(); let lin1 = nn::linear( - &p / "lin1", + p / "lin1", config.dim, config.hidden_dim, Default::default(), ); let lin2 = nn::linear( - &p / "lin2", + p / "lin2", config.hidden_dim, config.dim, Default::default(), diff --git a/src/electra/electra.rs b/src/electra/electra.rs index 31c4642..1175e4e 100644 --- a/src/electra/electra.rs +++ b/src/electra/electra.rs @@ -130,7 +130,7 @@ impl ElectraModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = ElectraConfig::from_file(config_path); - /// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config); + /// let electra_model: ElectraModel = ElectraModel::new(&p.root() / "electra", &config); /// ``` pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel where @@ -325,7 +325,7 @@ impl ElectraDiscriminatorHead { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = ElectraConfig::from_file(config_path); - /// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config); + /// let discriminator_head = ElectraDiscriminatorHead::new(&p.root() / "electra", &config); /// ``` pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead where @@ -430,7 +430,7 @@ impl ElectraGeneratorHead { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = ElectraConfig::from_file(config_path); - /// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config); + /// let generator_head = ElectraGeneratorHead::new(&p.root() / "electra", &config); /// ``` pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead where diff --git a/src/gpt2/gpt2.rs b/src/gpt2/gpt2.rs index 7d505f7..5ee55ce 100644 --- a/src/gpt2/gpt2.rs +++ b/src/gpt2/gpt2.rs @@ -245,7 +245,7 @@ impl Gpt2Model { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = Gpt2Config::from_file(config_path); - /// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config); + /// let gpt2: Gpt2Model = Gpt2Model::new(&p.root() / "gpt2", &config); /// ``` pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model where @@ -533,7 +533,7 @@ impl GPT2LMHeadModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = Gpt2Config::from_file(config_path); - /// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config); + /// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&p.root() / "gpt2", &config); /// ``` pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel where diff --git a/src/marian/marian.rs b/src/marian/marian.rs index d7b7dbb..1427d57 100644 --- a/src/marian/marian.rs +++ b/src/marian/marian.rs @@ -13,6 +13,7 @@ use crate::bart::{BartConfig, BartModel, LayerState}; use crate::pipelines::generation::{Cache, LMHeadModel}; +use std::borrow::Borrow; use tch::nn::Init; use tch::{nn, Tensor}; @@ -257,13 +258,17 @@ impl MarianForConditionalGeneration { /// let config = BartConfig::from_file(config_path); /// let generation_mode = true; /// let bart: BartForConditionalGeneration = - /// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode); + /// BartForConditionalGeneration::new(&p.root() / "bart", &config, generation_mode); /// ``` - pub fn new( - p: &nn::Path, + pub fn new<'p, P>( + p: P, config: &BartConfig, generation_mode: bool, - ) -> MarianForConditionalGeneration { + ) -> MarianForConditionalGeneration + where + P: Borrow>, + { + let p = p.borrow(); let base_model = BartModel::new(p / "model", config, generation_mode); let final_logits_bias = p.var( "final_logits_bias", diff --git a/src/openai_gpt/openai_gpt.rs b/src/openai_gpt/openai_gpt.rs index b2484db..f5dcf0b 100644 --- a/src/openai_gpt/openai_gpt.rs +++ b/src/openai_gpt/openai_gpt.rs @@ -104,7 +104,7 @@ impl OpenAiGptModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = Gpt2Config::from_file(config_path); - /// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config); + /// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&p.root() / "gpt", &config); /// ``` pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel where @@ -320,7 +320,7 @@ impl OpenAIGPTLMHeadModel { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = Gpt2Config::from_file(config_path); - /// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config); + /// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&p.root() / "gpt", &config); /// ``` pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel where diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 9d90c55..a282410 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -68,9 +68,10 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{ TokenizedInput, TruncationStrategy, }; use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; use std::collections::HashMap; use tch::nn::VarStore; -use tch::{no_grad, Device, Kind, Tensor}; +use tch::{nn, no_grad, Device, Kind, Tensor}; #[derive(Debug, Serialize, Deserialize)] /// # Label generated by a `SequenceClassificationModel` @@ -176,7 +177,10 @@ impl SequenceClassificationOption { /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) /// * `config` - A configuration (the model type of the configuration must be compatible with the value for /// `model_type`) - pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self { + pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self + where + P: Borrow>, + { match model_type { ModelType::Bert => { if let ConfigOption::Bert(config) = config { diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 47bf2cb..7f9cb95 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -123,11 +123,12 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{ Tokenizer, TruncationStrategy, }; use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; use std::cmp::min; use std::collections::HashMap; use tch::kind::Kind::Float; use tch::nn::VarStore; -use tch::{no_grad, Device, Tensor}; +use tch::{nn, no_grad, Device, Tensor}; #[derive(Debug, Clone, Serialize, Deserialize)] /// # Token generated by a `TokenClassificationModel` @@ -290,7 +291,10 @@ impl TokenClassificationOption { /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) /// * `config` - A configuration (the model type of the configuration must be compatible with the value for /// `model_type`) - pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self { + pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self + where + P: Borrow>, + { match model_type { ModelType::Bert => { if let ConfigOption::Bert(config) = config { diff --git a/src/roberta/embeddings.rs b/src/roberta/embeddings.rs index 53e8dbe..20573f9 100644 --- a/src/roberta/embeddings.rs +++ b/src/roberta/embeddings.rs @@ -68,7 +68,7 @@ impl BertEmbedding for RobertaEmbeddings { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config); + /// let robert_embeddings = RobertaEmbeddings::new(&p.root() / "bert_embeddings", &config); /// ``` fn new<'p, P>(p: P, config: &BertConfig) -> RobertaEmbeddings where diff --git a/src/roberta/roberta.rs b/src/roberta/roberta.rs index 387852c..312c01b 100644 --- a/src/roberta/roberta.rs +++ b/src/roberta/roberta.rs @@ -147,7 +147,7 @@ impl RobertaForMaskedLM { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config); + /// let roberta = RobertaForMaskedLM::new(&p.root() / "roberta", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM where @@ -325,7 +325,7 @@ impl RobertaForSequenceClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config); + /// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForSequenceClassification where @@ -453,7 +453,7 @@ impl RobertaForMultipleChoice { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config); + /// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice where @@ -598,7 +598,7 @@ impl RobertaForTokenClassification { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config); + /// let roberta = RobertaForTokenClassification::new(&p.root() / "roberta", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForTokenClassification where @@ -739,7 +739,7 @@ impl RobertaForQuestionAnswering { /// let device = Device::Cpu; /// let p = nn::VarStore::new(device); /// let config = BertConfig::from_file(config_path); - /// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config); + /// let roberta = RobertaForQuestionAnswering::new(&p.root() / "roberta", &config); /// ``` pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering where