mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-03 23:57:15 +03:00
ort 2.0
This commit is contained in:
parent
70ed5b4c95
commit
db92493edc
@ -295,7 +295,6 @@ impl MaskedLanguageOption {
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
@ -304,11 +303,7 @@ impl MaskedLanguageOption {
|
||||
"An encoder file must be provided for masked language ONNX models.".to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
|
||||
}
|
||||
/// Returns the `ModelType` for this MaskedLanguageOption
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
|
@ -38,9 +38,11 @@ impl ONNXEnvironmentConfig {
|
||||
pub fn from_device(device: Device) -> Self {
|
||||
let mut execution_providers = Vec::new();
|
||||
if let Device::Cuda(device_id) = device {
|
||||
CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id as i32)
|
||||
.build()
|
||||
execution_providers.push(
|
||||
CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id as i32)
|
||||
.build(),
|
||||
);
|
||||
};
|
||||
execution_providers.push(ExecutionProviderDispatch::CPU(
|
||||
CPUExecutionProvider::default(),
|
||||
@ -53,8 +55,10 @@ impl ONNXEnvironmentConfig {
|
||||
|
||||
///Build a session builder from an `ONNXEnvironmentConfig`.
|
||||
pub fn get_session_builder(&self) -> Result<SessionBuilder, RustBertError> {
|
||||
let mut session_builder =
|
||||
SessionBuilder::new()?.with_execution_providers(&self.execution_providers)?;
|
||||
let mut session_builder = SessionBuilder::new()?;
|
||||
if let Some(execution_providers) = &self.execution_providers {
|
||||
session_builder = session_builder.with_execution_providers(execution_providers)?;
|
||||
};
|
||||
match &self.optimization_level {
|
||||
Some(GraphOptimizationLevel::Level3) | None => {}
|
||||
Some(GraphOptimizationLevel::Level2) => {
|
||||
|
@ -1,29 +1,26 @@
|
||||
use crate::RustBertError;
|
||||
use ndarray::{ArrayBase, ArrayD, CowArray, CowRepr, IxDyn};
|
||||
|
||||
use ort::{Session, Value};
|
||||
use ort::Value;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use tch::{Kind, Tensor};
|
||||
|
||||
pub(crate) fn ort_tensor_to_tch(ort_tensor: &Value) -> Result<Tensor, RustBertError> {
|
||||
let ort_tensor = ort_tensor.try_extract::<f32>()?.view().to_owned();
|
||||
let ort_tensor = ort_tensor.extract_tensor::<f32>()?.view().to_owned();
|
||||
Ok(Tensor::try_from(ort_tensor)?)
|
||||
}
|
||||
|
||||
pub(crate) fn array_to_ort<'a>(
|
||||
session: &Session,
|
||||
array: &'a TypedArray<'a>,
|
||||
) -> Result<Value<'a>, RustBertError> {
|
||||
pub(crate) fn array_to_ort<'a>(array: &'a TypedArray<'a>) -> Result<Value, RustBertError> {
|
||||
match &array {
|
||||
TypedArray::I64(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::F32(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::I32(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::F64(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::F16(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::I16(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::I8(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::UI8(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::BF16(array) => Ok(Value::from_array(session.allocator(), array)?),
|
||||
TypedArray::I64(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::F32(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::I32(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::F64(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::F16(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::I16(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::I8(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::UI8(array) => Ok(Value::from_array(array)?),
|
||||
TypedArray::BF16(array) => Ok(Value::from_array(array)?),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@ use crate::pipelines::onnx::config::{
|
||||
use crate::pipelines::onnx::conversion::{array_to_ort, ort_tensor_to_tch, tch_tensor_to_ndarray};
|
||||
use crate::pipelines::onnx::models::ONNXLayerCache;
|
||||
use crate::RustBertError;
|
||||
use ort::Session;
|
||||
use ort::{Session, SessionInputs};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tch::Tensor;
|
||||
@ -95,10 +95,12 @@ impl ONNXDecoder {
|
||||
|
||||
let input_values = inputs_arrays
|
||||
.iter()
|
||||
.map(|array| array_to_ort(&self.session, array).unwrap())
|
||||
.map(|array| array_to_ort(array).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = self.session.run(input_values)?;
|
||||
let outputs = self
|
||||
.session
|
||||
.run(SessionInputs::from(input_values.as_slice()))?;
|
||||
|
||||
let lm_logits =
|
||||
ort_tensor_to_tch(&outputs[*self.name_mapping.output_names.get("logits").unwrap()])?;
|
||||
|
@ -5,7 +5,7 @@ use crate::pipelines::onnx::config::{
|
||||
};
|
||||
use crate::pipelines::onnx::conversion::{array_to_ort, ort_tensor_to_tch, tch_tensor_to_ndarray};
|
||||
use crate::RustBertError;
|
||||
use ort::Session;
|
||||
use ort::{Session, SessionInputs};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tch::Tensor;
|
||||
@ -147,10 +147,11 @@ impl ONNXEncoder {
|
||||
|
||||
let input_values = inputs_arrays
|
||||
.iter()
|
||||
.map(|array| array_to_ort(&self.session, array).unwrap())
|
||||
.map(|array| array_to_ort(array).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = self.session.run(input_values)?;
|
||||
let outputs = self
|
||||
.session
|
||||
.run(SessionInputs::from(input_values.as_slice()))?;
|
||||
|
||||
let last_hidden_state = self
|
||||
.name_mapping
|
||||
@ -183,7 +184,7 @@ impl ONNXEncoder {
|
||||
.output_names
|
||||
.iter()
|
||||
.filter(|(name, _)| name.contains("hidden_states"))
|
||||
.map(|(_, position)| outputs.get(*position))
|
||||
.map(|(name, _)| outputs.get(name.as_str()))
|
||||
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
|
||||
.collect::<Option<Vec<_>>>();
|
||||
|
||||
@ -192,7 +193,7 @@ impl ONNXEncoder {
|
||||
.output_names
|
||||
.iter()
|
||||
.filter(|(name, _)| name.contains("attentions"))
|
||||
.map(|(_, position)| outputs.get(*position))
|
||||
.map(|(name, _)| outputs.get(name.as_str()))
|
||||
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
|
||||
.collect::<Option<Vec<_>>>();
|
||||
(hidden_states, attentions)
|
||||
|
@ -9,7 +9,7 @@ use crate::pipelines::onnx::encoder::ONNXEncoder;
|
||||
use crate::{Config, RustBertError};
|
||||
|
||||
use crate::pipelines::onnx::conversion;
|
||||
use ort::Value;
|
||||
use ort::SessionOutputs;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
@ -1054,7 +1054,7 @@ impl ONNXLayerCache {
|
||||
/// Helper function to create a cache layer from an ONNX model output.
|
||||
/// Assumes that the output names for cached keys and values contain `key` and `value` in their name, respectively.
|
||||
pub fn from_ort_output(
|
||||
ort_output: &[Value],
|
||||
ort_output: &SessionOutputs,
|
||||
key_value_names: &HashMap<String, usize>,
|
||||
) -> Result<ONNXLayerCache, RustBertError> {
|
||||
let values = key_value_names
|
||||
|
@ -486,7 +486,6 @@ impl QuestionAnsweringOption {
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
@ -495,11 +494,7 @@ impl QuestionAnsweringOption {
|
||||
"An encoder file must be provided for question answering ONNX models.".to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
|
@ -402,7 +402,6 @@ impl SequenceClassificationOption {
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
@ -412,11 +411,7 @@ impl SequenceClassificationOption {
|
||||
.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
|
@ -244,7 +244,7 @@ impl SummarizationOption {
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
|
||||
ONNXConditionalGenerator::new(config.into(), None, None)?,
|
||||
ONNXConditionalGenerator::new(config.into(), None)?,
|
||||
)),
|
||||
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
|
||||
config.into(),
|
||||
@ -273,7 +273,7 @@ impl SummarizationOption {
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
|
||||
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
|
||||
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None)?,
|
||||
)),
|
||||
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(
|
||||
BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
|
@ -219,7 +219,7 @@ impl TextGenerationOption {
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
|
||||
ONNXCausalGenerator::new(config.into(), None, None)?,
|
||||
ONNXCausalGenerator::new(config.into(), None)?,
|
||||
)),
|
||||
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
|
||||
config.into(),
|
||||
@ -254,7 +254,7 @@ impl TextGenerationOption {
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
|
||||
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
|
||||
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None)?,
|
||||
)),
|
||||
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
|
||||
GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
|
@ -516,7 +516,6 @@ impl TokenClassificationOption {
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
@ -526,11 +525,7 @@ impl TokenClassificationOption {
|
||||
.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this TokenClassificationOption
|
||||
|
@ -1123,7 +1123,7 @@ impl TranslationOption {
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
|
||||
ONNXConditionalGenerator::new(config.into(), None, None)?,
|
||||
ONNXConditionalGenerator::new(config.into(), None)?,
|
||||
)),
|
||||
(ModelType::Marian, _) => Ok(TranslationOption::Marian(MarianGenerator::new(
|
||||
config.into(),
|
||||
@ -1150,7 +1150,7 @@ impl TranslationOption {
|
||||
match (config.model_type, &config.model_resource) {
|
||||
#[cfg(feature = "onnx")]
|
||||
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
|
||||
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
|
||||
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None)?,
|
||||
)),
|
||||
(ModelType::Marian, _) => Ok(TranslationOption::Marian(
|
||||
MarianGenerator::new_with_tokenizer(config.into(), tokenizer)?,
|
||||
|
@ -413,7 +413,6 @@ impl ZeroShotClassificationOption {
|
||||
#[cfg(feature = "onnx")]
|
||||
pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
|
||||
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
|
||||
let environment = onnx_config.get_environment()?;
|
||||
let encoder_file = config
|
||||
.model_resource
|
||||
.get_onnx_local_paths()?
|
||||
@ -423,11 +422,7 @@ impl ZeroShotClassificationOption {
|
||||
.to_string(),
|
||||
))?;
|
||||
|
||||
Ok(Self::ONNX(ONNXEncoder::new(
|
||||
encoder_file,
|
||||
&environment,
|
||||
&onnx_config,
|
||||
)?))
|
||||
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
|
Loading…
Reference in New Issue
Block a user