Mixed resources (#291)

* - made `merges` resource optional for all pipelines
- allow mixing local and remote resources for pipelines

* Updated changelog

* Fixed Clippy warnings
This commit is contained in:
guillaume-be 2022-10-30 07:39:52 +00:00 committed by GitHub
parent 78da0f4814
commit 340be36ed9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 203 additions and 150 deletions

View File

@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file. The format
## Changed ## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`) - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines
## Fixed ## Fixed
- Fixed configuration check for RoBERTa models for sentence classification. - Fixed configuration check for RoBERTa models for sentence classification.

View File

@ -17,7 +17,9 @@ fn create_text_generation_model() -> TextGenerationModel {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)), model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)), config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)), vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)), merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0, min_length: 0,
max_length: 30, max_length: 30,
do_sample: true, do_sample: true,

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
min_length: 10, min_length: 10,
max_length: 32, max_length: 32,
do_sample: false, do_sample: false,

View File

@ -30,9 +30,6 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Box::new(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT, ReformerModelResources::CRIME_AND_PUNISHMENT,
)); ));
@ -41,7 +38,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: None,
min_length: 100, min_length: 100,
max_length: 100, max_length: 100,
do_sample: true, do_sample: true,

View File

@ -27,9 +27,6 @@ fn main() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Box::new(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED, XLNetModelResources::XLNET_BASE_CASED,
)); ));
@ -39,7 +36,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: None,
max_length: 32, max_length: 32,
do_sample: false, do_sample: false,
num_beams: 3, num_beams: 3,

View File

@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
num_beams: 1, num_beams: 1,
length_penalty: 1.0, length_penalty: 1.0,
min_length: 56, min_length: 56,

View File

