mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-11 12:55:34 +03:00
clippy updates
This commit is contained in:
parent
a52279e24d
commit
b410c52b26
@ -104,6 +104,13 @@ pub struct AlbertConfig {
|
||||
|
||||
impl Config<AlbertConfig> for AlbertConfig {}
|
||||
|
||||
pub struct AlbertOutput {
|
||||
pub hidden_state: Tensor,
|
||||
pub pooled_output: Tensor,
|
||||
pub all_hidden_states: Option<Vec<Tensor>>,
|
||||
pub all_attentions: Option<Vec<Vec<Tensor>>>,
|
||||
}
|
||||
|
||||
/// # ALBERT Base model
|
||||
/// Base architecture for ALBERT models. Task-specific models will be built from this common base model
|
||||
/// It is made of the following blocks:
|
||||
@ -223,15 +230,7 @@ impl AlbertModel {
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Vec<Tensor>>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
) -> Result<AlbertOutput, &'static str> {
|
||||
let (input_shape, device) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => {
|
||||
@ -276,12 +275,12 @@ impl AlbertModel {
|
||||
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
|
||||
let pooled_output = (self.pooler_activation)(&pooled_output);
|
||||
|
||||
Ok((
|
||||
Ok(AlbertOutput {
|
||||
hidden_state,
|
||||
pooled_output,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -450,7 +449,7 @@ impl AlbertForMaskedLM {
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
let base_model_output = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
@ -461,8 +460,12 @@ impl AlbertForMaskedLM {
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let prediction_scores = self.predictions.forward(&hidden_state);
|
||||
(prediction_scores, all_hidden_states, all_attentions)
|
||||
let prediction_scores = self.predictions.forward(&base_model_output.hidden_state);
|
||||
(
|
||||
prediction_scores,
|
||||
base_model_output.all_hidden_states,
|
||||
base_model_output.all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -587,7 +590,7 @@ impl AlbertForSequenceClassification {
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
let base_model_output = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
@ -598,10 +601,15 @@ impl AlbertForSequenceClassification {
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = pooled_output
|
||||
let logits = base_model_output
|
||||
.pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(logits, all_hidden_states, all_attentions)
|
||||
(
|
||||
logits,
|
||||
base_model_output.all_hidden_states,
|
||||
base_model_output.all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -723,7 +731,7 @@ impl AlbertForTokenClassification {
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (sequence_output, _, all_hidden_states, all_attentions) = self
|
||||
let base_model_output = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
@ -734,10 +742,15 @@ impl AlbertForTokenClassification {
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = sequence_output
|
||||
let logits = base_model_output
|
||||
.hidden_state
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(logits, all_hidden_states, all_attentions)
|
||||
(
|
||||
logits,
|
||||
base_model_output.all_hidden_states,
|
||||
base_model_output.all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -854,7 +867,7 @@ impl AlbertForQuestionAnswering {
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Vec<Tensor>>>,
|
||||
) {
|
||||
let (sequence_output, _, all_hidden_states, all_attentions) = self
|
||||
let base_model_output = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
@ -865,12 +878,20 @@ impl AlbertForQuestionAnswering {
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
|
||||
let logits = base_model_output
|
||||
.hidden_state
|
||||
.apply(&self.qa_outputs)
|
||||
.split(1, -1);
|
||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||
let start_logits = start_logits.squeeze1(-1);
|
||||
let end_logits = end_logits.squeeze1(-1);
|
||||
|
||||
(start_logits, end_logits, all_hidden_states, all_attentions)
|
||||
(
|
||||
start_logits,
|
||||
end_logits,
|
||||
base_model_output.all_hidden_states,
|
||||
base_model_output.all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1024,7 +1045,7 @@ impl AlbertForMultipleChoice {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
let base_model_output = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
@ -1035,11 +1056,16 @@ impl AlbertForMultipleChoice {
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = pooled_output
|
||||
let logits = base_model_output
|
||||
.pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier)
|
||||
.view((-1, num_choices));
|
||||
|
||||
Ok((logits, all_hidden_states, all_attentions))
|
||||
Ok((
|
||||
logits,
|
||||
base_model_output.all_hidden_states,
|
||||
base_model_output.all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::albert::albert::Activation;
|
||||
use crate::albert::albert_model::Activation;
|
||||
use crate::albert::attention::AlbertSelfAttention;
|
||||
use crate::albert::AlbertConfig;
|
||||
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
|
||||
|
@ -11,7 +11,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/albert.rs`, run with `cargo run --example albert`.
|
||||
//! A full working example is provided in `examples/albert`, run with `cargo run --example albert`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
@ -53,12 +53,12 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod albert;
|
||||
mod albert_model;
|
||||
mod attention;
|
||||
mod embeddings;
|
||||
mod encoder;
|
||||
|
||||
pub use albert::{
|
||||
pub use albert_model::{
|
||||
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
|
||||
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||
AlbertModel, AlbertModelResources, AlbertVocabResources,
|
||||
|
@ -12,7 +12,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::attention::{LayerState, SelfAttention};
|
||||
use crate::bart::bart::Activation;
|
||||
use crate::bart::bart_model::Activation;
|
||||
use crate::bart::embeddings::{
|
||||
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
|
||||
};
|
||||
|
@ -12,7 +12,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::attention::SelfAttention;
|
||||
use crate::bart::bart::Activation;
|
||||
use crate::bart::bart_model::Activation;
|
||||
use crate::bart::embeddings::{
|
||||
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
|
||||
};
|
||||
|
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/bart.rs`, run with `cargo run --example bart`.
|
||||
//! A full working example is provided in `examples/bart`, run with `cargo run --example bart`.
|
||||
//! Alternatively, the summarization capabilities are illustrated in `examples/summarization.rs`, run with `cargo run --example summarization`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
@ -58,13 +58,13 @@
|
||||
//! ```
|
||||
|
||||
mod attention;
|
||||
mod bart;
|
||||
mod bart_model;
|
||||
mod decoder;
|
||||
mod embeddings;
|
||||
mod encoder;
|
||||
|
||||
pub use attention::LayerState;
|
||||
pub use bart::{
|
||||
pub use bart_model::{
|
||||
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
|
||||
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
|
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bert::bert::{Activation, BertConfig};
|
||||
use crate::bert::bert_model::{Activation, BertConfig};
|
||||
use crate::common::activations::{_gelu, _mish, _relu};
|
||||
use crate::common::dropout::Dropout;
|
||||
use std::borrow::Borrow;
|
||||
|
@ -11,7 +11,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bert::bert::BertConfig;
|
||||
use crate::bert::bert_model::BertConfig;
|
||||
use crate::common::dropout::Dropout;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
|
@ -12,7 +12,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
|
||||
use crate::bert::bert::BertConfig;
|
||||
use crate::bert::bert_model::BertConfig;
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
|
@ -10,7 +10,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/bert.rs`, run with `cargo run --example bert`.
|
||||
//! A full working example is provided in `examples/bert`, run with `cargo run --example bert`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
@ -53,11 +53,11 @@
|
||||
//! ```
|
||||
|
||||
mod attention;
|
||||
mod bert;
|
||||
mod bert_model;
|
||||
mod embeddings;
|
||||
pub(crate) mod encoder;
|
||||
|
||||
pub use bert::{
|
||||
pub use bert_model::{
|
||||
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
|
||||
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
|
||||
BertModelResources, BertVocabResources,
|
||||
|
@ -11,7 +11,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::distilbert::DistilBertConfig;
|
||||
use crate::distilbert::distilbert_model::DistilBertConfig;
|
||||
use std::borrow::Borrow;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::{nn, Tensor};
|
||||
|
@ -11,7 +11,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::distilbert::DistilBertConfig;
|
||||
use crate::distilbert::distilbert_model::DistilBertConfig;
|
||||
use std::borrow::Borrow;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
|
||||
|
@ -55,11 +55,11 @@
|
||||
//! ```
|
||||
|
||||
mod attention;
|
||||
mod distilbert;
|
||||
mod distilbert_model;
|
||||
mod embeddings;
|
||||
mod transformer;
|
||||
|
||||
pub use distilbert::{
|
||||
pub use distilbert_model::{
|
||||
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
|
||||
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
|
||||
|
@ -13,7 +13,7 @@
|
||||
use crate::common::activations::{_gelu, _relu};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::attention::MultiHeadSelfAttention;
|
||||
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
|
||||
use crate::distilbert::distilbert_model::{Activation, DistilBertConfig};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::nn::LayerNorm;
|
||||
use tch::{nn, Tensor};
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::electra::electra::ElectraConfig;
|
||||
use crate::electra::electra_model::ElectraConfig;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
@ -56,10 +56,10 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod electra;
|
||||
mod electra_model;
|
||||
mod embeddings;
|
||||
|
||||
pub use electra::{
|
||||
pub use electra_model::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
|
||||
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
|
||||
ElectraModelResources, ElectraVocabResources,
|
||||
|
@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::gpt2::gpt2::Gpt2Config;
|
||||
use crate::gpt2::gpt2_model::Gpt2Config;
|
||||
use std::borrow::Borrow;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::{Init, Module};
|
||||
@ -133,15 +133,15 @@ impl Attention {
|
||||
|
||||
fn attention(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let mut w = q.matmul(&k);
|
||||
let mut w = query.matmul(&key);
|
||||
if self.scale {
|
||||
w = w / (*v.size().last().unwrap() as f64).sqrt();
|
||||
w = w / (*value.size().last().unwrap() as f64).sqrt();
|
||||
}
|
||||
|
||||
let (nd, ns) = (w.size()[2], w.size()[3]);
|
||||
@ -152,7 +152,7 @@ impl Attention {
|
||||
w = w + mask;
|
||||
}
|
||||
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
||||
let output = w.matmul(&v);
|
||||
let output = w.matmul(&value);
|
||||
|
||||
if self.output_attentions {
|
||||
(output, Some(w))
|
||||
|
@ -56,10 +56,10 @@
|
||||
//! ```
|
||||
|
||||
pub(crate) mod attention;
|
||||
mod gpt2;
|
||||
mod gpt2_model;
|
||||
pub(crate) mod transformer;
|
||||
|
||||
pub use gpt2::{
|
||||
pub use gpt2_model::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
|
||||
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
|
||||
};
|
||||
|
@ -15,7 +15,7 @@
|
||||
use crate::common::activations::{_gelu_new, _relu, _swish};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::gpt2::attention::{Attention, GPTConv1D};
|
||||
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
|
||||
use crate::gpt2::gpt2_model::{Gpt2Config, GptActivation};
|
||||
use std::borrow::Borrow;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
|
@ -57,9 +57,9 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod marian;
|
||||
mod marian_model;
|
||||
|
||||
pub use marian::{
|
||||
pub use marian_model::{
|
||||
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
|
||||
MarianSpmResources, MarianVocabResources,
|
||||
};
|
||||
|
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/openai_gpt.rs`, run with `cargo run --example openai_gpt`.
|
||||
//! A full working example is provided in `examples/openai_gpt`, run with `cargo run --example openai_gpt`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
@ -55,10 +55,10 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod openai_gpt;
|
||||
mod openai_gpt_model;
|
||||
mod transformer;
|
||||
|
||||
pub use openai_gpt::{
|
||||
pub use openai_gpt_model::{
|
||||
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
|
||||
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||
};
|
||||
|
@ -306,12 +306,10 @@ impl Conversation {
|
||||
pub fn get_last_input(&self) -> Option<&str> {
|
||||
if self.new_user_input.is_some() {
|
||||
Some(self.new_user_input.as_ref().unwrap().as_str())
|
||||
} else if self.past_user_inputs.len() > 0 {
|
||||
Some(self.past_user_inputs.last().unwrap().as_str())
|
||||
} else {
|
||||
if self.past_user_inputs.len() > 0 {
|
||||
Some(self.past_user_inputs.last().unwrap().as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -707,9 +707,9 @@ impl TokenClassificationModel {
|
||||
if sub_tokens.len() > 1 {
|
||||
let (label_index, label) =
|
||||
self.consolidate_labels(sub_tokens, label_aggregation_function);
|
||||
let sentence = (&sub_tokens[0]).sentence;
|
||||
let index = (&sub_tokens[0]).index;
|
||||
let word_index = (&sub_tokens[0]).word_index;
|
||||
let sentence = (sub_tokens[0]).sentence;
|
||||
let index = (sub_tokens[0]).index;
|
||||
let word_index = (sub_tokens[0]).word_index;
|
||||
let offset_start = match &sub_tokens.first().unwrap().offset {
|
||||
Some(offset) => Some(offset.begin),
|
||||
None => None,
|
||||
|
@ -10,7 +10,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/robert.rs`, run with `cargo run --example roberta`.
|
||||
//! A full working example is provided in `examples/roberta.rs`, run with `cargo run --example roberta`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
@ -63,10 +63,10 @@
|
||||
//! ```
|
||||
|
||||
mod embeddings;
|
||||
mod roberta;
|
||||
mod roberta_model;
|
||||
|
||||
pub use embeddings::RobertaEmbeddings;
|
||||
pub use roberta::{
|
||||
pub use roberta_model::{
|
||||
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
|
||||
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||
|
@ -55,10 +55,10 @@ pub struct T5Attention {
|
||||
inner_dim: i64,
|
||||
output_attentions: bool,
|
||||
store_cache: bool,
|
||||
q: nn::Linear,
|
||||
k: nn::Linear,
|
||||
v: nn::Linear,
|
||||
o: nn::Linear,
|
||||
query: nn::Linear,
|
||||
key: nn::Linear,
|
||||
value: nn::Linear,
|
||||
output: nn::Linear,
|
||||
relative_attention_bias: Option<nn::Embedding>,
|
||||
}
|
||||
|
||||
@ -82,10 +82,10 @@ impl T5Attention {
|
||||
};
|
||||
|
||||
let inner_dim = config.num_heads * config.d_kv;
|
||||
let k = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
|
||||
let v = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
|
||||
let q = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
|
||||
let o = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
|
||||
let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
|
||||
let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
|
||||
let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
|
||||
let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
|
||||
|
||||
let dropout = Dropout::new(config.dropout_rate);
|
||||
let relative_attention_bias = if has_relative_attention_bias {
|
||||
@ -110,10 +110,10 @@ impl T5Attention {
|
||||
inner_dim,
|
||||
output_attentions,
|
||||
store_cache,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
relative_attention_bias,
|
||||
}
|
||||
}
|
||||
@ -155,17 +155,17 @@ impl T5Attention {
|
||||
None => real_query_length,
|
||||
};
|
||||
|
||||
let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.q), bs);
|
||||
let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.query), bs);
|
||||
|
||||
let (mut k, mut v) = if kv.is_none() {
|
||||
(
|
||||
self.shape(hidden_states.apply(&self.k), bs),
|
||||
self.shape(hidden_states.apply(&self.v), bs),
|
||||
self.shape(hidden_states.apply(&self.key), bs),
|
||||
self.shape(hidden_states.apply(&self.value), bs),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
self.shape(kv.as_ref().unwrap().apply(&self.k), bs),
|
||||
self.shape(kv.as_ref().unwrap().apply(&self.v), bs),
|
||||
self.shape(kv.as_ref().unwrap().apply(&self.key), bs),
|
||||
self.shape(kv.as_ref().unwrap().apply(&self.value), bs),
|
||||
)
|
||||
};
|
||||
|
||||
@ -219,7 +219,7 @@ impl T5Attention {
|
||||
.apply_t(&self.dropout, train);
|
||||
let context = self
|
||||
.unshape(attention_weights.matmul(&v), bs)
|
||||
.apply(&self.o);
|
||||
.apply(&self.output);
|
||||
|
||||
let attention_weights = if self.output_attentions {
|
||||
Some(attention_weights)
|
||||
|
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example (translation) is provided in `examples/t5.rs`, run with `cargo run --example t5`.
|
||||
//! A full working example (translation) is provided in `examples/t5`, run with `cargo run --example t5`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
@ -51,10 +51,10 @@
|
||||
mod attention;
|
||||
mod encoder;
|
||||
mod layer_norm;
|
||||
mod t5;
|
||||
mod t5_model;
|
||||
|
||||
pub use attention::LayerState;
|
||||
pub use t5::{
|
||||
pub use t5_model::{
|
||||
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelResources, T5Prefix,
|
||||
T5VocabResources,
|
||||
};
|
||||
|
@ -717,7 +717,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
|
||||
None,
|
||||
train,
|
||||
),
|
||||
_ => Err("Cache not compatible with T5 Model")?,
|
||||
_ => return Err("Cache not compatible with T5 Model"),
|
||||
};
|
||||
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
|
Loading…
Reference in New Issue
Block a user