Updated tch::nn::Path handling to be consistent across crate

This commit is contained in:
Guillaume B 2020-06-30 19:47:08 +02:00
parent a067faf574
commit afa03cd46e
22 changed files with 173 additions and 108 deletions

View File

@ -138,7 +138,7 @@ impl AlbertModel {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config); /// let albert: AlbertModel = AlbertModel::new(&p.root() / "albert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertModel pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertModel
where where

View File

@ -13,6 +13,7 @@
use crate::albert::AlbertConfig; use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -31,33 +32,37 @@ pub struct AlbertSelfAttention {
} }
impl AlbertSelfAttention { impl AlbertSelfAttention {
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention { pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
assert_eq!( assert_eq!(
config.hidden_size % config.num_attention_heads, config.hidden_size % config.num_attention_heads,
0, 0,
"Hidden size not a multiple of the number of attention heads" "Hidden size not a multiple of the number of attention heads"
); );
let p = p.borrow();
let query = nn::linear( let query = nn::linear(
&p / "query", p / "query",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),
); );
let key = nn::linear( let key = nn::linear(
&p / "key", p / "key",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),
); );
let value = nn::linear( let value = nn::linear(
&p / "value", p / "value",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),
); );
let dense = nn::linear( let dense = nn::linear(
&p / "dense", p / "dense",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),
@ -76,11 +81,8 @@ impl AlbertSelfAttention {
eps: layer_norm_eps, eps: layer_norm_eps,
..Default::default() ..Default::default()
}; };
let layer_norm = nn::layer_norm( let layer_norm =
&p / "LayerNorm", nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
vec![config.hidden_size],
layer_norm_config,
);
AlbertSelfAttention { AlbertSelfAttention {
num_attention_heads: config.num_attention_heads, num_attention_heads: config.num_attention_heads,

View File

@ -19,7 +19,6 @@ use tch::{nn, Kind, Tensor};
/// # Embeddings implementation for Albert model /// # Embeddings implementation for Albert model
#[derive(Debug)] #[derive(Debug)]
/// # Embeddings implementation for Electra model
pub struct AlbertEmbeddings { pub struct AlbertEmbeddings {
word_embeddings: nn::Embedding, word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding, position_embeddings: nn::Embedding,

View File

@ -12,6 +12,7 @@
// limitations under the License. // limitations under the License.
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -72,19 +73,24 @@ pub struct SelfAttention {
} }
impl SelfAttention { impl SelfAttention {
pub fn new( pub fn new<'p, P>(
p: nn::Path, p: P,
embed_dim: i64, embed_dim: i64,
num_heads: i64, num_heads: i64,
dropout: f64, dropout: f64,
encoder_decoder_attention: bool, encoder_decoder_attention: bool,
store_cache: bool, store_cache: bool,
output_attentions: bool, output_attentions: bool,
) -> SelfAttention { ) -> SelfAttention
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default()); where
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default()); P: Borrow<nn::Path<'p>>,
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default()); {
let out_proj = nn::linear(&p / "out_proj", embed_dim, embed_dim, Default::default()); let p = p.borrow();
let k_proj = nn::linear(p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(p / "q_proj", embed_dim, embed_dim, Default::default());
let out_proj = nn::linear(p / "out_proj", embed_dim, embed_dim, Default::default());
let head_dim = embed_dim / num_heads; let head_dim = embed_dim / num_heads;
let scaling = (head_dim as f64).powf(-0.5); let scaling = (head_dim as f64).powf(-0.5);

View File

@ -246,7 +246,7 @@ impl BartModel {
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode); /// let bart: BartModel = BartModel::new(&p.root() / "bart", &config, generation_mode);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartModel pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartModel
where where
@ -452,14 +452,17 @@ impl BartForConditionalGeneration {
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartForConditionalGeneration = /// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode); /// BartForConditionalGeneration::new(&p.root() / "bart", &config, generation_mode);
/// ``` /// ```
pub fn new( pub fn new<'p, P>(
p: &nn::Path, p: P,
config: &BartConfig, config: &BartConfig,
generation_mode: bool, generation_mode: bool,
) -> BartForConditionalGeneration { ) -> BartForConditionalGeneration
let base_model = BartModel::new(p / "model", config, generation_mode); where
P: Borrow<nn::Path<'p>>,
{
let base_model = BartModel::new(p.borrow() / "model", config, generation_mode);
BartForConditionalGeneration { base_model } BartForConditionalGeneration { base_model }
} }
@ -653,7 +656,7 @@ impl BartForSequenceClassification {
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartForSequenceClassification = /// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config); /// BartForSequenceClassification::new(&p.root() / "bart", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification
where where

View File

@ -19,7 +19,7 @@ use crate::bart::embeddings::{
use crate::bart::BartConfig; use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh}; use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use std::borrow::BorrowMut; use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64; use tch::kind::Kind::Int64;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -37,7 +37,12 @@ pub struct DecoderLayer {
} }
impl DecoderLayer { impl DecoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer { pub fn new<'p, P>(p: P, config: &BartConfig) -> DecoderLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = nn::LayerNormConfig { let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5, eps: 1e-5,
..Default::default() ..Default::default()
@ -47,7 +52,7 @@ impl DecoderLayer {
None => false, None => false,
}; };
let self_attention = SelfAttention::new( let self_attention = SelfAttention::new(
&p / "self_attn", p / "self_attn",
config.d_model, config.d_model,
config.decoder_attention_heads, config.decoder_attention_heads,
config.attention_dropout, config.attention_dropout,
@ -56,7 +61,7 @@ impl DecoderLayer {
output_attention, output_attention,
); );
let encoder_attention = SelfAttention::new( let encoder_attention = SelfAttention::new(
&p / "encoder_attn", p / "encoder_attn",
config.d_model, config.d_model,
config.decoder_attention_heads, config.decoder_attention_heads,
config.attention_dropout, config.attention_dropout,
@ -65,12 +70,12 @@ impl DecoderLayer {
output_attention, output_attention,
); );
let self_attention_layer_norm = nn::layer_norm( let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm", p / "self_attn_layer_norm",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
); );
let encoder_attention_layer_norm = nn::layer_norm( let encoder_attention_layer_norm = nn::layer_norm(
&p / "encoder_attn_layer_norm", p / "encoder_attn_layer_norm",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
); );
@ -89,20 +94,20 @@ impl DecoderLayer {
Activation::tanh => _tanh, Activation::tanh => _tanh,
}); });
let fc1 = nn::linear( let fc1 = nn::linear(
&p / "fc1", p / "fc1",
config.d_model, config.d_model,
config.decoder_ffn_dim, config.decoder_ffn_dim,
Default::default(), Default::default(),
); );
let fc2 = nn::linear( let fc2 = nn::linear(
&p / "fc2", p / "fc2",
config.decoder_ffn_dim, config.decoder_ffn_dim,
config.d_model, config.d_model,
Default::default(), Default::default(),
); );
let final_layer_norm = nn::layer_norm( let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm", p / "final_layer_norm",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
); );
@ -182,7 +187,11 @@ pub struct BartDecoder {
} }
impl BartDecoder { impl BartDecoder {
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder { pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartDecoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_past = match config.output_past { let output_past = match config.output_past {
Some(value) => value, Some(value) => value,
None => true, None => true,
@ -222,7 +231,7 @@ impl BartDecoder {
..Default::default() ..Default::default()
}; };
Some(nn::layer_norm( Some(nn::layer_norm(
&p / "layernorm_embedding", p / "layernorm_embedding",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
)) ))
@ -237,13 +246,13 @@ impl BartDecoder {
let embed_positions = if static_position_embeddings { let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new( EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions", p / "embed_positions",
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
)) ))
} else { } else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new( EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions", p / "embed_positions",
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
pad_token_id, pad_token_id,
@ -251,7 +260,7 @@ impl BartDecoder {
}; };
let mut layers: Vec<DecoderLayer> = vec![]; let mut layers: Vec<DecoderLayer> = vec![];
let p_layers = &p / "layers"; let p_layers = p / "layers";
for layer_index in 0..config.decoder_layers { for layer_index in 0..config.decoder_layers {
layers.push(DecoderLayer::new(&p_layers / layer_index, config)); layers.push(DecoderLayer::new(&p_layers / layer_index, config));
} }

View File

@ -11,6 +11,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::borrow::Borrow;
use tch::kind::Kind::Int64; use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig}; use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -43,12 +44,15 @@ pub struct LearnedPositionalEmbedding {
} }
impl LearnedPositionalEmbedding { impl LearnedPositionalEmbedding {
pub fn new( pub fn new<'p, P>(
p: nn::Path, p: P,
num_embeddings: i64, num_embeddings: i64,
embedding_dim: i64, embedding_dim: i64,
padding_index: i64, padding_index: i64,
) -> LearnedPositionalEmbedding { ) -> LearnedPositionalEmbedding
where
P: Borrow<nn::Path<'p>>,
{
let embedding_config = EmbeddingConfig { let embedding_config = EmbeddingConfig {
padding_idx: padding_index, padding_idx: padding_index,
..Default::default() ..Default::default()
@ -56,7 +60,7 @@ impl LearnedPositionalEmbedding {
let num_embeddings = num_embeddings + padding_index + 1; let num_embeddings = num_embeddings + padding_index + 1;
let embedding: nn::Embedding = let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, embedding_config); embedding(p.borrow(), num_embeddings, embedding_dim, embedding_config);
LearnedPositionalEmbedding { LearnedPositionalEmbedding {
embedding, embedding,
padding_index, padding_index,
@ -86,13 +90,20 @@ pub struct SinusoidalPositionalEmbedding {
} }
impl SinusoidalPositionalEmbedding { impl SinusoidalPositionalEmbedding {
pub fn new( pub fn new<'p, P>(
p: nn::Path, p: P,
num_embeddings: i64, num_embeddings: i64,
embedding_dim: i64, embedding_dim: i64,
) -> SinusoidalPositionalEmbedding { ) -> SinusoidalPositionalEmbedding
let embedding: nn::Embedding = where
embedding(p, num_embeddings, embedding_dim, Default::default()); P: Borrow<nn::Path<'p>>,
{
let embedding: nn::Embedding = embedding(
p.borrow(),
num_embeddings,
embedding_dim,
Default::default(),
);
SinusoidalPositionalEmbedding { embedding } SinusoidalPositionalEmbedding { embedding }
} }

View File

@ -19,7 +19,7 @@ use crate::bart::embeddings::{
use crate::bart::BartConfig; use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh}; use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use std::borrow::BorrowMut; use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Bool; use tch::kind::Kind::Bool;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -35,7 +35,12 @@ pub struct EncoderLayer {
} }
impl EncoderLayer { impl EncoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer { pub fn new<'p, P>(p: P, config: &BartConfig) -> EncoderLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = nn::LayerNormConfig { let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5, eps: 1e-5,
..Default::default() ..Default::default()
@ -45,7 +50,7 @@ impl EncoderLayer {
None => false, None => false,
}; };
let self_attention = SelfAttention::new( let self_attention = SelfAttention::new(
&p / "self_attn", p / "self_attn",
config.d_model, config.d_model,
config.encoder_attention_heads, config.encoder_attention_heads,
config.attention_dropout, config.attention_dropout,
@ -54,7 +59,7 @@ impl EncoderLayer {
output_attention, output_attention,
); );
let self_attention_layer_norm = nn::layer_norm( let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm", p / "self_attn_layer_norm",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
); );
@ -72,20 +77,20 @@ impl EncoderLayer {
Activation::tanh => _tanh, Activation::tanh => _tanh,
}); });
let fc1 = nn::linear( let fc1 = nn::linear(
&p / "fc1", p / "fc1",
config.d_model, config.d_model,
config.encoder_ffn_dim, config.encoder_ffn_dim,
Default::default(), Default::default(),
); );
let fc2 = nn::linear( let fc2 = nn::linear(
&p / "fc2", p / "fc2",
config.encoder_ffn_dim, config.encoder_ffn_dim,
config.d_model, config.d_model,
Default::default(), Default::default(),
); );
let final_layer_norm = nn::layer_norm( let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm", p / "final_layer_norm",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
); );
@ -136,7 +141,11 @@ pub struct BartEncoder {
} }
impl BartEncoder { impl BartEncoder {
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder { pub fn new<'p, P>(p: P, config: &BartConfig) -> BartEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false, None => false,
@ -172,7 +181,7 @@ impl BartEncoder {
..Default::default() ..Default::default()
}; };
Some(nn::layer_norm( Some(nn::layer_norm(
&p / "layernorm_embedding", p / "layernorm_embedding",
vec![config.d_model], vec![config.d_model],
layer_norm_config, layer_norm_config,
)) ))
@ -187,13 +196,13 @@ impl BartEncoder {
let embed_positions = if static_position_embeddings { let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new( EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions", p / "embed_positions",
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
)) ))
} else { } else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new( EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions", p / "embed_positions",
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
pad_token_id, pad_token_id,
@ -201,7 +210,7 @@ impl BartEncoder {
}; };
let mut layers: Vec<EncoderLayer> = vec![]; let mut layers: Vec<EncoderLayer> = vec![];
let p_layers = &p / "layers"; let p_layers = p / "layers";
for layer_index in 0..config.encoder_layers { for layer_index in 0..config.encoder_layers {
layers.push(EncoderLayer::new(&p_layers / layer_index, config)); layers.push(EncoderLayer::new(&p_layers / layer_index, config));
} }

View File

@ -30,27 +30,31 @@ pub struct BertSelfAttention {
} }
impl BertSelfAttention { impl BertSelfAttention {
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention { pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
assert_eq!( assert_eq!(
config.hidden_size % config.num_attention_heads, config.hidden_size % config.num_attention_heads,
0, 0,
"Hidden size not a multiple of the number of attention heads" "Hidden size not a multiple of the number of attention heads"
); );
let p = p.borrow();
let query = nn::linear( let query = nn::linear(
&p / "query", p / "query",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),
); );
let key = nn::linear( let key = nn::linear(
&p / "key", p / "key",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),
); );
let value = nn::linear( let value = nn::linear(
&p / "value", p / "value",
config.hidden_size, config.hidden_size,
config.hidden_size, config.hidden_size,
Default::default(), Default::default(),

View File

@ -145,7 +145,7 @@ impl<T: BertEmbedding> BertModel<T> {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert: BertModel<BertEmbeddings> = BertModel::new(&(&p.root() / "bert"), &config); /// let bert: BertModel<BertEmbeddings> = BertModel::new(&p.root() / "bert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel<T> pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel<T>
where where
@ -442,7 +442,7 @@ impl BertForMaskedLM {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert = BertForMaskedLM::new(&(&p.root() / "bert"), &config); /// let bert = BertForMaskedLM::new(&p.root() / "bert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM
where where
@ -569,7 +569,7 @@ impl BertForSequenceClassification {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert = BertForSequenceClassification::new(&(&p.root() / "bert"), &config); /// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification
where where
@ -709,7 +709,7 @@ impl BertForMultipleChoice {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert = BertForMultipleChoice::new(&(&p.root() / "bert"), &config); /// let bert = BertForMultipleChoice::new(&p.root() / "bert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice
where where
@ -852,7 +852,7 @@ impl BertForTokenClassification {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert = BertForTokenClassification::new(&(&p.root() / "bert"), &config); /// let bert = BertForTokenClassification::new(&p.root() / "bert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForTokenClassification pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForTokenClassification
where where
@ -991,7 +991,7 @@ impl BertForQuestionAnswering {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert = BertForQuestionAnswering::new(&(&p.root() / "bert"), &config); /// let bert = BertForQuestionAnswering::new(&p.root() / "bert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering
where where

View File

@ -65,7 +65,7 @@ impl BertEmbedding for BertEmbeddings {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config); /// let bert_embeddings = BertEmbeddings::new(&p.root() / "bert_embeddings", &config);
/// ``` /// ```
fn new<'p, P>(p: P, config: &BertConfig) -> BertEmbeddings fn new<'p, P>(p: P, config: &BertConfig) -> BertEmbeddings
where where

View File

@ -12,6 +12,7 @@
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig; use crate::distilbert::distilbert::DistilBertConfig;
use std::borrow::Borrow;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -28,11 +29,15 @@ pub struct MultiHeadSelfAttention {
} }
impl MultiHeadSelfAttention { impl MultiHeadSelfAttention {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> MultiHeadSelfAttention { pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> MultiHeadSelfAttention
let q_lin = nn::linear(&p / "q_lin", config.dim, config.dim, Default::default()); where
let k_lin = nn::linear(&p / "k_lin", config.dim, config.dim, Default::default()); P: Borrow<nn::Path<'p>>,
let v_lin = nn::linear(&p / "v_lin", config.dim, config.dim, Default::default()); {
let out_lin = nn::linear(&p / "out_lin", config.dim, config.dim, Default::default()); let p = p.borrow();
let q_lin = nn::linear(p / "q_lin", config.dim, config.dim, Default::default());
let k_lin = nn::linear(p / "k_lin", config.dim, config.dim, Default::default());
let v_lin = nn::linear(p / "v_lin", config.dim, config.dim, Default::default());
let out_lin = nn::linear(p / "out_lin", config.dim, config.dim, Default::default());
let dropout = Dropout::new(config.attention_dropout); let dropout = Dropout::new(config.attention_dropout);

View File

@ -154,15 +154,15 @@ impl DistilBertModel {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config); /// let distil_bert: DistilBertModel = DistilBertModel::new(&p.root() / "distilbert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel
where where
P: Borrow<nn::Path<'p>>, P: Borrow<nn::Path<'p>>,
{ {
let p = p.borrow() / "distilbert"; let p = p.borrow() / "distilbert";
let embeddings = DistilBertEmbedding::new(&p / "embeddings", config); let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config);
let transformer = Transformer::new(&p / "transformer", config); let transformer = Transformer::new(p.borrow() / "transformer", config);
DistilBertModel { DistilBertModel {
embeddings, embeddings,
transformer, transformer,
@ -269,7 +269,7 @@ impl DistilBertModelClassifier {
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier = /// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config); /// DistilBertModelClassifier::new(&p.root() / "distilbert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier
where where
@ -404,7 +404,7 @@ impl DistilBertModelMaskedLM {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config); /// let distil_bert = DistilBertModelMaskedLM::new(&p.root() / "distilbert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM
where where
@ -538,7 +538,7 @@ impl DistilBertForQuestionAnswering {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config); /// let distil_bert = DistilBertForQuestionAnswering::new(&p.root() / "distilbert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering
where where
@ -656,7 +656,7 @@ impl DistilBertForTokenClassification {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config); /// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForTokenClassification pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForTokenClassification
where where

View File

@ -26,15 +26,19 @@ pub struct FeedForwardNetwork {
} }
impl FeedForwardNetwork { impl FeedForwardNetwork {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork { pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> FeedForwardNetwork
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let lin1 = nn::linear( let lin1 = nn::linear(
&p / "lin1", p / "lin1",
config.dim, config.dim,
config.hidden_dim, config.hidden_dim,
Default::default(), Default::default(),
); );
let lin2 = nn::linear( let lin2 = nn::linear(
&p / "lin2", p / "lin2",
config.hidden_dim, config.hidden_dim,
config.dim, config.dim,
Default::default(), Default::default(),

View File

@ -130,7 +130,7 @@ impl ElectraModel {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config); /// let electra_model: ElectraModel = ElectraModel::new(&p.root() / "electra", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel
where where
@ -325,7 +325,7 @@ impl ElectraDiscriminatorHead {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config); /// let discriminator_head = ElectraDiscriminatorHead::new(&p.root() / "electra", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead
where where
@ -430,7 +430,7 @@ impl ElectraGeneratorHead {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config); /// let generator_head = ElectraGeneratorHead::new(&p.root() / "electra", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead
where where

View File

@ -245,7 +245,7 @@ impl Gpt2Model {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config); /// let gpt2: Gpt2Model = Gpt2Model::new(&p.root() / "gpt2", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model
where where
@ -533,7 +533,7 @@ impl GPT2LMHeadModel {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config); /// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&p.root() / "gpt2", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel
where where

View File

@ -13,6 +13,7 @@
use crate::bart::{BartConfig, BartModel, LayerState}; use crate::bart::{BartConfig, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel}; use crate::pipelines::generation::{Cache, LMHeadModel};
use std::borrow::Borrow;
use tch::nn::Init; use tch::nn::Init;
use tch::{nn, Tensor}; use tch::{nn, Tensor};
@ -257,13 +258,17 @@ impl MarianForConditionalGeneration {
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartForConditionalGeneration = /// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode); /// BartForConditionalGeneration::new(&p.root() / "bart", &config, generation_mode);
/// ``` /// ```
pub fn new( pub fn new<'p, P>(
p: &nn::Path, p: P,
config: &BartConfig, config: &BartConfig,
generation_mode: bool, generation_mode: bool,
) -> MarianForConditionalGeneration { ) -> MarianForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = BartModel::new(p / "model", config, generation_mode); let base_model = BartModel::new(p / "model", config, generation_mode);
let final_logits_bias = p.var( let final_logits_bias = p.var(
"final_logits_bias", "final_logits_bias",

View File

@ -104,7 +104,7 @@ impl OpenAiGptModel {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config); /// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&p.root() / "gpt", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel
where where
@ -320,7 +320,7 @@ impl OpenAIGPTLMHeadModel {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config); /// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&p.root() / "gpt", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel
where where

View File

@ -68,9 +68,10 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
TokenizedInput, TruncationStrategy, TokenizedInput, TruncationStrategy,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap; use std::collections::HashMap;
use tch::nn::VarStore; use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor}; use tch::{nn, no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
/// # Label generated by a `SequenceClassificationModel` /// # Label generated by a `SequenceClassificationModel`
@ -176,7 +177,10 @@ impl SequenceClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for /// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`) /// `model_type`)
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self { pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self
where
P: Borrow<nn::Path<'p>>,
{
match model_type { match model_type {
ModelType::Bert => { ModelType::Bert => {
if let ConfigOption::Bert(config) = config { if let ConfigOption::Bert(config) = config {

View File

@ -123,11 +123,12 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
Tokenizer, TruncationStrategy, Tokenizer, TruncationStrategy,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::cmp::min; use std::cmp::min;
use std::collections::HashMap; use std::collections::HashMap;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::nn::VarStore; use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor}; use tch::{nn, no_grad, Device, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel` /// # Token generated by a `TokenClassificationModel`
@ -290,7 +291,10 @@ impl TokenClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for /// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`) /// `model_type`)
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self { pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self
where
P: Borrow<nn::Path<'p>>,
{
match model_type { match model_type {
ModelType::Bert => { ModelType::Bert => {
if let ConfigOption::Bert(config) = config { if let ConfigOption::Bert(config) = config {

View File

@ -68,7 +68,7 @@ impl BertEmbedding for RobertaEmbeddings {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config); /// let robert_embeddings = RobertaEmbeddings::new(&p.root() / "bert_embeddings", &config);
/// ``` /// ```
fn new<'p, P>(p: P, config: &BertConfig) -> RobertaEmbeddings fn new<'p, P>(p: P, config: &BertConfig) -> RobertaEmbeddings
where where

View File

@ -147,7 +147,7 @@ impl RobertaForMaskedLM {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForMaskedLM::new(&p.root() / "roberta", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM
where where
@ -325,7 +325,7 @@ impl RobertaForSequenceClassification {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForSequenceClassification pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForSequenceClassification
where where
@ -453,7 +453,7 @@ impl RobertaForMultipleChoice {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice
where where
@ -598,7 +598,7 @@ impl RobertaForTokenClassification {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForTokenClassification::new(&p.root() / "roberta", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForTokenClassification pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForTokenClassification
where where
@ -739,7 +739,7 @@ impl RobertaForQuestionAnswering {
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForQuestionAnswering::new(&p.root() / "roberta", &config);
/// ``` /// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering
where where