mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Fixed Clippy warnings (#204)
This commit is contained in:
parent
b444780c18
commit
4175942cc4
@ -46,7 +46,6 @@ fn sst2_forward_pass(iters: u64, model: &SentimentModel, sst2_data: &[String]) -
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Record {
|
struct Record {
|
||||||
sentence: String,
|
sentence: String,
|
||||||
label: i8,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
|
fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
|
@ -22,7 +22,6 @@ use std::{env, fs};
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Record {
|
struct Record {
|
||||||
sentence: String,
|
sentence: String,
|
||||||
label: i8,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
|
fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
|
||||||
@ -47,7 +46,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
|
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
|
||||||
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
||||||
sst2_path.push("train.tsv");
|
sst2_path.push("train.tsv");
|
||||||
let inputs = ss2_processor(sst2_path).unwrap();
|
let inputs = &ss2_processor(sst2_path).unwrap()[..100];
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let batch_size = 64;
|
let batch_size = 64;
|
||||||
|
@ -1427,7 +1427,7 @@ pub struct GeneratedIndicesOutput {
|
|||||||
pub score: Option<f64>,
|
pub score: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy, Default)]
|
||||||
/// # Generation options for text generation.
|
/// # Generation options for text generation.
|
||||||
/// When provided to a `generate` method, these options will take priority over the `GenerateConfig` used to create the
|
/// When provided to a `generate` method, these options will take priority over the `GenerateConfig` used to create the
|
||||||
/// `LanguageGenerator`. Some of these options may be left as `None`, options without a value will individually default
|
/// `LanguageGenerator`. Some of these options may be left as `None`, options without a value will individually default
|
||||||
@ -1476,32 +1476,6 @@ pub struct GenerateOptions<'a> {
|
|||||||
pub output_scores: bool,
|
pub output_scores: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for GenerateOptions<'_> {
|
|
||||||
fn default() -> Self {
|
|
||||||
GenerateOptions {
|
|
||||||
min_length: None,
|
|
||||||
max_length: None,
|
|
||||||
max_new_tokens: None,
|
|
||||||
early_stopping: None,
|
|
||||||
num_return_sequences: None,
|
|
||||||
num_beams: None,
|
|
||||||
num_beam_groups: None,
|
|
||||||
do_sample: None,
|
|
||||||
temperature: None,
|
|
||||||
top_k: None,
|
|
||||||
top_p: None,
|
|
||||||
repetition_penalty: None,
|
|
||||||
length_penalty: None,
|
|
||||||
no_repeat_ngram_size: None,
|
|
||||||
diversity_penalty: None,
|
|
||||||
decoder_start_token_id: None,
|
|
||||||
forced_bos_token_id: None,
|
|
||||||
prefix_allowed_tokens_fn: None,
|
|
||||||
bad_word_ids: None,
|
|
||||||
output_scores: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
macro_rules! unpack_config {
|
macro_rules! unpack_config {
|
||||||
($field_name:ident, $generate_options: ident, $generate_config: ident) => {
|
($field_name:ident, $generate_options: ident, $generate_config: ident) => {
|
||||||
$generate_options.map_or($generate_config.$field_name, |opts| {
|
$generate_options.map_or($generate_config.$field_name, |opts| {
|
||||||
|
@ -936,10 +936,8 @@ pub struct LocalSelfAttention {
|
|||||||
num_chunks_after: i64,
|
num_chunks_after: i64,
|
||||||
is_decoder: bool,
|
is_decoder: bool,
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
pad_token_id: i64,
|
|
||||||
num_attention_heads: i64,
|
num_attention_heads: i64,
|
||||||
attention_head_size: i64,
|
attention_head_size: i64,
|
||||||
hidden_size: i64,
|
|
||||||
query: nn::Linear,
|
query: nn::Linear,
|
||||||
key: nn::Linear,
|
key: nn::Linear,
|
||||||
value: nn::Linear,
|
value: nn::Linear,
|
||||||
@ -965,7 +963,6 @@ impl LocalSelfAttention {
|
|||||||
let num_chunks_before = config.local_num_chunks_before.unwrap_or(1);
|
let num_chunks_before = config.local_num_chunks_before.unwrap_or(1);
|
||||||
let num_chunks_after = config.local_num_chunks_after.unwrap_or(0);
|
let num_chunks_after = config.local_num_chunks_after.unwrap_or(0);
|
||||||
let is_decoder = config.is_decoder;
|
let is_decoder = config.is_decoder;
|
||||||
let pad_token_id = config.pad_token_id;
|
|
||||||
|
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
|
||||||
@ -994,10 +991,8 @@ impl LocalSelfAttention {
|
|||||||
num_chunks_after,
|
num_chunks_after,
|
||||||
is_decoder,
|
is_decoder,
|
||||||
dropout,
|
dropout,
|
||||||
pad_token_id,
|
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_size,
|
attention_head_size,
|
||||||
hidden_size,
|
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::common::embeddings::process_ids_embeddings_pair;
|
use crate::common::embeddings::process_ids_embeddings_pair;
|
||||||
use crate::reformer::attention_utils::get_least_common_mult_chunk_len;
|
|
||||||
use crate::reformer::ReformerConfig;
|
use crate::reformer::ReformerConfig;
|
||||||
use crate::RustBertError;
|
use crate::RustBertError;
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
@ -25,7 +24,6 @@ use tch::{nn, Kind, Tensor};
|
|||||||
pub struct AxialPositionEmbeddings {
|
pub struct AxialPositionEmbeddings {
|
||||||
weights: Vec<Tensor>,
|
weights: Vec<Tensor>,
|
||||||
axial_pos_shape: Vec<i64>,
|
axial_pos_shape: Vec<i64>,
|
||||||
least_common_mult_chunk_length: i64,
|
|
||||||
dropout_prob: f64,
|
dropout_prob: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,12 +44,6 @@ impl AxialPositionEmbeddings {
|
|||||||
)));
|
)));
|
||||||
};
|
};
|
||||||
|
|
||||||
let least_common_mult_chunk_length = get_least_common_mult_chunk_len(
|
|
||||||
&config.attn_layers,
|
|
||||||
config.lsh_attn_chunk_length,
|
|
||||||
config.local_attn_chunk_length,
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut weights: Vec<Tensor> = vec![];
|
let mut weights: Vec<Tensor> = vec![];
|
||||||
let p_weights = p / "weights";
|
let p_weights = p / "weights";
|
||||||
for (axis_index, axial_pos_embd_dim) in config.axial_pos_embds_dim.iter().enumerate() {
|
for (axis_index, axial_pos_embd_dim) in config.axial_pos_embds_dim.iter().enumerate() {
|
||||||
@ -64,7 +56,6 @@ impl AxialPositionEmbeddings {
|
|||||||
Ok(AxialPositionEmbeddings {
|
Ok(AxialPositionEmbeddings {
|
||||||
weights,
|
weights,
|
||||||
axial_pos_shape,
|
axial_pos_shape,
|
||||||
least_common_mult_chunk_length,
|
|
||||||
dropout_prob: config.hidden_dropout_prob,
|
dropout_prob: config.hidden_dropout_prob,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,6 @@ pub struct T5Attention {
|
|||||||
is_bidirectional: bool,
|
is_bidirectional: bool,
|
||||||
has_relative_attention_bias: bool,
|
has_relative_attention_bias: bool,
|
||||||
relative_attention_num_buckets: i64,
|
relative_attention_num_buckets: i64,
|
||||||
d_model: i64,
|
|
||||||
d_kv: i64,
|
d_kv: i64,
|
||||||
n_heads: i64,
|
n_heads: i64,
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
@ -106,7 +105,6 @@ impl T5Attention {
|
|||||||
is_bidirectional,
|
is_bidirectional,
|
||||||
has_relative_attention_bias,
|
has_relative_attention_bias,
|
||||||
relative_attention_num_buckets: config.relative_attention_num_buckets,
|
relative_attention_num_buckets: config.relative_attention_num_buckets,
|
||||||
d_model: config.d_model,
|
|
||||||
d_kv: config.d_kv,
|
d_kv: config.d_kv,
|
||||||
n_heads: config.num_heads,
|
n_heads: config.num_heads,
|
||||||
dropout,
|
dropout,
|
||||||
|
@ -42,9 +42,6 @@ impl LayerState {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct XLNetRelativeAttention {
|
pub struct XLNetRelativeAttention {
|
||||||
num_attention_heads: i64,
|
|
||||||
attention_head_size: i64,
|
|
||||||
hidden_size: i64,
|
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
@ -135,9 +132,6 @@ impl XLNetRelativeAttention {
|
|||||||
let scale = 1f64 / ((config.d_head as f64).powf(0.5f64));
|
let scale = 1f64 / ((config.d_head as f64).powf(0.5f64));
|
||||||
|
|
||||||
XLNetRelativeAttention {
|
XLNetRelativeAttention {
|
||||||
num_attention_heads: config.n_head,
|
|
||||||
attention_head_size: config.d_head,
|
|
||||||
hidden_size: config.d_model,
|
|
||||||
dropout,
|
dropout,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
query,
|
query,
|
||||||
|
Loading…
Reference in New Issue
Block a user