@ -33,8 +33,8 @@ fn main() -> anyhow::Result<()> {
model_type: ModelType::Pegasus, model_type: ModelType::Pegasus,
model_resource: weights_resource, model_resource: weights_resource,
config_resource, config_resource,
vocab_resource: vocab_resource.clone(), vocab_resource,
merges_resource: vocab_resource, merges_resource: None,
length_penalty: 1.0, length_penalty: 1.0,
num_beams: 4, num_beams: 4,
no_repeat_ngram_size: 3, no_repeat_ngram_size: 3,

View File

@ -35,8 +35,8 @@ fn main() -> anyhow::Result<()> {
model_type: ModelType::ProphetNet, model_type: ModelType::ProphetNet,
model_resource: weights_resource, model_resource: weights_resource,
config_resource, config_resource,
vocab_resource: vocab_resource.clone(), vocab_resource,
merges_resource: vocab_resource, merges_resource: None,
length_penalty: 1.2, length_penalty: 1.2,
num_beams: 4, num_beams: 4,
no_repeat_ngram_size: 3, no_repeat_ngram_size: 3,

View File

@ -26,8 +26,8 @@ fn main() -> anyhow::Result<()> {
ModelType::T5, ModelType::T5,
weights_resource, weights_resource,
config_resource, config_resource,
vocab_resource.clone(),
vocab_resource, vocab_resource,
None,
); );
let summarization_model = SummarizationModel::new(summarization_config)?; let summarization_model = SummarizationModel::new(summarization_config)?;

View File

@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, Some(merges_resource),
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -36,7 +36,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, Some(merges_resource),
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -26,8 +26,6 @@ fn main() -> anyhow::Result<()> {
let config_resource = let config_resource =
RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY); RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY);
let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY); let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let merges_resource =
RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY);
let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY; let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY;
let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY; let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY;
@ -37,7 +35,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, None,
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -22,7 +22,6 @@ fn main() -> anyhow::Result<()> {
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE); let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE); let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE); let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE);
let source_languages = [ let source_languages = [
Language::English, Language::English,
@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, None,
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -1067,7 +1067,15 @@ impl BartGenerator {
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?; let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?; let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"BART expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file( let tokenizer = TokenizerOption::from_file(
ModelType::Bart, ModelType::Bart,

View File

@ -708,7 +708,15 @@ impl GPT2Generator {
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<GPT2Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?; let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?; let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT2 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file( let tokenizer = TokenizerOption::from_file(
ModelType::GPT2, ModelType::GPT2,

View File

@ -683,7 +683,15 @@ impl GptNeoGenerator {
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<GptNeoGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?; let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?; let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT-Neo expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file( let tokenizer = TokenizerOption::from_file(
ModelType::GPTNeo, ModelType::GPTNeo,

View File

@ -44,7 +44,7 @@
//! model_resource, //! model_resource,
//! config_resource, //! config_resource,
//! vocab_resource, //! vocab_resource,
//! merges_resource, //! merges_resource: Some(merges_resource),
//! num_beams: 4, //! num_beams: 4,
//! no_repeat_ngram_size: 3, //! no_repeat_ngram_size: 3,
//! device: Device::cuda_if_available(), //! device: Device::cuda_if_available(),

View File

@ -617,7 +617,15 @@ impl M2M100Generator {
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<M2M100Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?; let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?; let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"M2M100 expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file( let tokenizer = TokenizerOption::from_file(
ModelType::M2M100, ModelType::M2M100,

View File

@ -837,7 +837,16 @@ impl MarianGenerator {
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<MarianGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<MarianGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?; let vocab_path = generate_config.vocab_resource.get_local_path()?;
let sentence_piece_path = generate_config.merges_resource.get_local_path()?; let sentence_piece_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"Marian expects a merges (SentencePiece model) resources to be provided"
.to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file( let tokenizer = TokenizerOption::from_file(
ModelType::Marian, ModelType::Marian,

View File

@ -470,7 +470,15 @@ impl OpenAIGenerator {
/// ``` /// ```
pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> { pub fn new(generate_config: GenerateConfig) -> Result<OpenAIGenerator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?; let vocab_path = generate_config.vocab_resource.get_local_path()?;
let merges_path = generate_config.merges_resource.get_local_path()?; let merges_path = generate_config
.merges_resource
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"GPT expects a merges resources to be provided".to_string(),
)
})?
.get_local_path()?;
let tokenizer = TokenizerOption::from_file( let tokenizer = TokenizerOption::from_file(
ModelType::OpenAiGpt, ModelType::OpenAiGpt,

View File

@ -82,7 +82,7 @@ pub struct ConversationConfig {
/// Vocab resource (default: DialoGPT-medium) /// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Box<dyn ResourceProvider + Send>, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: DialoGPT-medium) /// Merges resource (default: DialoGPT-medium)
pub merges_resource: Box<dyn ResourceProvider + Send>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -131,9 +131,9 @@ impl Default for ConversationConfig {
vocab_resource: Box::new(RemoteResource::from_pretrained( vocab_resource: Box::new(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM, Gpt2VocabResources::DIALOGPT_MEDIUM,
)), )),
merges_resource: Box::new(RemoteResource::from_pretrained( merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM, Gpt2MergesResources::DIALOGPT_MEDIUM,
)), ))),
min_length: 0, min_length: 0,
max_length: 1000, max_length: 1000,
min_length_for_response: 64, min_length_for_response: 64,

View File

@ -103,7 +103,7 @@ pub struct GenerateConfig {
/// Vocab resource (default: pretrained GPT2 model) /// Vocab resource (default: pretrained GPT2 model)
pub vocab_resource: Box<dyn ResourceProvider + Send>, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained GPT2 model) /// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Box<dyn ResourceProvider + Send>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -143,7 +143,9 @@ impl Default for GenerateConfig {
model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)), model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)),
config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)), config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)),
vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)), vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)),
merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)), merges_resource: Some(Box::new(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2,
))),
min_length: 0, min_length: 0,
max_length: 20, max_length: 20,
do_sample: true, do_sample: true,

View File

