add async example, documentation and fix clippy (#217)

This commit is contained in:
Flix 2022-01-30 12:51:58 +01:00 committed by GitHub
parent 61b7e854b3
commit 23c5d9112a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 5 deletions

View File

@ -73,5 +73,6 @@ half = "1.8.2"
anyhow = "1.0.51"
csv = "1.1.6"
criterion = "0.3.5"
tokio = { version = "1.16.1", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "~0.6.1"
tempfile = "3.2.0"

View File

@ -0,0 +1,63 @@
use std::{
sync::mpsc,
thread::{self, JoinHandle},
};
use anyhow::Result;
use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel};
use tokio::{sync::oneshot, task};
#[tokio::main]
async fn main() -> Result<()> {
let (_handle, classifier) = SentimentClassifier::spawn();
let texts = vec![
"Classify this positive text".to_owned(),
"Classify this negative text".to_owned(),
];
let sentiments = classifier.predict(texts).await?;
println!("Results: {:?}", sentiments);
Ok(())
}
/// Message type for internal channel, passing around texts and return value
/// senders
type Message = (Vec<String>, oneshot::Sender<Vec<Sentiment>>);
/// Runner for sentiment classification
#[derive(Debug, Clone)]
pub struct SentimentClassifier {
sender: mpsc::SyncSender<Message>,
}
impl SentimentClassifier {
/// Spawn a classifier on a separate thread and return a classifier instance
/// to interact with it
pub fn spawn() -> (JoinHandle<Result<()>>, SentimentClassifier) {
let (sender, receiver) = mpsc::sync_channel(100);
let handle = thread::spawn(move || Self::runner(receiver));
(handle, SentimentClassifier { sender })
}
/// The classification runner itself
fn runner(receiver: mpsc::Receiver<Message>) -> Result<()> {
// Needs to be in sync runtime, async doesn't work
let model = SentimentModel::new(SentimentConfig::default())?;
while let Ok((texts, sender)) = receiver.recv() {
let texts: Vec<&str> = texts.iter().map(String::as_str).collect();
let sentiments = model.predict(texts);
sender.send(sentiments).expect("sending results");
}
Ok(())
}
/// Make the runner predict a sample and return the result
pub async fn predict(&self, texts: Vec<String>) -> Result<Vec<Sentiment>> {
let (sender, receiver) = oneshot::channel();
task::block_in_place(|| self.sender.send((texts, sender)))?;
Ok(receiver.await?)
}
}

View File

@ -559,6 +559,13 @@
//! `python ./utils/convert_model.py path/to/pytorch_model.bin` where `path/to/pytorch_model.bin` is the location of the original Pytorch weights.
//!
//!
//! ## Async execution
//!
//! Creating any of the models in async context will cause panics! Running extensive calculations like running predictions in a future should be avoided, too ([see here](https://docs.rs/tokio/latest/tokio/#cpu-bound-tasks-and-blocking-code)).
//!
//! It is recommended to spawn a separate thread for the models. The `async-sentiment` example displays a possible solution you could use to integrate models into async code.
//!
//!
//! ## Citation
//!
//! If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/):
@ -579,6 +586,9 @@
//! Thank you to [Hugging Face](https://huggingface.co) for hosting a set of weights compatible with this Rust library.
//! The list of ready-to-use pretrained models is listed at [https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust).
// These are used abundantly in this code
#![allow(clippy::assign_op_pattern, clippy::upper_case_acronyms)]
pub mod albert;
pub mod bart;
pub mod bert;

View File

@ -663,7 +663,7 @@ impl QuestionAnsweringModel {
let mut features: Vec<QaFeature> = qa_inputs
.iter()
.enumerate()
.map(|(example_index, qa_example)| {
.flat_map(|(example_index, qa_example)| {
self.generate_features(
qa_example,
self.max_seq_len,
@ -672,7 +672,6 @@ impl QuestionAnsweringModel {
example_index as i64,
)
})
.flatten()
.collect();
let mut example_top_k_answers_map: HashMap<usize, Vec<Answer>> = HashMap::new();

View File

@ -705,7 +705,7 @@ impl SequenceClassificationModel {
input: &[&str],
threshold: f64,
) -> Result<Vec<Vec<Label>>, RustBertError> {
let input_tensor = self.prepare_for_model(input.to_vec());
let input_tensor = self.prepare_for_model(input);
let output = no_grad(|| {
let output = self.sequence_classifier.forward_t(
Some(&input_tensor),

View File

@ -826,8 +826,7 @@ impl TokenClassificationModel {
let mut features: Vec<InputFeature> = input
.iter()
.enumerate()
.map(|(example_index, example)| self.generate_features(example, example_index))
.flatten()
.flat_map(|(example_index, example)| self.generate_features(example, example_index))
.collect();
let mut example_tokens_map: HashMap<usize, Vec<Token>> = HashMap::new();