Updated tch::nn::Path handling to be consistent across crate

This commit is contained in:
Guillaume B 2020-06-30 19:47:08 +02:00
parent a067faf574
commit afa03cd46e
22 changed files with 173 additions and 108 deletions

View File

@ -138,7 +138,7 @@ impl AlbertModel {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
/// let albert: AlbertModel = AlbertModel::new(&p.root() / "albert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertModel
where

View File

@ -13,6 +13,7 @@
use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
@ -31,33 +32,37 @@ pub struct AlbertSelfAttention {
}
impl AlbertSelfAttention {
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let p = p.borrow();
let query = nn::linear(
&p / "query",
p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense = nn::linear(
&p / "dense",
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
@ -76,11 +81,8 @@ impl AlbertSelfAttention {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&p / "LayerNorm",
vec![config.hidden_size],
layer_norm_config,
);
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
AlbertSelfAttention {
num_attention_heads: config.num_attention_heads,

View File

@ -19,7 +19,6 @@ use tch::{nn, Kind, Tensor};
/// # Embeddings implementation for Albert model
#[derive(Debug)]
/// # Embeddings implementation for Electra model
pub struct AlbertEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,

View File

@ -12,6 +12,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
@ -72,19 +73,24 @@ pub struct SelfAttention {
}
impl SelfAttention {
pub fn new(
p: nn::Path,
pub fn new<'p, P>(
p: P,
embed_dim: i64,
num_heads: i64,
dropout: f64,
encoder_decoder_attention: bool,
store_cache: bool,
output_attentions: bool,
) -> SelfAttention {
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
let out_proj = nn::linear(&p / "out_proj", embed_dim, embed_dim, Default::default());
) -> SelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let k_proj = nn::linear(p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(p / "q_proj", embed_dim, embed_dim, Default::default());
let out_proj = nn::linear(p / "out_proj", embed_dim, embed_dim, Default::default());
let head_dim = embed_dim / num_heads;
let scaling = (head_dim as f64).powf(-0.5);

View File

@ -246,7 +246,7 @@ impl BartModel {
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
/// let bart: BartModel = BartModel::new(&p.root() / "bart", &config, generation_mode);
/// ```
pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartModel
where
@ -452,14 +452,17 @@ impl BartForConditionalGeneration {
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// BartForConditionalGeneration::new(&p.root() / "bart", &config, generation_mode);
/// ```
pub fn new(
p: &nn::Path,
pub fn new<'p, P>(
p: P,
config: &BartConfig,
generation_mode: bool,
) -> BartForConditionalGeneration {
let base_model = BartModel::new(p / "model", config, generation_mode);
) -> BartForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let base_model = BartModel::new(p.borrow() / "model", config, generation_mode);
BartForConditionalGeneration { base_model }
}
@ -653,7 +656,7 @@ impl BartForSequenceClassification {
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
/// BartForSequenceClassification::new(&p.root() / "bart", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification
where

View File

@ -19,7 +19,7 @@ use crate::bart::embeddings::{
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut;
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::{nn, Tensor};
@ -37,7 +37,12 @@ pub struct DecoderLayer {
}
impl DecoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
pub fn new<'p, P>(p: P, config: &BartConfig) -> DecoderLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
@ -47,7 +52,7 @@ impl DecoderLayer {
None => false,
};
let self_attention = SelfAttention::new(
&p / "self_attn",
p / "self_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
@ -56,7 +61,7 @@ impl DecoderLayer {
output_attention,
);
let encoder_attention = SelfAttention::new(
&p / "encoder_attn",
p / "encoder_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
@ -65,12 +70,12 @@ impl DecoderLayer {
output_attention,
);
let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let encoder_attention_layer_norm = nn::layer_norm(
&p / "encoder_attn_layer_norm",
p / "encoder_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
@ -89,20 +94,20 @@ impl DecoderLayer {
Activation::tanh => _tanh,
});
let fc1 = nn::linear(
&p / "fc1",
p / "fc1",
config.d_model,
config.decoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
p / "fc2",
config.decoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm",
p / "final_layer_norm",
vec![config.d_model],
layer_norm_config,
);
@ -182,7 +187,11 @@ pub struct BartDecoder {
}
impl BartDecoder {
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
pub fn new<'p, P>(p: P, config: &BartConfig, generation_mode: bool) -> BartDecoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_past = match config.output_past {
Some(value) => value,
None => true,
@ -222,7 +231,7 @@ impl BartDecoder {
..Default::default()
};
Some(nn::layer_norm(
&p / "layernorm_embedding",
p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config,
))
@ -237,13 +246,13 @@ impl BartDecoder {
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions",
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions",
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
@ -251,7 +260,7 @@ impl BartDecoder {
};
let mut layers: Vec<DecoderLayer> = vec![];
let p_layers = &p / "layers";
let p_layers = p / "layers";
for layer_index in 0..config.decoder_layers {
layers.push(DecoderLayer::new(&p_layers / layer_index, config));
}

View File

@ -11,6 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Borrow;
use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
@ -43,12 +44,15 @@ pub struct LearnedPositionalEmbedding {
}
impl LearnedPositionalEmbedding {
pub fn new(
p: nn::Path,
pub fn new<'p, P>(
p: P,
num_embeddings: i64,
embedding_dim: i64,
padding_index: i64,
) -> LearnedPositionalEmbedding {
) -> LearnedPositionalEmbedding
where
P: Borrow<nn::Path<'p>>,
{
let embedding_config = EmbeddingConfig {
padding_idx: padding_index,
..Default::default()
@ -56,7 +60,7 @@ impl LearnedPositionalEmbedding {
let num_embeddings = num_embeddings + padding_index + 1;
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, embedding_config);
embedding(p.borrow(), num_embeddings, embedding_dim, embedding_config);
LearnedPositionalEmbedding {
embedding,
padding_index,
@ -86,13 +90,20 @@ pub struct SinusoidalPositionalEmbedding {
}
impl SinusoidalPositionalEmbedding {
pub fn new(
p: nn::Path,
pub fn new<'p, P>(
p: P,
num_embeddings: i64,
embedding_dim: i64,
) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, Default::default());
) -> SinusoidalPositionalEmbedding
where
P: Borrow<nn::Path<'p>>,
{
let embedding: nn::Embedding = embedding(
p.borrow(),
num_embeddings,
embedding_dim,
Default::default(),
);
SinusoidalPositionalEmbedding { embedding }
}

View File

@ -19,7 +19,7 @@ use crate::bart::embeddings::{
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut;
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Bool;
use tch::{nn, Tensor};
@ -35,7 +35,12 @@ pub struct EncoderLayer {
}
impl EncoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
pub fn new<'p, P>(p: P, config: &BartConfig) -> EncoderLayer
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
@ -45,7 +50,7 @@ impl EncoderLayer {
None => false,
};
let self_attention = SelfAttention::new(
&p / "self_attn",
p / "self_attn",
config.d_model,
config.encoder_attention_heads,
config.attention_dropout,
@ -54,7 +59,7 @@ impl EncoderLayer {
output_attention,
);
let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
@ -72,20 +77,20 @@ impl EncoderLayer {
Activation::tanh => _tanh,
});
let fc1 = nn::linear(
&p / "fc1",
p / "fc1",
config.d_model,
config.encoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
p / "fc2",
config.encoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm",
p / "final_layer_norm",
vec![config.d_model],
layer_norm_config,
);
@ -136,7 +141,11 @@ pub struct BartEncoder {
}
impl BartEncoder {
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false,
@ -172,7 +181,7 @@ impl BartEncoder {
..Default::default()
};
Some(nn::layer_norm(
&p / "layernorm_embedding",
p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config,
))
@ -187,13 +196,13 @@ impl BartEncoder {
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions",
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions",
p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
@ -201,7 +210,7 @@ impl BartEncoder {
};
let mut layers: Vec<EncoderLayer> = vec![];
let p_layers = &p / "layers";
let p_layers = p / "layers";
for layer_index in 0..config.encoder_layers {
layers.push(EncoderLayer::new(&p_layers / layer_index, config));
}

View File

@ -30,27 +30,31 @@ pub struct BertSelfAttention {
}
impl BertSelfAttention {
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let p = p.borrow();
let query = nn::linear(
&p / "query",
p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),

View File

@ -145,7 +145,7 @@ impl<T: BertEmbedding> BertModel<T> {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert: BertModel<BertEmbeddings> = BertModel::new(&(&p.root() / "bert"), &config);
/// let bert: BertModel<BertEmbeddings> = BertModel::new(&p.root() / "bert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertModel<T>
where
@ -442,7 +442,7 @@ impl BertForMaskedLM {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForMaskedLM::new(&(&p.root() / "bert"), &config);
/// let bert = BertForMaskedLM::new(&p.root() / "bert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMaskedLM
where
@ -569,7 +569,7 @@ impl BertForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForSequenceClassification::new(&(&p.root() / "bert"), &config);
/// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification
where
@ -709,7 +709,7 @@ impl BertForMultipleChoice {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForMultipleChoice::new(&(&p.root() / "bert"), &config);
/// let bert = BertForMultipleChoice::new(&p.root() / "bert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForMultipleChoice
where
@ -852,7 +852,7 @@ impl BertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForTokenClassification::new(&(&p.root() / "bert"), &config);
/// let bert = BertForTokenClassification::new(&p.root() / "bert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForTokenClassification
where
@ -991,7 +991,7 @@ impl BertForQuestionAnswering {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForQuestionAnswering::new(&(&p.root() / "bert"), &config);
/// let bert = BertForQuestionAnswering::new(&p.root() / "bert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForQuestionAnswering
where

View File

@ -65,7 +65,7 @@ impl BertEmbedding for BertEmbeddings {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// let bert_embeddings = BertEmbeddings::new(&p.root() / "bert_embeddings", &config);
/// ```
fn new<'p, P>(p: P, config: &BertConfig) -> BertEmbeddings
where

View File

@ -12,6 +12,7 @@
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
@ -28,11 +29,15 @@ pub struct MultiHeadSelfAttention {
}
impl MultiHeadSelfAttention {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> MultiHeadSelfAttention {
let q_lin = nn::linear(&p / "q_lin", config.dim, config.dim, Default::default());
let k_lin = nn::linear(&p / "k_lin", config.dim, config.dim, Default::default());
let v_lin = nn::linear(&p / "v_lin", config.dim, config.dim, Default::default());
let out_lin = nn::linear(&p / "out_lin", config.dim, config.dim, Default::default());
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> MultiHeadSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let q_lin = nn::linear(p / "q_lin", config.dim, config.dim, Default::default());
let k_lin = nn::linear(p / "k_lin", config.dim, config.dim, Default::default());
let v_lin = nn::linear(p / "v_lin", config.dim, config.dim, Default::default());
let out_lin = nn::linear(p / "out_lin", config.dim, config.dim, Default::default());
let dropout = Dropout::new(config.attention_dropout);

View File

@ -154,15 +154,15 @@ impl DistilBertModel {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&p.root() / "distilbert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "distilbert";
let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
let transformer = Transformer::new(&p / "transformer", config);
let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config);
let transformer = Transformer::new(p.borrow() / "transformer", config);
DistilBertModel {
embeddings,
transformer,
@ -269,7 +269,7 @@ impl DistilBertModelClassifier {
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
/// DistilBertModelClassifier::new(&p.root() / "distilbert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier
where
@ -404,7 +404,7 @@ impl DistilBertModelMaskedLM {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
/// let distil_bert = DistilBertModelMaskedLM::new(&p.root() / "distilbert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelMaskedLM
where
@ -538,7 +538,7 @@ impl DistilBertForQuestionAnswering {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
/// let distil_bert = DistilBertForQuestionAnswering::new(&p.root() / "distilbert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForQuestionAnswering
where
@ -656,7 +656,7 @@ impl DistilBertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
/// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config);
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertForTokenClassification
where

View File

@ -26,15 +26,19 @@ pub struct FeedForwardNetwork {
}
impl FeedForwardNetwork {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> FeedForwardNetwork
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let lin1 = nn::linear(
&p / "lin1",
p / "lin1",
config.dim,
config.hidden_dim,
Default::default(),
);
let lin2 = nn::linear(
&p / "lin2",
p / "lin2",
config.hidden_dim,
config.dim,
Default::default(),

View File

@ -130,7 +130,7 @@ impl ElectraModel {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
/// let electra_model: ElectraModel = ElectraModel::new(&p.root() / "electra", &config);
/// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraModel
where
@ -325,7 +325,7 @@ impl ElectraDiscriminatorHead {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path);
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
/// let discriminator_head = ElectraDiscriminatorHead::new(&p.root() / "electra", &config);
/// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraDiscriminatorHead
where
@ -430,7 +430,7 @@ impl ElectraGeneratorHead {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path);
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
/// let generator_head = ElectraGeneratorHead::new(&p.root() / "electra", &config);
/// ```
pub fn new<'p, P>(p: P, config: &ElectraConfig) -> ElectraGeneratorHead
where

View File

@ -245,7 +245,7 @@ impl Gpt2Model {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
/// let gpt2: Gpt2Model = Gpt2Model::new(&p.root() / "gpt2", &config);
/// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> Gpt2Model
where
@ -533,7 +533,7 @@ impl GPT2LMHeadModel {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&p.root() / "gpt2", &config);
/// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> GPT2LMHeadModel
where

View File

@ -13,6 +13,7 @@
use crate::bart::{BartConfig, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel};
use std::borrow::Borrow;
use tch::nn::Init;
use tch::{nn, Tensor};
@ -257,13 +258,17 @@ impl MarianForConditionalGeneration {
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// BartForConditionalGeneration::new(&p.root() / "bart", &config, generation_mode);
/// ```
pub fn new(
p: &nn::Path,
pub fn new<'p, P>(
p: P,
config: &BartConfig,
generation_mode: bool,
) -> MarianForConditionalGeneration {
) -> MarianForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = BartModel::new(p / "model", config, generation_mode);
let final_logits_bias = p.var(
"final_logits_bias",

View File

@ -104,7 +104,7 @@ impl OpenAiGptModel {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&p.root() / "gpt", &config);
/// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAiGptModel
where
@ -320,7 +320,7 @@ impl OpenAIGPTLMHeadModel {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&p.root() / "gpt", &config);
/// ```
pub fn new<'p, P>(p: P, config: &Gpt2Config) -> OpenAIGPTLMHeadModel
where

View File

@ -68,9 +68,10 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
TokenizedInput, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
use tch::{nn, no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)]
/// # Label generated by a `SequenceClassificationModel`
@ -176,7 +177,10 @@ impl SequenceClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {

View File

@ -123,11 +123,12 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
Tokenizer, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::cmp::min;
use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
use tch::{nn, no_grad, Device, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel`
@ -290,7 +291,10 @@ impl TokenClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self
where
P: Borrow<nn::Path<'p>>,
{
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {

View File

@ -68,7 +68,7 @@ impl BertEmbedding for RobertaEmbeddings {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// let robert_embeddings = RobertaEmbeddings::new(&p.root() / "bert_embeddings", &config);
/// ```
fn new<'p, P>(p: P, config: &BertConfig) -> RobertaEmbeddings
where

View File

@ -147,7 +147,7 @@ impl RobertaForMaskedLM {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
/// let roberta = RobertaForMaskedLM::new(&p.root() / "roberta", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMaskedLM
where
@ -325,7 +325,7 @@ impl RobertaForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
/// let roberta = RobertaForSequenceClassification::new(&p.root() / "roberta", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForSequenceClassification
where
@ -453,7 +453,7 @@ impl RobertaForMultipleChoice {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
/// let roberta = RobertaForMultipleChoice::new(&p.root() / "roberta", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForMultipleChoice
where
@ -598,7 +598,7 @@ impl RobertaForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
/// let roberta = RobertaForTokenClassification::new(&p.root() / "roberta", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForTokenClassification
where
@ -739,7 +739,7 @@ impl RobertaForQuestionAnswering {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
/// let roberta = RobertaForQuestionAnswering::new(&p.root() / "roberta", &config);
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> RobertaForQuestionAnswering
where