Mixed resources (#291)

* - made `merges` resource optional for all pipelines
- allow mixing local and remote resources for pipelines

* Updated changelog

* Fixed Clippy warnings
This commit is contained in:
guillaume-be 2022-10-30 07:39:52 +00:00 committed by GitHub
parent 78da0f4814
commit 340be36ed9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 203 additions and 150 deletions

View File

@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file. The format
## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines
## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.

View File

@ -17,7 +17,9 @@ fn create_text_generation_model() -> TextGenerationModel {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 30,
do_sample: true,

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 32,
do_sample: false,

View File

@ -30,9 +30,6 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
@ -41,7 +38,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: None,
min_length: 100,
max_length: 100,
do_sample: true,

View File

@ -27,9 +27,6 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
@ -39,7 +36,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: None,
max_length: 32,
do_sample: false,
num_beams: 3,

View File

@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
num_beams: 1,
length_penalty: 1.0,
min_length: 56,

View File

@ -33,8 +33,8 @@ fn main() -> anyhow::Result<()> {
model_type: ModelType::Pegasus,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
vocab_resource,
merges_resource: None,
length_penalty: 1.0,
num_beams: 4,
no_repeat_ngram_size: 3,

View File

@ -35,8 +35,8 @@ fn main() -> anyhow::Result<()> {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
vocab_resource,
merges_resource: None,
length_penalty: 1.2,
num_beams: 4,
no_repeat_ngram_size: 3,

View File

@ -26,8 +26,8 @@ fn main() -> anyhow::Result<()> {
ModelType::T5,
weights_resource,
config_resource,
vocab_resource.clone(),
vocab_resource,
None,
);
let summarization_model = SummarizationModel::new(summarization_config)?;

View File

@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -36,7 +36,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -26,8 +26,6 @@ fn main() -> anyhow::Result<()> {
let config_resource =
RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY);
let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let merges_resource =
RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
@ -37,7 +35,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
None,
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -22,7 +22,6 @@ fn main() -> anyhow::Result<()> {
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let source_languages = [
Language::English,
@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
None,
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -1067,7 +1067,15 @@ impl BartGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"BART expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::Bart,

View File

@ -708,7 +708,15 @@ impl GPT2Generator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT2 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::GPT2,

View File

@ -683,7 +683,15 @@ impl GptNeoGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT-Neo expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::GPTNeo,

View File

@ -44,7 +44,7 @@
//! model_resource,
//! config_resource,
//! vocab_resource,
//! merges_resource,
//! merges_resource: Some(merges_resource),
//! num_beams: 4,
//! no_repeat_ngram_size: 3,
//! device: Device::cuda_if_available(),

View File

@ -617,7 +617,15 @@ impl M2M100Generator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"M2M100 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::M2M100,

View File

@ -837,7 +837,16 @@ impl MarianGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<MarianGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let sentence_piece_path = generate_config.merges_resource.get_local_path()?;
let sentence_piece_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"Marian expects a merges (SentencePiece model) resources to be provided"
.to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::Marian,

View File

@ -470,7 +470,15 @@ impl OpenAIGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?;
let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::OpenAiGpt,

View File

@ -82,7 +82,7 @@ pub struct ConversationConfig {
/// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -131,9 +131,9 @@ impl Default for ConversationConfig {
vocab_resource: Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
)),
merges_resource: Box::new(RemoteResource::from_pretrained(
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
)),
))),
min_length: 0,
max_length: 1000,
min_length_for_response: 64,

View File

@ -103,7 +103,7 @@ pub struct GenerateConfig {
/// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -143,7 +143,9 @@ impl Default for GenerateConfig {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)),
merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 20,
do_sample: true,

View File

@ -166,18 +166,20 @@ impl QuestionAnsweringConfig {
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,
@ -210,12 +212,12 @@ impl QuestionAnsweringConfig {
/// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new<R>(
pub fn custom_new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
@ -225,7 +227,9 @@ impl QuestionAnsweringConfig {
max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
QuestionAnsweringConfig {
model_type,

View File

@ -134,18 +134,20 @@ impl SequenceClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> SequenceClassificationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
SequenceClassificationConfig {
model_type,

View File

@ -92,7 +92,7 @@ pub struct SummarizationConfig {
/// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -135,22 +135,24 @@ impl SummarizationConfig {
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
) -> SummarizationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
SummarizationConfig {
model_type,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 56,
max_length: 142,
do_sample: false,
@ -178,7 +180,9 @@ impl Default for SummarizationConfig {
RemoteResource::from_pretrained(BartModelResources::BART_CNN),
RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
RemoteResource::from_pretrained(BartMergesResources::BART_CNN),
Some(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
)
}
}

View File

@ -63,7 +63,7 @@ pub struct TextGenerationConfig {
/// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
@ -106,22 +106,24 @@ impl TextGenerationConfig {
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
) -> TextGenerationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
TextGenerationConfig {
model_type,
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 0,
max_length: 20,
do_sample: true,
@ -149,7 +151,9 @@ impl Default for TextGenerationConfig {
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2MergesResources::GPT2_MEDIUM),
Some(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2_MEDIUM,
)),
)
}
}

View File

@ -254,19 +254,21 @@ impl TokenClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
TokenClassificationConfig {
model_type,

View File

@ -38,7 +38,7 @@
//! model_resource,
//! config_resource,
//! vocab_resource,
//! merges_resource,
//! Some(merges_resource),
//! source_languages,
//! target_languages,
//! Device::cuda_if_available(),

View File

@ -383,7 +383,7 @@ impl TranslationModelBuilder {
translation_resources.model_resource,
translation_resources.config_resource,
translation_resources.vocab_resource,
translation_resources.merges_resource,
Some(translation_resources.merges_resource),
translation_resources.source_languages,
translation_resources.target_languages,
device,

View File

@ -380,7 +380,7 @@ pub struct TranslationConfig {
/// Vocab resource
pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource
pub merges_resource: Box<dyn ResourceProvider + Send>,
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Supported source languages
pub source_languages: HashSet<Language>,
/// Supported target languages
@ -428,11 +428,8 @@ impl TranslationConfig {
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::marian::{
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
/// MarianVocabResources,
/// };
/// # fn main() -> anyhow::Result<()> { ///
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources};
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::TranslationConfig;
/// use rust_bert::resources::RemoteResource;
@ -441,6 +438,7 @@ impl TranslationConfig {
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -449,8 +447,8 @@ impl TranslationConfig {
/// ModelType::Marian,
/// model_resource,
/// config_resource,
/// vocab_resource.clone(),
/// vocab_resource,
/// Some(spm_resource),
/// source_languages,
/// target_languages,
/// Device::cuda_if_available(),
@ -458,18 +456,20 @@ impl TranslationConfig {
/// # Ok(())
/// # }
/// ```
pub fn new<R, S, T>(
pub fn new<RM, RC, RV, S, T>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: R,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
source_languages: S,
target_languages: T,
device: impl Into<Option<Device>>,
) -> TranslationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
S: AsRef<[Language]>,
T: AsRef<[Language]>,
{
@ -480,7 +480,7 @@ impl TranslationConfig {
model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
source_languages: source_languages.as_ref().iter().cloned().collect(),
target_languages: target_languages.as_ref().iter().cloned().collect(),
device,
@ -786,11 +786,8 @@ impl TranslationModel {
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::marian::{
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
/// MarianVocabResources,
/// };
/// # fn main() -> anyhow::Result<()> { ///
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources};
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::RemoteResource;
@ -799,6 +796,7 @@ impl TranslationModel {
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -807,8 +805,8 @@ impl TranslationModel {
/// ModelType::Marian,
/// model_resource,
/// config_resource,
/// vocab_resource.clone(),
/// vocab_resource,
/// Some(spm_resource),
/// source_languages,
/// target_languages,
/// Device::cuda_if_available(),
@ -863,7 +861,7 @@ impl TranslationModel {
/// model_resource,
/// config_resource,
/// vocab_resource,
/// merges_resource,
/// Some(merges_resource),
/// source_languages,
/// target_languages,
/// Device::cuda_if_available(),
@ -911,8 +909,8 @@ impl TranslationModel {
mod test {
use super::*;
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
MarianVocabResources,
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use crate::resources::RemoteResource;
@ -923,6 +921,7 @@ mod test {
let config_resource =
RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -931,8 +930,8 @@ mod test {
ModelType::Marian,
model_resource,
config_resource,
vocab_resource.clone(),
vocab_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -159,18 +159,20 @@ impl ZeroShotClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>(
pub fn new<RM, RC, RV>(
model_type: ModelType,
model_resource: R,
config_resource: R,
vocab_resource: R,
merges_resource: Option<R>,
model_resource: RM,
config_resource: RC,
vocab_resource: RV,
merges_resource: Option<RV>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> ZeroShotClassificationConfig
where
R: ResourceProvider + Send + 'static,
RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{
ZeroShotClassificationConfig {
model_type,

View File

@ -38,8 +38,8 @@
//! model_type: ModelType::ProphetNet,
//! model_resource: weights_resource,
//! config_resource,
//! vocab_resource: vocab_resource.clone(),
//! merges_resource: vocab_resource,
//! vocab_resource,
//! merges_resource: None,
//! length_penalty: 1.2,
//! num_beams: 4,
//! no_repeat_ngram_size: 3,

View File

@ -30,9 +30,6 @@
//! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let merges_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let model_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetModelResources::XLNET_BASE_CASED,
//! ));
@ -41,7 +38,7 @@
//! model_resource,
//! config_resource,
//! vocab_resource,
//! merges_resource,
//! merges_resource: None,
//! max_length: 56,
//! do_sample: true,
//! num_beams: 3,

View File

@ -93,7 +93,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
@ -154,7 +154,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
num_beams: 4,
min_length: 56,
max_length: 142,

View File

@ -120,7 +120,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 40,
do_sample: false,
num_beams: 1,
@ -152,7 +152,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 20,
do_sample: false,
num_beams: 5,
@ -196,7 +196,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 20,
do_sample: false,
num_beams: 5,
@ -253,7 +253,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 20,
do_sample: false,
num_beams: 5,
@ -309,7 +309,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 20,
do_sample: false,
@ -382,7 +382,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
do_sample: false,
num_beams: 1,
device: Device::Cpu,
@ -432,7 +432,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
do_sample: false,
num_beams: 1,
device: Device::Cpu,
@ -498,7 +498,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
do_sample: false,
num_beams: 3,
device: Device::Cpu,
@ -579,7 +579,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
do_sample: false,
num_beams: 3,
device: Device::Cpu,
@ -629,7 +629,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
do_sample: false,
num_beams: 1,
device: Device::Cpu,
@ -685,7 +685,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
do_sample: false,
num_beams: 2,
device: Device::Cpu,

View File

@ -128,7 +128,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 32,
do_sample: false,

View File

@ -81,7 +81,7 @@ fn m2m100_translation() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -26,7 +26,7 @@ fn test_translation() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),

View File

@ -122,7 +122,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 40,
do_sample: false,
num_beams: 1,
@ -164,7 +164,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 20,
do_sample: false,
early_stopping: true,
@ -217,7 +217,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 20,
do_sample: false,
early_stopping: true,
@ -286,7 +286,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: Some(merges_resource),
max_length: 20,
do_sample: false,
num_beams: 5,

View File

@ -22,8 +22,8 @@ fn pegasus_summarization_greedy() -> anyhow::Result<()> {
model_type: ModelType::Pegasus,
model_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
vocab_resource,
merges_resource: None,
num_beams: 4,
no_repeat_ngram_size: 3,
device: Device::cuda_if_available(),

View File

@ -24,8 +24,8 @@ fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
vocab_resource,
merges_resource: None,
length_penalty: 1.2,
num_beams: 4,
no_repeat_ngram_size: 3,

View File

@ -39,9 +39,6 @@ fn test_generation_reformer() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
@ -51,7 +48,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: None,
min_length: 100,
max_length: 100,
do_sample: false,

View File

@ -10,7 +10,6 @@ fn test_translation_t5() -> anyhow::Result<()> {
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let source_languages = [
Language::English,
@ -30,7 +29,7 @@ fn test_translation_t5() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
None,
source_languages,
target_languages,
Device::cuda_if_available(),
@ -69,7 +68,7 @@ fn test_summarization_t5() -> anyhow::Result<()> {
model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: None,
min_length: 30,
max_length: 200,
early_stopping: true,

View File

@ -202,9 +202,6 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
@ -214,7 +211,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
model_resource,
config_resource,
vocab_resource,
merges_resource,
merges_resource: None,
max_length: 32,
do_sample: false,
num_beams: 3,