Updated documentation and updated cache reorder for generation pipelines

This commit is contained in:
Guillaume B 2020-07-25 11:49:50 +02:00
parent 5b00074ab5
commit dda94dedce
5 changed files with 95 additions and 50 deletions

View File

@ -30,7 +30,7 @@ all-tests = []
features = [ "doc-only" ]
[dependencies]
rust_tokenizers = {version = "~3.1.6", path = "E:/coding/backup-rust/rust-tokenizers/main/"}
rust_tokenizers = "~3.1.6"
tch = "~0.1.7"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}

View File

@ -186,7 +186,8 @@ Output:
```
#### 7. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Models are currently available for English, German, Spanish and Dutch.
```rust
let ner_model = NERModel::new(default::default())?;

View File

@ -768,21 +768,19 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
@ -1037,21 +1035,19 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
@ -1253,21 +1249,19 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
match past {
Cache::T5Cache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
@ -2088,18 +2082,15 @@ pub(crate) mod private_generation_utils {
) -> Option<Tensor> {
match past {
Cache::None => None,
Cache::GPT2Cache(cached_decoder_state) => {
match cached_decoder_state {
Some(value) => {
// let mut reordered_past = vec!();
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None
Cache::GPT2Cache(cached_decoder_state) => match cached_decoder_state {
Some(value) => {
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None => None,
None
}
}
None => None,
},
Cache::BARTCache(_) => {
panic!("Not implemented");
}

View File

@ -216,7 +216,8 @@
//! ```
//!
//! #### 7. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Additional pre-trained models are available for English, German, Spanish and Dutch.
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//! # fn main() -> failure::Fallible<()> {

View File

@ -12,12 +12,19 @@
//! # Named Entity Recognition pipeline
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text.
//! BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Pretrained models are available for the following languages:
//! - English
//! - German
//! - Spanish
//! - Dutch
//!
//! The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! All resources for this model can be downloaded using the Python utility script included in this repository.
//! 1. Set-up a Python virtual environment and install dependencies (in ./requirements.txt)
//! 2. Run the conversion script python /utils/download-dependencies_bert_ner.py.
//! The dependencies will be downloaded to the user's home directory, under ~/rustbert/bert-ner
//!
//! The example below illustrate how to run the model for the default English NER model
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//! # fn main() -> failure::Fallible<()> {
@ -60,6 +67,51 @@
//! ]
//! # ;
//! ```
//!
//! To run the pipeline for another language, change the NERModel configuration from its default:
//!
//! ```no_run
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::ner::NERModel;
//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::roberta::{
//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
//! };
//! use tch::Device;
//! let ner_config = TokenClassificationConfig {
//! model_type: ModelType::XLMRoberta,
//! model_resource: Resource::Remote(RemoteResource::from_pretrained(
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
//! )),
//! config_resource: Resource::Remote(RemoteResource::from_pretrained(
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
//! )),
//! vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
//! RobertaVocabResources::XLM_ROBERTA_NER_DE,
//! )),
//! lower_case: false,
//! device: Device::cuda_if_available(),
//! ..Default::default()
//! };
//!
//! let ner_model = NERModel::new(ner_config)?;
//!
//! // Define input
//! let input = [
//! "Mein Name ist Amélie. Ich lebe in Paris.",
//! "Paris ist eine Stadt in Frankreich.",
//! ];
//! ```
//! The XLMRoberta models for the languages are defined as follows:
//!
//! | **Language** |**Model name**|
//! :-----:|:----:
//! English| XLM_ROBERTA_NER_EN |
//! German| XLM_ROBERTA_NER_DE |
//! Spanish| XLM_ROBERTA_NER_ES |
//! Dutch| XLM_ROBERTA_NER_NL |
//!
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};