Fixed Clippy warnings (#309)

* - Fixed Clippy warnings
- Updated `tch` dependency
- Updated README to avoid confusion with respect to the required `LIBTORCH` version for the repository and published package versions

* Fixed Clippy warnings (2)

* Fixed Clippy warnings (3)
This commit is contained in:
guillaume-be 2022-12-21 17:52:26 +00:00 committed by GitHub
parent dae899fea6
commit fdf5503163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 130 additions and 138 deletions

View File

@ -70,7 +70,7 @@ features = ["doc-only"]
[dependencies]
rust_tokenizers = "~7.0.2"
tch = "~0.9.0"
tch = "~0.10.1"
serde_json = "1.0.82"
serde = { version = "1.0.140", features = ["derive"] }
ordered-float = "3.0.0"
@ -88,6 +88,6 @@ anyhow = "1.0.58"
csv = "1.1.6"
criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.9.0"
torch-sys = "0.10.0"
tempfile = "3.3.0"
itertools = "0.10.3"

View File

@ -76,8 +76,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcu117.zip` for a Linux version with CUDA11.
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.13.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-1.13.1%2Bcu117.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:

View File

@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(&vs.root(), &config);
let model = DebertaForSequenceClassification::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input

View File

@ -1101,7 +1101,7 @@ impl BartGenerator {
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(&var_store.root(), &config);
let model = BartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -1131,7 +1131,7 @@ impl BartGenerator {
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
@ -1337,6 +1337,6 @@ mod test {
let vs = tch::nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(BartModel::new(&vs.root(), &config));
let _: Box<dyn Send> = Box::new(BartModel::new(vs.root(), &config));
}
}

View File

@ -24,7 +24,7 @@ use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::Init;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Kind, Tensor};
/// # BERT Pretrained model weight files
@ -507,7 +507,7 @@ impl BertLMPredictionHead {
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
BertLMPredictionHead {
transform,
@ -1301,9 +1301,9 @@ mod test {
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = tch::nn::VarStore::new(device);
let vs = nn::VarStore::new(device);
let config = BertConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(&vs.root(), &config));
let _: Box<dyn Send> = Box::new(BertModel::<BertEmbeddings>::new(vs.root(), &config));
}
}

View File

@ -11,6 +11,7 @@
// limitations under the License.
use std::borrow::Borrow;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::nn::{Init, Module, Path};
use tch::Tensor;
@ -22,7 +23,7 @@ pub struct LinearNoBiasConfig {
impl Default for LinearNoBiasConfig {
fn default() -> Self {
LinearNoBiasConfig {
ws_init: Init::KaimingUniform,
ws_init: DEFAULT_KAIMING_UNIFORM,
}
}
}

View File

@ -1070,6 +1070,6 @@ mod test {
let vs = tch::nn::VarStore::new(device);
let config = FNetConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(FNetModel::new(&vs.root(), &config, true));
let _: Box<dyn Send> = Box::new(FNetModel::new(vs.root(), &config, true));
}
}

View File

@ -742,7 +742,7 @@ impl GPT2Generator {
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(&var_store.root(), &config);
let model = GPT2LMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id();

View File

@ -716,7 +716,7 @@ impl GptNeoGenerator {
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(&var_store.root(), &config)?;
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id();

View File

@ -71,7 +71,7 @@ impl SinusoidalPositionalEmbedding {
) -> Tensor {
let half_dim = embedding_dim / 2;
let emb = -(10000f64.ln() as f64) / ((half_dim - 1) as f64);
let emb = -(10000f64.ln()) / ((half_dim - 1) as f64);
let emb = (Tensor::arange(half_dim, (Kind::Float, device)) * emb).exp();
let emb =
Tensor::arange(num_embeddings, (Kind::Float, device)).unsqueeze(1) * emb.unsqueeze(0);

View File

@ -651,7 +651,7 @@ impl M2M100Generator {
let mut var_store = nn::VarStore::new(device);
let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(&var_store.root(), &config);
let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -681,7 +681,7 @@ impl M2M100Generator {
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
@ -887,6 +887,6 @@ mod test {
let vs = tch::nn::VarStore::new(device);
let config = M2M100Config::from_file(config_path);
let _: Box<dyn Send> = Box::new(M2M100Model::new(&vs.root(), &config));
let _: Box<dyn Send> = Box::new(M2M100Model::new(vs.root(), &config));
}
}

View File

@ -872,7 +872,7 @@ impl MarianGenerator {
let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let model = MarianForConditionalGeneration::new(&var_store.root(), &config);
let model = MarianForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -904,7 +904,7 @@ impl MarianGenerator {
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());

View File

@ -900,7 +900,7 @@ impl MBartGenerator {
let mut var_store = nn::VarStore::new(device);
let config = MBartConfig::from_file(config_path);
let model = MBartForConditionalGeneration::new(&var_store.root(), &config);
let model = MBartForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -930,7 +930,7 @@ impl MBartGenerator {
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
@ -1136,6 +1136,6 @@ mod test {
let vs = tch::nn::VarStore::new(device);
let config = MBartConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(MBartModel::new(&vs.root(), &config));
let _: Box<dyn Send> = Box::new(MBartModel::new(vs.root(), &config));
}
}

View File

@ -19,6 +19,7 @@ use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::nn::{Init, LayerNormConfig, Module};
use tch::{nn, Kind, Tensor};
@ -292,7 +293,7 @@ impl MobileBertLMPredictionHead {
config.hidden_size - config.embedding_size,
config.vocab_size,
],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let bias = p.var("bias", &[config.vocab_size], Init::Const(0.0));

View File

@ -504,7 +504,7 @@ impl OpenAIGenerator {
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(&var_store.root(), &config);
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id();

View File

@ -624,7 +624,7 @@ impl PegasusConditionalGenerator {
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = PegasusConfig::from_file(config_path);
let model = PegasusForConditionalGeneration::new(&var_store.root(), &config);
let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
@ -654,7 +654,7 @@ impl PegasusConditionalGenerator {
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size())
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());

View File

@ -403,7 +403,7 @@ pub(crate) mod private_generation_utils {
prev_output_tokens: &Tensor,
repetition_penalty: f64,
) {
for i in 0..(batch_size * num_beams as i64) {
for i in 0..(batch_size * num_beams) {
for token_position in 0..prev_output_tokens.get(i).size()[0] {
let token = prev_output_tokens.get(i).int64_value(&[token_position]);
let updated_value = &next_token_logits.double_value(&[i, token]);
@ -826,8 +826,8 @@ pub(crate) mod private_generation_utils {
if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
gen_opt.no_repeat_ngram_size as i64,
current_length as i64,
gen_opt.no_repeat_ngram_size,
current_length,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
@ -875,7 +875,7 @@ pub(crate) mod private_generation_utils {
}
self.top_k_top_p_filtering(
&mut next_token_logits,
gen_opt.top_k as i64,
gen_opt.top_k,
gen_opt.top_p,
1,
);
@ -915,7 +915,7 @@ pub(crate) mod private_generation_utils {
&sentence_with_eos
.to_kind(Kind::Bool)
.to_device(sentence_lengths.device()),
current_length as i64 + 1,
current_length + 1,
);
unfinished_sentences = -unfinished_sentences * (sentence_with_eos - 1);
}
@ -943,7 +943,7 @@ pub(crate) mod private_generation_utils {
&unfinished_sentences
.to_kind(Kind::Bool)
.to_device(sentence_lengths.device()),
current_length as i64,
current_length,
);
break;
}
@ -1927,10 +1927,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let batch_size = *input_ids.size().first().unwrap();
let (effective_batch_size, effective_batch_mult) = match do_sample {
true => (
batch_size * num_return_sequences as i64,
num_return_sequences as i64,
),
true => (batch_size * num_return_sequences, num_return_sequences),
false => (batch_size, 1),
};
@ -1946,7 +1943,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let encoder_outputs = self.encode(&input_ids, Some(&attention_mask)).unwrap();
let expanded_batch_indices = Tensor::arange(batch_size, (Int64, input_ids.device()))
.view((-1, 1))
.repeat(&[1, num_beams as i64 * effective_batch_mult])
.repeat(&[1, num_beams * effective_batch_mult])
.view(-1);
Some(encoder_outputs.index_select(0, &expanded_batch_indices))
} else {
@ -1959,19 +1956,19 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
input_ids
.unsqueeze(1)
.expand(
&[batch_size, effective_batch_mult * num_beams as i64, cur_len],
&[batch_size, effective_batch_mult * num_beams, cur_len],
true,
)
.contiguous()
.view((effective_batch_size * num_beams as i64, cur_len)),
.view((effective_batch_size * num_beams, cur_len)),
attention_mask
.unsqueeze(1)
.expand(
&[batch_size, effective_batch_mult * num_beams as i64, cur_len],
&[batch_size, effective_batch_mult * num_beams, cur_len],
true,
)
.contiguous()
.view((effective_batch_size * num_beams as i64, cur_len)),
.view((effective_batch_size * num_beams, cur_len)),
)
} else {
(input_ids, attention_mask)
@ -1982,7 +1979,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
.expect("decoder start id must be specified for encoder decoders")
});
let input_ids = Tensor::full(
&[effective_batch_size * num_beams as i64, 1],
&[effective_batch_size * num_beams, 1],
decoder_start_token_id,
(Int64, input_ids.device()),
);
@ -1990,15 +1987,11 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
attention_mask
.unsqueeze(1)
.expand(
&[
batch_size,
effective_batch_mult * num_beams as i64,
input_ids_len,
],
&[batch_size, effective_batch_mult * num_beams, input_ids_len],
true,
)
.contiguous()
.view((effective_batch_size * num_beams as i64, input_ids_len))
.view((effective_batch_size * num_beams, input_ids_len))
} else {
attention_mask
};

View File

@ -144,13 +144,13 @@ impl<'a> KeywordExtractionModel<'a> {
config: KeywordExtractionConfig<'a>,
) -> Result<KeywordExtractionModel<'a>, RustBertError> {
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
&config
config
.sentence_embeddings_config
.tokenizer_config_resource
.get_local_path()?,
);
let sentence_bert_config = SentenceEmbeddingsSentenceBertConfig::from_file(
&config
config
.sentence_embeddings_config
.sentence_bert_config_resource
.get_local_path()?,

View File

@ -407,7 +407,7 @@ impl MaskedLanguageModel {
.unwrap_or(usize::MAX);
let language_encode =
MaskedLanguageOption::new(config.model_type, &var_store.root(), &model_config)?;
MaskedLanguageOption::new(config.model_type, var_store.root(), &model_config)?;
var_store.load(weights_path)?;
let mask_token = config.mask_token;
Ok(MaskedLanguageModel {

View File

@ -620,7 +620,7 @@ impl QuestionAnsweringModel {
let qa_model = QuestionAnsweringOption::new(
question_answering_config.model_type,
&var_store.root(),
var_store.root(),
&model_config,
)?;
@ -878,7 +878,7 @@ impl QuestionAnsweringModel {
max_seq_length - sequence_pair_added_tokens - encoded_query.ids.len();
let mut start_token = 0_usize;
while (spans.len() * doc_stride as usize) < encoded_context.ids.len() {
while (spans.len() * doc_stride) < encoded_context.ids.len() {
let end_token = min(start_token + max_context_length, encoded_context.ids.len());
let sub_encoded_context = TokenIdsWithOffsets {
ids: encoded_context.ids[start_token..end_token].to_vec(),

View File

@ -130,7 +130,7 @@ impl Dense {
bias: dense_conf.bias,
};
let linear = nn::linear(
&vs_dense.root(),
vs_dense.root(),
dense_conf.in_features,
dense_conf.out_features,
linear_conf,

View File

@ -230,11 +230,8 @@ impl SentenceEmbeddingsModel {
transformer_type,
transformer_config_resource.get_local_path()?,
);
let transformer = SentenceEmbeddingsOption::new(
transformer_type,
&var_store.root(),
&transformer_config,
)?;
let transformer =
SentenceEmbeddingsOption::new(transformer_type, var_store.root(), &transformer_config)?;
var_store.load(transformer_weights_resource.get_local_path()?)?;
// Setup pooling layer
@ -310,7 +307,7 @@ impl SentenceEmbeddingsModel {
Tensor::of_slice(
&input
.iter()
.map(|&e| if e == pad_token_id { 0_i64 } else { 1_i64 })
.map(|&e| i64::from(e != pad_token_id))
.collect::<Vec<_>>(),
)
})

View File

@ -593,7 +593,7 @@ impl SequenceClassificationModel {
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let sequence_classifier =
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config)?;
SequenceClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
var_store.load(weights_path)?;
Ok(SequenceClassificationModel {

View File

@ -699,7 +699,7 @@ impl TokenClassificationModel {
.map(|v| v as usize)
.unwrap_or(usize::MAX);
let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config)?;
TokenClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
let label_mapping = model_config.get_label_mapping().clone();
let batch_size = config.batch_size;
var_store.load(weights_path)?;
@ -749,7 +749,7 @@ impl TokenClassificationModel {
let mut start_token = 0_usize;
let total_length = encoded_input.ids.len();
while (spans.len() * doc_stride as usize) < encoded_input.ids.len() {
while (spans.len() * doc_stride) < encoded_input.ids.len() {
let end_token = min(start_token + max_content_length, total_length);
let sub_encoded_input = TokenIdsWithOffsets {
ids: encoded_input.ids[start_token..end_token].to_vec(),
@ -994,8 +994,8 @@ impl TokenClassificationModel {
position_idx: i64,
word_index: u16,
) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
let label_id = labels.int64_value(&[position_idx]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
let offsets = &sentence_tokens.offsets[position_idx as usize];

View File

@ -575,7 +575,7 @@ impl ZeroShotClassificationModel {
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let zero_shot_classifier =
ZeroShotClassificationOption::new(config.model_type, &var_store.root(), &model_config)?;
ZeroShotClassificationOption::new(config.model_type, var_store.root(), &model_config)?;
var_store.load(weights_path)?;
Ok(ZeroShotClassificationModel {
tokenizer,

View File

@ -21,7 +21,7 @@ use crate::prophetnet::embeddings::ProphetNetPositionalEmbeddings;
use crate::prophetnet::ProphetNetConfig;
use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut};
use tch::nn::Init;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Device, Kind, Tensor};
fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor {
@ -210,7 +210,7 @@ impl ProphetNetDecoder {
let ngram_embeddings = p_ngram_embedding.var(
"weight",
&[config.ngram, config.hidden_size],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let output_attentions = config.output_attentions.unwrap_or(false);

View File

@ -965,7 +965,7 @@ impl ProphetNetConditionalGenerator {
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = ProphetNetConfig::from_file(config_path);
let model = ProphetNetForConditionalGeneration::new(&var_store.root(), &config)?;
let model = ProphetNetForConditionalGeneration::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id);

View File

@ -354,7 +354,7 @@ impl ReformerModel {
let must_pad_to_match_chunk_length =
(input_shape.last().unwrap() % self.least_common_mult_chunk_length != 0)
& (*input_shape.last().unwrap() as i64 > self.min_chunk_length)
& (*input_shape.last().unwrap() > self.min_chunk_length)
& old_layer_states.is_none();
let start_idx_pos_encodings = if let Some(layer_states) = &old_layer_states {
@ -1091,7 +1091,7 @@ impl ReformerGenerator {
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = ReformerConfig::from_file(config_path);
let model = ReformerModelWithLMHead::new(&var_store.root(), &config)?;
let model = ReformerModelWithLMHead::new(var_store.root(), &config)?;
var_store.load(weights_path)?;
let bos_token_id = tokenizer.get_bos_id();

View File

@ -17,7 +17,7 @@ use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings;
use std::borrow::Borrow;
use tch::nn::Init;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Tensor};
/// # RoBERTa Pretrained model weight files
@ -218,7 +218,7 @@ impl RobertaLMHead {
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
let bias = p.var("bias", &[config.vocab_size], DEFAULT_KAIMING_UNIFORM);
RobertaLMHead {
dense,

View File

@ -881,7 +881,7 @@ impl T5Generator {
let mut var_store = nn::VarStore::new(device);
let config = T5Config::from_file(config_path);
let model = T5ForConditionalGeneration::new(&var_store.root(), &config);
let model = T5ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id.unwrap_or(-1));

View File

@ -15,7 +15,7 @@
use crate::common::dropout::Dropout;
use crate::xlnet::XLNetConfig;
use std::borrow::Borrow;
use tch::nn::Init;
use tch::nn::init::DEFAULT_KAIMING_UNIFORM;
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
@ -72,52 +72,52 @@ impl XLNetRelativeAttention {
let query = p.var(
"q",
&[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let key = p.var(
"k",
&[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let value = p.var(
"v",
&[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let output = p.var(
"o",
&[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let pos = p.var(
"r",
&[config.d_model, config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let r_r_bias = p.var(
"r_r_bias",
&[config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let r_s_bias = p.var(
"r_s_bias",
&[config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let r_w_bias = p.var(
"r_w_bias",
&[config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let seg_embed = p.var(
"seg_embed",
&[2, config.n_head, config.d_head],
Init::KaimingUniform,
DEFAULT_KAIMING_UNIFORM,
);
let dropout = Dropout::new(config.dropout);

View File

@ -1648,7 +1648,7 @@ impl XLNetGenerator {
let mut var_store = nn::VarStore::new(device);
let config = XLNetConfig::from_file(config_path);
let model = XLNetLMHeadModel::new(&var_store.root(), &config);
let model = XLNetLMHeadModel::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = Some(config.bos_token_id);

View File

@ -35,7 +35,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
let albert_model = AlbertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -109,7 +109,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
let albert_model = AlbertForSequenceClassification::new(vs.root(), &config);
// Define input
let input = [
@ -170,7 +170,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
let albert_model = AlbertForMultipleChoice::new(vs.root(), &config);
// Define input
let input = [
@ -242,7 +242,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForTokenClassification::new(&vs.root(), &config);
let albert_model = AlbertForTokenClassification::new(vs.root(), &config);
// Define input
let input = [
@ -303,7 +303,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
let albert_model = AlbertForQuestionAnswering::new(vs.root(), &config);
// Define input
let input = [

View File

@ -35,7 +35,7 @@ fn bert_masked_lm() -> anyhow::Result<()> {
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = BertConfig::from_file(config_path);
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
let bert_model = BertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -162,7 +162,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
let bert_model = BertForSequenceClassification::new(vs.root(), &config);
// Define input
let input = [
@ -219,7 +219,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
let bert_model = BertForMultipleChoice::new(vs.root(), &config);
// Define input
let input = [
@ -283,7 +283,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = BertForTokenClassification::new(&vs.root(), &config);
let bert_model = BertForTokenClassification::new(vs.root(), &config);
// Define input
let input = [
@ -340,7 +340,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
let bert_model = BertForQuestionAnswering::new(vs.root(), &config);
// Define input
let input = [

View File

@ -41,7 +41,7 @@ fn deberta_natural_language_inference() -> anyhow::Result<()> {
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(&vs.root(), &config);
let model = DebertaForSequenceClassification::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -96,7 +96,7 @@ fn deberta_masked_lm() -> anyhow::Result<()> {
let mut config = DebertaConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let deberta_model = DebertaForMaskedLM::new(&vs.root(), &config);
let deberta_model = DebertaForMaskedLM::new(vs.root(), &config);
// Generate random input
let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device));
@ -170,7 +170,7 @@ fn deberta_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = DebertaForTokenClassification::new(&vs.root(), &config);
let model = DebertaForTokenClassification::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -225,7 +225,7 @@ fn deberta_for_question_answering() -> anyhow::Result<()> {
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForQuestionAnswering::new(&vs.root(), &config);
let model = DebertaForQuestionAnswering::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];

View File

@ -22,7 +22,7 @@ fn deberta_v2_masked_lm() -> anyhow::Result<()> {
let mut config = DebertaV2Config::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let deberta_model = DebertaV2ForMaskedLM::new(&vs.root(), &config);
let deberta_model = DebertaV2ForMaskedLM::new(vs.root(), &config);
// Generate random input
let input_tensor = Tensor::randint(42, &[32, 128], (Kind::Int64, device));
@ -88,7 +88,7 @@ fn deberta_v2_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Neutral"));
dummy_label_mapping.insert(2, String::from("Negative"));
config.id2label = Some(dummy_label_mapping);
let model = DebertaV2ForSequenceClassification::new(&vs.root(), &config);
let model = DebertaV2ForSequenceClassification::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -142,7 +142,7 @@ fn deberta_v2_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = DebertaV2ForTokenClassification::new(&vs.root(), &config);
let model = DebertaV2ForTokenClassification::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -190,7 +190,7 @@ fn deberta_v2_for_question_answering() -> anyhow::Result<()> {
let tokenizer =
DeBERTaV2Tokenizer::from_file(vocab_path.to_str().unwrap(), false, false, false)?;
let config = DebertaV2Config::from_file(config_path);
let model = DebertaV2ForQuestionAnswering::new(&vs.root(), &config);
let model = DebertaV2ForQuestionAnswering::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];

View File

@ -61,7 +61,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = DistilBertConfig::from_file(config_path);
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
let distil_bert_model = DistilBertModelMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -140,7 +140,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
let mut config = DistilBertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
let distil_bert_model = DistilBertForQuestionAnswering::new(vs.root(), &config);
// Define input
let input = [
@ -211,7 +211,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
let distil_bert_model = DistilBertForTokenClassification::new(vs.root(), &config);
// Define input
let input = [

View File

@ -37,7 +37,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
false,
)?;
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
let gpt2_model = GPT2LMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input

View File

@ -32,7 +32,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
let mut config = ElectraConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
let electra_model = ElectraForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -114,7 +114,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
let electra_model = ElectraDiscriminator::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input

View File

@ -30,7 +30,7 @@ fn fnet_masked_lm() -> anyhow::Result<()> {
let tokenizer: FNetTokenizer =
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let config = FNetConfig::from_file(config_path);
let fnet_model = FNetForMaskedLM::new(&vs.root(), &config);
let fnet_model = FNetForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -138,7 +138,7 @@ fn fnet_for_multiple_choice() -> anyhow::Result<()> {
let mut config = FNetConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let fnet_model = FNetForMultipleChoice::new(&vs.root(), &config);
let fnet_model = FNetForMultipleChoice::new(vs.root(), &config);
// Define input
let input = [
@ -201,7 +201,7 @@ fn fnet_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
config.output_hidden_states = Some(true);
let fnet_model = FNetForTokenClassification::new(&vs.root(), &config);
let fnet_model = FNetForTokenClassification::new(vs.root(), &config);
// Define input
let input = [
@ -256,7 +256,7 @@ fn fnet_for_question_answering() -> anyhow::Result<()> {
FNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, false)?;
let mut config = FNetConfig::from_file(config_path);
config.output_hidden_states = Some(true);
let fnet_model = FNetForQuestionAnswering::new(&vs.root(), &config);
let fnet_model = FNetForQuestionAnswering::new(vs.root(), &config);
// Define input
let input = [

View File

@ -35,7 +35,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
false,
)?;
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
let gpt2_model = GPT2LMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input

View File

@ -40,7 +40,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
let mut config = GptNeoConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let gpt_neo_model = GptNeoForCausalLM::new(&vs.root(), &config)?;
let gpt_neo_model = GptNeoForCausalLM::new(vs.root(), &config)?;
vs.load(weights_path)?;
// Define input

View File

@ -197,7 +197,7 @@ fn longformer_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping);
let model = LongformerForSequenceClassification::new(&vs.root(), &config);
let model = LongformerForSequenceClassification::new(vs.root(), &config);
// Define input
let input = ["Very positive sentence", "Second sentence input"];
@ -258,7 +258,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
false,
)?;
let config = LongformerConfig::from_file(config_path);
let model = LongformerForMultipleChoice::new(&vs.root(), &config);
let model = LongformerForMultipleChoice::new(vs.root(), &config);
// Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
@ -337,7 +337,7 @@ fn longformer_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = LongformerForTokenClassification::new(&vs.root(), &config);
let model = LongformerForTokenClassification::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];

View File

@ -35,7 +35,7 @@ fn mobilebert_masked_model() -> anyhow::Result<()> {
let mut config = MobileBertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let mobilebert_model = MobileBertForMaskedLM::new(&vs.root(), &config);
let mobilebert_model = MobileBertForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -130,7 +130,7 @@ fn mobilebert_for_sequence_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping);
let model = MobileBertForSequenceClassification::new(&vs.root(), &config);
let model = MobileBertForSequenceClassification::new(vs.root(), &config);
// Define input
let input = ["Very positive sentence", "Second sentence input"];
@ -176,7 +176,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device);
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path);
let model = MobileBertForMultipleChoice::new(&vs.root(), &config);
let model = MobileBertForMultipleChoice::new(vs.root(), &config);
// Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
@ -240,7 +240,7 @@ fn mobilebert_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = MobileBertForTokenClassification::new(&vs.root(), &config);
let model = MobileBertForTokenClassification::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -287,7 +287,7 @@ fn mobilebert_for_question_answering() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device);
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path);
let model = MobileBertForQuestionAnswering::new(&vs.root(), &config);
let model = MobileBertForQuestionAnswering::new(vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];

View File

@ -39,7 +39,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
true,
)?;
let config = OpenAiGptConfig::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
let openai_gpt = OpenAIGPTLMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input

View File

@ -98,7 +98,7 @@ fn reformer_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let reformer_model = ReformerForSequenceClassification::new(&vs.root(), &config)?;
let reformer_model = ReformerForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = [
@ -159,7 +159,7 @@ fn reformer_for_question_answering() -> anyhow::Result<()> {
let mut config = ReformerConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let reformer_model = ReformerForQuestionAnswering::new(&vs.root(), &config)?;
let reformer_model = ReformerForQuestionAnswering::new(vs.root(), &config)?;
// Define input
let input = [

View File

@ -41,7 +41,7 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
false,
)?;
let config = RobertaConfig::from_file(config_path);
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
let roberta_model = RobertaForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -136,7 +136,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
let roberta_model = RobertaForSequenceClassification::new(vs.root(), &config);
// Define input
let input = [
@ -201,7 +201,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
let mut config = RobertaConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
let roberta_model = RobertaForMultipleChoice::new(vs.root(), &config);
// Define input
let input = [
@ -273,7 +273,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
let roberta_model = RobertaForTokenClassification::new(vs.root(), &config);
// Define input
let input = [

View File

@ -141,7 +141,7 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
let tokenizer: XLNetTokenizer =
XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?;
let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config);
let xlnet_model = XLNetLMHeadModel::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -257,7 +257,7 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
config.id2label = Some(dummy_label_mapping);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let xlnet_model = XLNetForSequenceClassification::new(&vs.root(), &config)?;
let xlnet_model = XLNetForSequenceClassification::new(vs.root(), &config)?;
// Define input
let input = ["Very positive sentence", "Second sentence input"];
@ -322,7 +322,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device);
let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetForMultipleChoice::new(&vs.root(), &config)?;
let xlnet_model = XLNetForMultipleChoice::new(vs.root(), &config)?;
// Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
@ -396,7 +396,7 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let xlnet_model = XLNetForTokenClassification::new(&vs.root(), &config)?;
let xlnet_model = XLNetForTokenClassification::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
@ -453,7 +453,7 @@ fn xlnet_for_question_answering() -> anyhow::Result<()> {
let vs = nn::VarStore::new(device);
let tokenizer = XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetForQuestionAnswering::new(&vs.root(), &config)?;
let xlnet_model = XLNetForQuestionAnswering::new(vs.root(), &config)?;
// Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];