diff --git a/Cargo.toml b/Cargo.toml index ee8b15f..8c13ca7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,5 +73,6 @@ half = "1.8.2" anyhow = "1.0.51" csv = "1.1.6" criterion = "0.3.5" +tokio = { version = "1.16.1", features = ["sync", "rt-multi-thread", "macros"] } torch-sys = "~0.6.1" tempfile = "3.2.0" diff --git a/examples/async-sentiment.rs b/examples/async-sentiment.rs new file mode 100644 index 0000000..e994a8b --- /dev/null +++ b/examples/async-sentiment.rs @@ -0,0 +1,63 @@ +use std::{ + sync::mpsc, + thread::{self, JoinHandle}, +}; + +use anyhow::Result; +use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel}; +use tokio::{sync::oneshot, task}; + +#[tokio::main] +async fn main() -> Result<()> { + let (_handle, classifier) = SentimentClassifier::spawn(); + + let texts = vec![ + "Classify this positive text".to_owned(), + "Classify this negative text".to_owned(), + ]; + let sentiments = classifier.predict(texts).await?; + println!("Results: {:?}", sentiments); + + Ok(()) +} + +/// Message type for internal channel, passing around texts and return value +/// senders +type Message = (Vec, oneshot::Sender>); + +/// Runner for sentiment classification +#[derive(Debug, Clone)] +pub struct SentimentClassifier { + sender: mpsc::SyncSender, +} + +impl SentimentClassifier { + /// Spawn a classifier on a separate thread and return a classifier instance + /// to interact with it + pub fn spawn() -> (JoinHandle>, SentimentClassifier) { + let (sender, receiver) = mpsc::sync_channel(100); + let handle = thread::spawn(move || Self::runner(receiver)); + (handle, SentimentClassifier { sender }) + } + + /// The classification runner itself + fn runner(receiver: mpsc::Receiver) -> Result<()> { + // Needs to be in sync runtime, async doesn't work + let model = SentimentModel::new(SentimentConfig::default())?; + + while let Ok((texts, sender)) = receiver.recv() { + let texts: Vec<&str> = texts.iter().map(String::as_str).collect(); + let sentiments = model.predict(texts); + sender.send(sentiments).expect("sending results"); + } + + Ok(()) + } + + /// Make the runner predict a sample and return the result + pub async fn predict(&self, texts: Vec) -> Result> { + let (sender, receiver) = oneshot::channel(); + task::block_in_place(|| self.sender.send((texts, sender)))?; + Ok(receiver.await?) + } +} diff --git a/src/lib.rs b/src/lib.rs index 3e533e2..de2bbca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -559,6 +559,13 @@ //! `python ./utils/convert_model.py path/to/pytorch_model.bin` where `path/to/pytorch_model.bin` is the location of the original Pytorch weights. //! //! +//! ## Async execution +//! +//! Creating any of the models in async context will cause panics! Running extensive calculations like running predictions in a future should be avoided, too ([see here](https://docs.rs/tokio/latest/tokio/#cpu-bound-tasks-and-blocking-code)). +//! +//! It is recommended to spawn a separate thread for the models. The `async-sentiment` example displays a possible solution you could use to integrate models into async code. +//! +//! //! ## Citation //! //! If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/): @@ -579,6 +586,9 @@ //! Thank you to [Hugging Face](https://huggingface.co) for hosting a set of weights compatible with this Rust library. //! The list of ready-to-use pretrained models is listed at [https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust). +// These are used abundantly in this code +#![allow(clippy::assign_op_pattern, clippy::upper_case_acronyms)] + pub mod albert; pub mod bart; pub mod bert; diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index fdb4ce6..80a07a9 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -663,7 +663,7 @@ impl QuestionAnsweringModel { let mut features: Vec = qa_inputs .iter() .enumerate() - .map(|(example_index, qa_example)| { + .flat_map(|(example_index, qa_example)| { self.generate_features( qa_example, self.max_seq_len, @@ -672,7 +672,6 @@ impl QuestionAnsweringModel { example_index as i64, ) }) - .flatten() .collect(); let mut example_top_k_answers_map: HashMap> = HashMap::new(); diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 78efcbd..d5dab08 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -705,7 +705,7 @@ impl SequenceClassificationModel { input: &[&str], threshold: f64, ) -> Result>, RustBertError> { - let input_tensor = self.prepare_for_model(input.to_vec()); + let input_tensor = self.prepare_for_model(input); let output = no_grad(|| { let output = self.sequence_classifier.forward_t( Some(&input_tensor), diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 77a5ddb..8b96bc5 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -826,8 +826,7 @@ impl TokenClassificationModel { let mut features: Vec = input .iter() .enumerate() - .map(|(example_index, example)| self.generate_features(example, example_index)) - .flatten() + .flat_map(|(example_index, example)| self.generate_features(example, example_index)) .collect(); let mut example_tokens_map: HashMap> = HashMap::new();