@ -166,18 +166,20 @@ impl QuestionAnsweringConfig {
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>( pub fn new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: Option<R>, merges_resource: Option<RV>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig ) -> QuestionAnsweringConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
QuestionAnsweringConfig { QuestionAnsweringConfig {
model_type, model_type,
@ -210,12 +212,12 @@ impl QuestionAnsweringConfig {
/// * max_query_length - Optional maximum question token length. Defaults to 64. /// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128. /// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15. /// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new<R>( pub fn custom_new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: Option<R>, merges_resource: Option<RV>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
@ -225,7 +227,9 @@ impl QuestionAnsweringConfig {
max_answer_length: impl Into<Option<usize>>, max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig ) -> QuestionAnsweringConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
QuestionAnsweringConfig { QuestionAnsweringConfig {
model_type, model_type,

View File

@ -134,18 +134,20 @@ impl SequenceClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>( pub fn new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: Option<R>, merges_resource: Option<RV>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
) -> SequenceClassificationConfig ) -> SequenceClassificationConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
SequenceClassificationConfig { SequenceClassificationConfig {
model_type, model_type,

View File

@ -92,7 +92,7 @@ pub struct SummarizationConfig {
/// Vocab resource (default: pretrained BART model on CNN-DM) /// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Box<dyn ResourceProvider + Send>, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM) /// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Box<dyn ResourceProvider + Send>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -135,22 +135,24 @@ impl SummarizationConfig {
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json) /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt). /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<R>( pub fn new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: R, merges_resource: Option<RV>,
) -> SummarizationConfig ) -> SummarizationConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
SummarizationConfig { SummarizationConfig {
model_type, model_type,
model_resource: Box::new(model_resource), model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource), config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource), vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource), merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 56, min_length: 56,
max_length: 142, max_length: 142,
do_sample: false, do_sample: false,
@ -178,7 +180,9 @@ impl Default for SummarizationConfig {
RemoteResource::from_pretrained(BartModelResources::BART_CNN), RemoteResource::from_pretrained(BartModelResources::BART_CNN),
RemoteResource::from_pretrained(BartConfigResources::BART_CNN), RemoteResource::from_pretrained(BartConfigResources::BART_CNN),
RemoteResource::from_pretrained(BartVocabResources::BART_CNN), RemoteResource::from_pretrained(BartVocabResources::BART_CNN),
RemoteResource::from_pretrained(BartMergesResources::BART_CNN), Some(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
) )
} }
} }

View File

@ -63,7 +63,7 @@ pub struct TextGenerationConfig {
/// Vocab resource (default: pretrained BART model on CNN-DM) /// Vocab resource (default: pretrained BART model on CNN-DM)
pub vocab_resource: Box<dyn ResourceProvider + Send>, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource (default: pretrained BART model on CNN-DM) /// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Box<dyn ResourceProvider + Send>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
@ -106,22 +106,24 @@ impl TextGenerationConfig {
/// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json) /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt). /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt).
pub fn new<R>( pub fn new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: R, merges_resource: Option<RV>,
) -> TextGenerationConfig ) -> TextGenerationConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
TextGenerationConfig { TextGenerationConfig {
model_type, model_type,
model_resource: Box::new(model_resource), model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource), config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource), vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource), merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 0, min_length: 0,
max_length: 20, max_length: 20,
do_sample: true, do_sample: true,
@ -149,7 +151,9 @@ impl Default for TextGenerationConfig {
RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM), RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM), RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM), RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM),
RemoteResource::from_pretrained(Gpt2MergesResources::GPT2_MEDIUM), Some(RemoteResource::from_pretrained(
Gpt2MergesResources::GPT2_MEDIUM,
)),
) )
} }
} }

View File

