From 466c6b6922e9b35fc2eb55e0df2f64774b53d951 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 28 Jul 2021 18:10:20 +0200 Subject: [PATCH] Updated doctests --- examples/translation_builder.rs | 2 +- src/lib.rs | 19 +++++--- src/pipelines/mod.rs | 22 +++++++--- .../translation/translation_pipeline.rs | 44 +++++++++++-------- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/examples/translation_builder.rs b/examples/translation_builder.rs index 0234e28..b457306 100644 --- a/examples/translation_builder.rs +++ b/examples/translation_builder.rs @@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> { .with_target_languages(vec![Language::Spanish]) .create_model()?; - let input_context_1 = "The quick brown fox jumps over the lazy dog."; + let input_context_1 = "This is a sentence to be translated"; let input_context_2 = "The dog did not wake up."; let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?; diff --git a/src/lib.rs b/src/lib.rs index 4cfd12b..0ed6991 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -168,23 +168,28 @@ //! ```no_run //! # fn main() -> anyhow::Result<()> { //! # use rust_bert::pipelines::generation_utils::LanguageGenerator; -//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; +//! use rust_bert::pipelines::common::ModelType; +//! use rust_bert::pipelines::translation::{ +//! Language, TranslationConfig, TranslationModel, TranslationModelBuilder, +//! }; //! use tch::Device; -//! let translation_config = -//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available()); -//! let mut model = TranslationModel::new(translation_config)?; +//! let model = TranslationModelBuilder::new() +//! .with_device(Device::cuda_if_available()) +//! .with_model_type(ModelType::Marian) +//! .with_source_languages(vec![Language::English]) +//! .with_target_languages(vec![Language::French]) +//! .create_model()?; //! //! let input = ["This is a sentence to be translated"]; //! -//! let output = model.translate(&input); +//! let output = model.translate(&input, None, Language::French); //! # Ok(()) //! # } //! ``` -//! //! Output: \ //! ```no_run //! # let output = -//! "Il s'agit d'une phrase à traduire" +//! " Il s'agit d'une phrase à traduire" //! # ; //! ``` //! diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 4523a43..a95b71b 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -56,18 +56,30 @@ //! ```no_run //! # fn main() -> anyhow::Result<()> { //! # use rust_bert::pipelines::generation_utils::LanguageGenerator; -//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; +//! use rust_bert::pipelines::common::ModelType; +//! use rust_bert::pipelines::translation::{ +//! Language, TranslationConfig, TranslationModel, TranslationModelBuilder, +//! }; //! use tch::Device; -//! let translation_config = -//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available()); -//! let mut model = TranslationModel::new(translation_config)?; +//! let model = TranslationModelBuilder::new() +//! .with_device(Device::cuda_if_available()) +//! .with_model_type(ModelType::Marian) +//! .with_source_languages(vec![Language::English]) +//! .with_target_languages(vec![Language::French]) +//! .create_model()?; //! //! let input = ["This is a sentence to be translated"]; //! -//! let output = model.translate(&input); +//! let output = model.translate(&input, None, Language::French); //! # Ok(()) //! # } //! ``` +//! Output: \ +//! ```no_run +//! # let output = +//! " Il s'agit d'une phrase à traduire" +//! # ; +//! ``` //! //! Output: \ //! ```no_run diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs index 8b98d18..4db6a45 100644 --- a/src/pipelines/translation/translation_pipeline.rs +++ b/src/pipelines/translation/translation_pipeline.rs @@ -433,7 +433,7 @@ impl TranslationConfig { /// MarianVocabResources, /// }; /// use rust_bert::pipelines::common::ModelType; - /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig}; + /// use rust_bert::pipelines::translation::TranslationConfig; /// use rust_bert::resources::{RemoteResource, Resource}; /// use tch::Device; /// @@ -447,8 +447,8 @@ impl TranslationConfig { /// MarianVocabResources::ROMANCE2ENGLISH, /// )); /// - /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect(); - /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect(); + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; /// /// let translation_config = TranslationConfig::new( /// ModelType::Marian, @@ -458,7 +458,7 @@ impl TranslationConfig { /// vocab_resource, /// source_languages, /// target_languages, - /// device: Device::cuda_if_available(), + /// Device::cuda_if_available(), /// ); /// # Ok(()) /// # } @@ -817,11 +817,14 @@ impl TranslationModel { /// /// ```no_run /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel}; - /// use tch::Device; - /// use rust_bert::resources::{Resource, RemoteResource}; - /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages}; + /// use rust_bert::marian::{ + /// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, + /// MarianVocabResources, + /// }; /// use rust_bert::pipelines::common::ModelType; + /// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel}; + /// use rust_bert::resources::{RemoteResource, Resource}; + /// use tch::Device; /// /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// MarianModelResources::ROMANCE2ENGLISH, @@ -833,8 +836,8 @@ impl TranslationModel { /// MarianVocabResources::ROMANCE2ENGLISH, /// )); /// - /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect(); - /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect(); + /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; + /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; /// /// let translation_config = TranslationConfig::new( /// ModelType::Marian, @@ -844,7 +847,7 @@ impl TranslationModel { /// vocab_resource, /// source_languages, /// target_languages, - /// device: Device::cuda_if_available(), + /// Device::cuda_if_available(), /// ); /// let mut summarization_model = TranslationModel::new(translation_config)?; /// # Ok(()) @@ -876,11 +879,14 @@ impl TranslationModel { /// /// ```no_run /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel, Language}; - /// use tch::Device; - /// use rust_bert::resources::{Resource, RemoteResource}; - /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages, MarianSpmResources}; + /// use rust_bert::marian::{ + /// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, + /// MarianTargetLanguages, MarianVocabResources, + /// }; /// use rust_bert::pipelines::common::ModelType; + /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel}; + /// use rust_bert::resources::{RemoteResource, Resource}; + /// use tch::Device; /// /// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// MarianModelResources::ENGLISH2ROMANCE, @@ -892,10 +898,10 @@ impl TranslationModel { /// MarianVocabResources::ENGLISH2ROMANCE, /// )); /// let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - /// MarianSpmResources::ENGLISH2ROMANCE, + /// MarianSpmResources::ENGLISH2ROMANCE, /// )); - /// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE.iter().collect(); - /// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE.iter().collect(); + /// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE; + /// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE; /// /// let translation_config = TranslationConfig::new( /// ModelType::Marian, @@ -905,7 +911,7 @@ impl TranslationModel { /// merges_resource, /// source_languages, /// target_languages, - /// device: Device::cuda_if_available(), + /// Device::cuda_if_available(), /// ); /// let model = TranslationModel::new(translation_config)?; ///