Updated doctests

This commit is contained in:
Guillaume B 2021-07-28 18:10:20 +02:00
parent d6c5c47b48
commit 466c6b6922
4 changed files with 55 additions and 32 deletions

View File

@ -26,7 +26,7 @@ fn main() -> anyhow::Result<()> {
.with_target_languages(vec![Language::Spanish])
.create_model()?;
let input_context_1 = "The quick brown fox jumps over the lazy dog.";
let input_context_1 = "This is a sentence to be translated";
let input_context_2 = "The dog did not wake up.";
let output = model.translate(&[input_context_1, input_context_2], None, Language::Spanish)?;

View File

@ -168,23 +168,28 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::translation::{
//! Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
//! };
//! use tch::Device;
//! let translation_config =
//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//! let model = TranslationModelBuilder::new()
//! .with_device(Device::cuda_if_available())
//! .with_model_type(ModelType::Marian)
//! .with_source_languages(vec![Language::English])
//! .with_target_languages(vec![Language::French])
//! .create_model()?;
//!
//! let input = ["This is a sentence to be translated"];
//!
//! let output = model.translate(&input);
//! let output = model.translate(&input, None, Language::French);
//! # Ok(())
//! # }
//! ```
//!
//! Output: \
//! ```no_run
//! # let output =
//! "Il s'agit d'une phrase à traduire"
//! " Il s'agit d'une phrase à traduire"
//! # ;
//! ```
//!

View File

@ -56,18 +56,30 @@
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! # use rust_bert::pipelines::generation_utils::LanguageGenerator;
//! use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::translation::{
//! Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
//! };
//! use tch::Device;
//! let translation_config =
//! TranslationConfig::new(OldLanguage::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//! let model = TranslationModelBuilder::new()
//! .with_device(Device::cuda_if_available())
//! .with_model_type(ModelType::Marian)
//! .with_source_languages(vec![Language::English])
//! .with_target_languages(vec![Language::French])
//! .create_model()?;
//!
//! let input = ["This is a sentence to be translated"];
//!
//! let output = model.translate(&input);
//! let output = model.translate(&input, None, Language::French);
//! # Ok(())
//! # }
//! ```
//! Output: \
//! ```no_run
//! # let output =
//! " Il s'agit d'une phrase à traduire"
//! # ;
//! ```
//!
//! Output: \
//! ```no_run

View File

@ -433,7 +433,7 @@ impl TranslationConfig {
/// MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig};
/// use rust_bert::pipelines::translation::TranslationConfig;
/// use rust_bert::resources::{RemoteResource, Resource};
/// use tch::Device;
///
@ -447,8 +447,8 @@ impl TranslationConfig {
/// MarianVocabResources::ROMANCE2ENGLISH,
/// ));
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect();
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect();
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
///
/// let translation_config = TranslationConfig::new(
/// ModelType::Marian,
@ -458,7 +458,7 @@ impl TranslationConfig {
/// vocab_resource,
/// source_languages,
/// target_languages,
/// device: Device::cuda_if_available(),
/// Device::cuda_if_available(),
/// );
/// # Ok(())
/// # }
@ -817,11 +817,14 @@ impl TranslationModel {
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel};
/// use tch::Device;
/// use rust_bert::resources::{Resource, RemoteResource};
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages};
/// use rust_bert::marian::{
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages,
/// MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use tch::Device;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianModelResources::ROMANCE2ENGLISH,
@ -833,8 +836,8 @@ impl TranslationModel {
/// MarianVocabResources::ROMANCE2ENGLISH,
/// ));
///
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH.iter().collect();
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH.iter().collect();
/// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
/// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
///
/// let translation_config = TranslationConfig::new(
/// ModelType::Marian,
@ -844,7 +847,7 @@ impl TranslationModel {
/// vocab_resource,
/// source_languages,
/// target_languages,
/// device: Device::cuda_if_available(),
/// Device::cuda_if_available(),
/// );
/// let mut summarization_model = TranslationModel::new(translation_config)?;
/// # Ok(())
@ -876,11 +879,14 @@ impl TranslationModel {
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::translation::{OldLanguage, TranslationConfig, TranslationModel, Language};
/// use tch::Device;
/// use rust_bert::resources::{Resource, RemoteResource};
/// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianVocabResources, MarianSourceLanguages, MarianTargetLanguages, MarianSpmResources};
/// use rust_bert::marian::{
/// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
/// MarianTargetLanguages, MarianVocabResources,
/// };
/// use rust_bert::pipelines::common::ModelType;
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use tch::Device;
///
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianModelResources::ENGLISH2ROMANCE,
@ -892,10 +898,10 @@ impl TranslationModel {
/// MarianVocabResources::ENGLISH2ROMANCE,
/// ));
/// let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
/// MarianSpmResources::ENGLISH2ROMANCE,
/// MarianSpmResources::ENGLISH2ROMANCE,
/// ));
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE.iter().collect();
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE.iter().collect();
/// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
/// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
///
/// let translation_config = TranslationConfig::new(
/// ModelType::Marian,
@ -905,7 +911,7 @@ impl TranslationModel {
/// merges_resource,
/// source_languages,
/// target_languages,
/// device: Device::cuda_if_available(),
/// Device::cuda_if_available(),
/// );
/// let model = TranslationModel::new(translation_config)?;
///