@ -254,19 +254,21 @@ impl TokenClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta. /// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>( pub fn new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: Option<R>, merges_resource: Option<RV>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
label_aggregation_function: LabelAggregationOption, label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig ) -> TokenClassificationConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
TokenClassificationConfig { TokenClassificationConfig {
model_type, model_type,

View File

@ -38,7 +38,7 @@
//! model_resource, //! model_resource,
//! config_resource, //! config_resource,
//! vocab_resource, //! vocab_resource,
//! merges_resource, //! Some(merges_resource),
//! source_languages, //! source_languages,
//! target_languages, //! target_languages,
//! Device::cuda_if_available(), //! Device::cuda_if_available(),

View File

@ -383,7 +383,7 @@ impl TranslationModelBuilder {
translation_resources.model_resource, translation_resources.model_resource,
translation_resources.config_resource, translation_resources.config_resource,
translation_resources.vocab_resource, translation_resources.vocab_resource,
translation_resources.merges_resource, Some(translation_resources.merges_resource),
translation_resources.source_languages, translation_resources.source_languages,
translation_resources.target_languages, translation_resources.target_languages,
device, device,

View File

@ -380,7 +380,7 @@ pub struct TranslationConfig {
/// Vocab resource /// Vocab resource
pub vocab_resource: Box<dyn ResourceProvider + Send>, pub vocab_resource: Box<dyn ResourceProvider + Send>,
/// Merges resource /// Merges resource
pub merges_resource: Box<dyn ResourceProvider + Send>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Supported source languages /// Supported source languages
pub source_languages: HashSet<Language>, pub source_languages: HashSet<Language>,
/// Supported target languages /// Supported target languages
@ -428,11 +428,8 @@ impl TranslationConfig {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// # fn main() -> anyhow::Result<()> { /// # fn main() -> anyhow::Result<()> { ///
/// use rust_bert::marian::{ /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources};
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
/// MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::TranslationConfig; /// use rust_bert::pipelines::translation::TranslationConfig;
/// use rust_bert::resources::RemoteResource; /// use rust_bert::resources::RemoteResource;
@ -441,6 +438,7 @@ impl TranslationConfig {
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH); /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH); /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH); /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
/// ///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -449,8 +447,8 @@ impl TranslationConfig {
/// ModelType::Marian, /// ModelType::Marian,
/// model_resource, /// model_resource,
/// config_resource, /// config_resource,
/// vocab_resource.clone(),
/// vocab_resource, /// vocab_resource,
/// Some(spm_resource),
/// source_languages, /// source_languages,
/// target_languages, /// target_languages,
/// Device::cuda_if_available(), /// Device::cuda_if_available(),
@ -458,18 +456,20 @@ impl TranslationConfig {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn new<R, S, T>( pub fn new<RM, RC, RV, S, T>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: R, merges_resource: Option<RV>,
source_languages: S, source_languages: S,
target_languages: T, target_languages: T,
device: impl Into<Option<Device>>, device: impl Into<Option<Device>>,
) -> TranslationConfig ) -> TranslationConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
S: AsRef<[Language]>, S: AsRef<[Language]>,
T: AsRef<[Language]>, T: AsRef<[Language]>,
{ {
@ -480,7 +480,7 @@ impl TranslationConfig {
model_resource: Box::new(model_resource), model_resource: Box::new(model_resource),
config_resource: Box::new(config_resource), config_resource: Box::new(config_resource),
vocab_resource: Box::new(vocab_resource), vocab_resource: Box::new(vocab_resource),
merges_resource: Box::new(merges_resource), merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
source_languages: source_languages.as_ref().iter().cloned().collect(), source_languages: source_languages.as_ref().iter().cloned().collect(),
target_languages: target_languages.as_ref().iter().cloned().collect(), target_languages: target_languages.as_ref().iter().cloned().collect(),
device, device,
@ -786,11 +786,8 @@ impl TranslationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// # fn main() -> anyhow::Result<()> { /// # fn main() -> anyhow::Result<()> { ///
/// use rust_bert::marian::{ /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources};
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
/// MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel}; /// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::RemoteResource; /// use rust_bert::resources::RemoteResource;
@ -799,6 +796,7 @@ impl TranslationModel {
/// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH); /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
/// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH); /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
/// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH); /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
/// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
/// ///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -807,8 +805,8 @@ impl TranslationModel {
/// ModelType::Marian, /// ModelType::Marian,
/// model_resource, /// model_resource,
/// config_resource, /// config_resource,
/// vocab_resource.clone(),
/// vocab_resource, /// vocab_resource,
/// Some(spm_resource),
/// source_languages, /// source_languages,
/// target_languages, /// target_languages,
/// Device::cuda_if_available(), /// Device::cuda_if_available(),
@ -863,7 +861,7 @@ impl TranslationModel {
/// model_resource, /// model_resource,
/// config_resource, /// config_resource,
/// vocab_resource, /// vocab_resource,
/// merges_resource, /// Some(merges_resource),
/// source_languages, /// source_languages,
/// target_languages, /// target_languages,
/// Device::cuda_if_available(), /// Device::cuda_if_available(),
@ -911,8 +909,8 @@ impl TranslationModel {
mod test { mod test {
use super::*; use super::*;
use crate::marian::{ use crate::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianVocabResources, MarianTargetLanguages, MarianVocabResources,
}; };
use crate::resources::RemoteResource; use crate::resources::RemoteResource;
@ -923,6 +921,7 @@ mod test {
let config_resource = let config_resource =
RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH); RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH); let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
@ -931,8 +930,8 @@ mod test {
ModelType::Marian, ModelType::Marian,
model_resource, model_resource,
config_resource, config_resource,
vocab_resource.clone(),
vocab_resource, vocab_resource,
Some(merges_resource),
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -159,18 +159,20 @@ impl ZeroShotClassificationConfig {
/// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
pub fn new<R>( pub fn new<RM, RC, RV>(
model_type: ModelType, model_type: ModelType,
model_resource: R, model_resource: RM,
config_resource: R, config_resource: RC,
vocab_resource: R, vocab_resource: RV,
merges_resource: Option<R>, merges_resource: Option<RV>,
lower_case: bool, lower_case: bool,
strip_accents: impl Into<Option<bool>>, strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>, add_prefix_space: impl Into<Option<bool>>,
) -> ZeroShotClassificationConfig ) -> ZeroShotClassificationConfig
where where
R: ResourceProvider + Send + 'static, RM: ResourceProvider + Send + 'static,
RC: ResourceProvider + Send + 'static,
RV: ResourceProvider + Send + 'static,
{ {
ZeroShotClassificationConfig { ZeroShotClassificationConfig {
model_type, model_type,

View File

@ -38,8 +38,8 @@
//! model_type: ModelType::ProphetNet, //! model_type: ModelType::ProphetNet,
//! model_resource: weights_resource, //! model_resource: weights_resource,
//! config_resource, //! config_resource,
//! vocab_resource: vocab_resource.clone(), //! vocab_resource,
//! merges_resource: vocab_resource, //! merges_resource: None,
//! length_penalty: 1.2, //! length_penalty: 1.2,
//! num_beams: 4, //! num_beams: 4,
//! no_repeat_ngram_size: 3, //! no_repeat_ngram_size: 3,

View File

@ -30,9 +30,6 @@
//! let vocab_resource = Box::new(RemoteResource::from_pretrained( //! let vocab_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED, //! XLNetVocabResources::XLNET_BASE_CASED,
//! )); //! ));
//! let merges_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let model_resource = Box::new(RemoteResource::from_pretrained( //! let model_resource = Box::new(RemoteResource::from_pretrained(
//! XLNetModelResources::XLNET_BASE_CASED, //! XLNetModelResources::XLNET_BASE_CASED,
//! )); //! ));
@ -41,7 +38,7 @@
//! model_resource, //! model_resource,
//! config_resource, //! config_resource,
//! vocab_resource, //! vocab_resource,
//! merges_resource, //! merges_resource: None,
//! max_length: 56, //! max_length: 56,
//! do_sample: true, //! do_sample: true,
//! num_beams: 3, //! num_beams: 3,

View File

@ -93,7 +93,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
num_beams: 1, num_beams: 1,
length_penalty: 1.0, length_penalty: 1.0,
min_length: 56, min_length: 56,
@ -154,7 +154,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
num_beams: 4, num_beams: 4,
min_length: 56, min_length: 56,
max_length: 142, max_length: 142,

View File

@ -120,7 +120,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 40, max_length: 40,
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
@ -152,7 +152,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
@ -196,7 +196,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
@ -253,7 +253,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
@ -309,7 +309,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
min_length: 10, min_length: 10,
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
@ -382,7 +382,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
device: Device::Cpu, device: Device::Cpu,
@ -432,7 +432,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
device: Device::Cpu, device: Device::Cpu,
@ -498,7 +498,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
do_sample: false, do_sample: false,
num_beams: 3, num_beams: 3,
device: Device::Cpu, device: Device::Cpu,
@ -579,7 +579,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
do_sample: false, do_sample: false,
num_beams: 3, num_beams: 3,
device: Device::Cpu, device: Device::Cpu,
@ -629,7 +629,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
device: Device::Cpu, device: Device::Cpu,
@ -685,7 +685,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
do_sample: false, do_sample: false,
num_beams: 2, num_beams: 2,
device: Device::Cpu, device: Device::Cpu,

View File

@ -128,7 +128,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
min_length: 10, min_length: 10,
max_length: 32, max_length: 32,
do_sample: false, do_sample: false,

View File

@ -81,7 +81,7 @@ fn m2m100_translation() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, Some(merges_resource),
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -26,7 +26,7 @@ fn test_translation() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, Some(merges_resource),
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),

View File

@ -122,7 +122,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 40, max_length: 40,
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
@ -164,7 +164,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
@ -217,7 +217,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
@ -286,7 +286,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: Some(merges_resource),
max_length: 20, max_length: 20,
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,

View File

@ -22,8 +22,8 @@ fn pegasus_summarization_greedy() -> anyhow::Result<()> {
model_type: ModelType::Pegasus, model_type: ModelType::Pegasus,
model_resource, model_resource,
config_resource, config_resource,
vocab_resource: vocab_resource.clone(), vocab_resource,
merges_resource: vocab_resource, merges_resource: None,
num_beams: 4, num_beams: 4,
no_repeat_ngram_size: 3, no_repeat_ngram_size: 3,
device: Device::cuda_if_available(), device: Device::cuda_if_available(),

View File

@ -24,8 +24,8 @@ fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
model_type: ModelType::ProphetNet, model_type: ModelType::ProphetNet,
model_resource: weights_resource, model_resource: weights_resource,
config_resource, config_resource,
vocab_resource: vocab_resource.clone(), vocab_resource,
merges_resource: vocab_resource, merges_resource: None,
length_penalty: 1.2, length_penalty: 1.2,
num_beams: 4, num_beams: 4,
no_repeat_ngram_size: 3, no_repeat_ngram_size: 3,

View File

@ -39,9 +39,6 @@ fn test_generation_reformer() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT, ReformerVocabResources::CRIME_AND_PUNISHMENT,
)); ));
let merges_resource = Box::new(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let model_resource = Box::new(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT, ReformerModelResources::CRIME_AND_PUNISHMENT,
)); ));
@ -51,7 +48,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: None,
min_length: 100, min_length: 100,
max_length: 100, max_length: 100,
do_sample: false, do_sample: false,

View File

@ -10,7 +10,6 @@ fn test_translation_t5() -> anyhow::Result<()> {
let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL); let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL);
let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL); let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL);
let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL); let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL);
let source_languages = [ let source_languages = [
Language::English, Language::English,
@ -30,7 +29,7 @@ fn test_translation_t5() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, None,
source_languages, source_languages,
target_languages, target_languages,
Device::cuda_if_available(), Device::cuda_if_available(),
@ -69,7 +68,7 @@ fn test_summarization_t5() -> anyhow::Result<()> {
model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)), model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)),
config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)), config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)),
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), merges_resource: None,
min_length: 30, min_length: 30,
max_length: 200, max_length: 200,
early_stopping: true, early_stopping: true,

View File

@ -202,9 +202,6 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
let vocab_resource = Box::new(RemoteResource::from_pretrained( let vocab_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED, XLNetVocabResources::XLNET_BASE_CASED,
)); ));
let merges_resource = Box::new(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let model_resource = Box::new(RemoteResource::from_pretrained( let model_resource = Box::new(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED, XLNetModelResources::XLNET_BASE_CASED,
)); ));
@ -214,7 +211,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource, merges_resource: None,
max_length: 32, max_length: 32,
do_sample: false, do_sample: false,
num_beams: 3, num_beams: 3,