mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
add async example, documentation and fix clippy (#217)
This commit is contained in:
parent
61b7e854b3
commit
23c5d9112a
@ -73,5 +73,6 @@ half = "1.8.2"
|
||||
anyhow = "1.0.51"
|
||||
csv = "1.1.6"
|
||||
criterion = "0.3.5"
|
||||
tokio = { version = "1.16.1", features = ["sync", "rt-multi-thread", "macros"] }
|
||||
torch-sys = "~0.6.1"
|
||||
tempfile = "3.2.0"
|
||||
|
63
examples/async-sentiment.rs
Normal file
63
examples/async-sentiment.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use std::{
|
||||
sync::mpsc,
|
||||
thread::{self, JoinHandle},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use rust_bert::pipelines::sentiment::{Sentiment, SentimentConfig, SentimentModel};
|
||||
use tokio::{sync::oneshot, task};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let (_handle, classifier) = SentimentClassifier::spawn();
|
||||
|
||||
let texts = vec![
|
||||
"Classify this positive text".to_owned(),
|
||||
"Classify this negative text".to_owned(),
|
||||
];
|
||||
let sentiments = classifier.predict(texts).await?;
|
||||
println!("Results: {:?}", sentiments);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Message type for internal channel, passing around texts and return value
|
||||
/// senders
|
||||
type Message = (Vec<String>, oneshot::Sender<Vec<Sentiment>>);
|
||||
|
||||
/// Runner for sentiment classification
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SentimentClassifier {
|
||||
sender: mpsc::SyncSender<Message>,
|
||||
}
|
||||
|
||||
impl SentimentClassifier {
|
||||
/// Spawn a classifier on a separate thread and return a classifier instance
|
||||
/// to interact with it
|
||||
pub fn spawn() -> (JoinHandle<Result<()>>, SentimentClassifier) {
|
||||
let (sender, receiver) = mpsc::sync_channel(100);
|
||||
let handle = thread::spawn(move || Self::runner(receiver));
|
||||
(handle, SentimentClassifier { sender })
|
||||
}
|
||||
|
||||
/// The classification runner itself
|
||||
fn runner(receiver: mpsc::Receiver<Message>) -> Result<()> {
|
||||
// Needs to be in sync runtime, async doesn't work
|
||||
let model = SentimentModel::new(SentimentConfig::default())?;
|
||||
|
||||
while let Ok((texts, sender)) = receiver.recv() {
|
||||
let texts: Vec<&str> = texts.iter().map(String::as_str).collect();
|
||||
let sentiments = model.predict(texts);
|
||||
sender.send(sentiments).expect("sending results");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Make the runner predict a sample and return the result
|
||||
pub async fn predict(&self, texts: Vec<String>) -> Result<Vec<Sentiment>> {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
task::block_in_place(|| self.sender.send((texts, sender)))?;
|
||||
Ok(receiver.await?)
|
||||
}
|
||||
}
|
10
src/lib.rs
10
src/lib.rs
@ -559,6 +559,13 @@
|
||||
//! `python ./utils/convert_model.py path/to/pytorch_model.bin` where `path/to/pytorch_model.bin` is the location of the original Pytorch weights.
|
||||
//!
|
||||
//!
|
||||
//! ## Async execution
|
||||
//!
|
||||
//! Creating any of the models in async context will cause panics! Running extensive calculations like running predictions in a future should be avoided, too ([see here](https://docs.rs/tokio/latest/tokio/#cpu-bound-tasks-and-blocking-code)).
|
||||
//!
|
||||
//! It is recommended to spawn a separate thread for the models. The `async-sentiment` example displays a possible solution you could use to integrate models into async code.
|
||||
//!
|
||||
//!
|
||||
//! ## Citation
|
||||
//!
|
||||
//! If you use `rust-bert` for your work, please cite [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/):
|
||||
@ -579,6 +586,9 @@
|
||||
//! Thank you to [Hugging Face](https://huggingface.co) for hosting a set of weights compatible with this Rust library.
|
||||
//! The list of ready-to-use pretrained models is listed at [https://huggingface.co/models?filter=rust](https://huggingface.co/models?filter=rust).
|
||||
|
||||
// These are used abundantly in this code
|
||||
#![allow(clippy::assign_op_pattern, clippy::upper_case_acronyms)]
|
||||
|
||||
pub mod albert;
|
||||
pub mod bart;
|
||||
pub mod bert;
|
||||
|
@ -663,7 +663,7 @@ impl QuestionAnsweringModel {
|
||||
let mut features: Vec<QaFeature> = qa_inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(example_index, qa_example)| {
|
||||
.flat_map(|(example_index, qa_example)| {
|
||||
self.generate_features(
|
||||
qa_example,
|
||||
self.max_seq_len,
|
||||
@ -672,7 +672,6 @@ impl QuestionAnsweringModel {
|
||||
example_index as i64,
|
||||
)
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
let mut example_top_k_answers_map: HashMap<usize, Vec<Answer>> = HashMap::new();
|
||||
|
@ -705,7 +705,7 @@ impl SequenceClassificationModel {
|
||||
input: &[&str],
|
||||
threshold: f64,
|
||||
) -> Result<Vec<Vec<Label>>, RustBertError> {
|
||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||
let input_tensor = self.prepare_for_model(input);
|
||||
let output = no_grad(|| {
|
||||
let output = self.sequence_classifier.forward_t(
|
||||
Some(&input_tensor),
|
||||
|
@ -826,8 +826,7 @@ impl TokenClassificationModel {
|
||||
let mut features: Vec<InputFeature> = input
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(example_index, example)| self.generate_features(example, example_index))
|
||||
.flatten()
|
||||
.flat_map(|(example_index, example)| self.generate_features(example, example_index))
|
||||
.collect();
|
||||
|
||||
let mut example_tokens_map: HashMap<usize, Vec<Token>> = HashMap::new();
|
||||
|
Loading…
Reference in New Issue
Block a user