This commit is contained in:
Guillaume Becquin 2023-12-02 14:29:06 +00:00
parent 70ed5b4c95
commit db92493edc
No known key found for this signature in database
GPG Key ID: D23E3F3D92A4157D
13 changed files with 46 additions and 67 deletions

View File

@ -295,7 +295,6 @@ impl MaskedLanguageOption {
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &MaskedLanguageConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
@ -304,11 +303,7 @@ impl MaskedLanguageOption {
"An encoder file must be provided for masked language ONNX models.".to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
}
/// Returns the `ModelType` for this MaskedLanguageOption
pub fn model_type(&self) -> ModelType {

View File

@ -38,9 +38,11 @@ impl ONNXEnvironmentConfig {
pub fn from_device(device: Device) -> Self {
let mut execution_providers = Vec::new();
if let Device::Cuda(device_id) = device {
CUDAExecutionProvider::default()
.with_device_id(device_id as i32)
.build()
execution_providers.push(
CUDAExecutionProvider::default()
.with_device_id(device_id as i32)
.build(),
);
};
execution_providers.push(ExecutionProviderDispatch::CPU(
CPUExecutionProvider::default(),
@ -53,8 +55,10 @@ impl ONNXEnvironmentConfig {
///Build a session builder from an `ONNXEnvironmentConfig`.
pub fn get_session_builder(&self) -> Result<SessionBuilder, RustBertError> {
let mut session_builder =
SessionBuilder::new()?.with_execution_providers(&self.execution_providers)?;
let mut session_builder = SessionBuilder::new()?;
if let Some(execution_providers) = &self.execution_providers {
session_builder = session_builder.with_execution_providers(execution_providers)?;
};
match &self.optimization_level {
Some(GraphOptimizationLevel::Level3) | None => {}
Some(GraphOptimizationLevel::Level2) => {

View File

@ -1,29 +1,26 @@
use crate::RustBertError;
use ndarray::{ArrayBase, ArrayD, CowArray, CowRepr, IxDyn};
use ort::{Session, Value};
use ort::Value;
use std::convert::{TryFrom, TryInto};
use tch::{Kind, Tensor};
pub(crate) fn ort_tensor_to_tch(ort_tensor: &Value) -> Result<Tensor, RustBertError> {
let ort_tensor = ort_tensor.try_extract::<f32>()?.view().to_owned();
let ort_tensor = ort_tensor.extract_tensor::<f32>()?.view().to_owned();
Ok(Tensor::try_from(ort_tensor)?)
}
pub(crate) fn array_to_ort<'a>(
session: &Session,
array: &'a TypedArray<'a>,
) -> Result<Value<'a>, RustBertError> {
pub(crate) fn array_to_ort<'a>(array: &'a TypedArray<'a>) -> Result<Value, RustBertError> {
match &array {
TypedArray::I64(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::F32(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::I32(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::F64(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::F16(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::I16(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::I8(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::UI8(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::BF16(array) => Ok(Value::from_array(session.allocator(), array)?),
TypedArray::I64(array) => Ok(Value::from_array(array)?),
TypedArray::F32(array) => Ok(Value::from_array(array)?),
TypedArray::I32(array) => Ok(Value::from_array(array)?),
TypedArray::F64(array) => Ok(Value::from_array(array)?),
TypedArray::F16(array) => Ok(Value::from_array(array)?),
TypedArray::I16(array) => Ok(Value::from_array(array)?),
TypedArray::I8(array) => Ok(Value::from_array(array)?),
TypedArray::UI8(array) => Ok(Value::from_array(array)?),
TypedArray::BF16(array) => Ok(Value::from_array(array)?),
}
}

View File

@ -7,7 +7,7 @@ use crate::pipelines::onnx::config::{
use crate::pipelines::onnx::conversion::{array_to_ort, ort_tensor_to_tch, tch_tensor_to_ndarray};
use crate::pipelines::onnx::models::ONNXLayerCache;
use crate::RustBertError;
use ort::Session;
use ort::{Session, SessionInputs};
use std::collections::HashMap;
use std::path::PathBuf;
use tch::Tensor;
@ -95,10 +95,12 @@ impl ONNXDecoder {
let input_values = inputs_arrays
.iter()
.map(|array| array_to_ort(&self.session, array).unwrap())
.map(|array| array_to_ort(array).unwrap())
.collect::<Vec<_>>();
let outputs = self.session.run(input_values)?;
let outputs = self
.session
.run(SessionInputs::from(input_values.as_slice()))?;
let lm_logits =
ort_tensor_to_tch(&outputs[*self.name_mapping.output_names.get("logits").unwrap()])?;

View File

@ -5,7 +5,7 @@ use crate::pipelines::onnx::config::{
};
use crate::pipelines::onnx::conversion::{array_to_ort, ort_tensor_to_tch, tch_tensor_to_ndarray};
use crate::RustBertError;
use ort::Session;
use ort::{Session, SessionInputs};
use std::collections::HashMap;
use std::path::PathBuf;
use tch::Tensor;
@ -147,10 +147,11 @@ impl ONNXEncoder {
let input_values = inputs_arrays
.iter()
.map(|array| array_to_ort(&self.session, array).unwrap())
.map(|array| array_to_ort(array).unwrap())
.collect::<Vec<_>>();
let outputs = self.session.run(input_values)?;
let outputs = self
.session
.run(SessionInputs::from(input_values.as_slice()))?;
let last_hidden_state = self
.name_mapping
@ -183,7 +184,7 @@ impl ONNXEncoder {
.output_names
.iter()
.filter(|(name, _)| name.contains("hidden_states"))
.map(|(_, position)| outputs.get(*position))
.map(|(name, _)| outputs.get(name.as_str()))
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
.collect::<Option<Vec<_>>>();
@ -192,7 +193,7 @@ impl ONNXEncoder {
.output_names
.iter()
.filter(|(name, _)| name.contains("attentions"))
.map(|(_, position)| outputs.get(*position))
.map(|(name, _)| outputs.get(name.as_str()))
.map(|array| array.map(|array_value| ort_tensor_to_tch(array_value).unwrap()))
.collect::<Option<Vec<_>>>();
(hidden_states, attentions)

View File

@ -9,7 +9,7 @@ use crate::pipelines::onnx::encoder::ONNXEncoder;
use crate::{Config, RustBertError};
use crate::pipelines::onnx::conversion;
use ort::Value;
use ort::SessionOutputs;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::{nn, Device, Kind, Tensor};
@ -1054,7 +1054,7 @@ impl ONNXLayerCache {
/// Helper function to create a cache layer from an ONNX model output.
/// Assumes that the output names for cached keys and values contain `key` and `value` in their name, respectively.
pub fn from_ort_output(
ort_output: &[Value],
ort_output: &SessionOutputs,
key_value_names: &HashMap<String, usize>,
) -> Result<ONNXLayerCache, RustBertError> {
let values = key_value_names

View File

@ -486,7 +486,6 @@ impl QuestionAnsweringOption {
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &QuestionAnsweringConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
@ -495,11 +494,7 @@ impl QuestionAnsweringOption {
"An encoder file must be provided for question answering ONNX models.".to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
}
/// Returns the `ModelType` for this SequenceClassificationOption

View File

@ -402,7 +402,6 @@ impl SequenceClassificationOption {
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &SequenceClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
@ -412,11 +411,7 @@ impl SequenceClassificationOption {
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
}
/// Returns the `ModelType` for this SequenceClassificationOption

View File

@ -244,7 +244,7 @@ impl SummarizationOption {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
ONNXConditionalGenerator::new(config.into(), None, None)?,
ONNXConditionalGenerator::new(config.into(), None)?,
)),
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(BartGenerator::new(
config.into(),
@ -273,7 +273,7 @@ impl SummarizationOption {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(SummarizationOption::ONNX(
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None)?,
)),
(ModelType::Bart, _) => Ok(SummarizationOption::Bart(
BartGenerator::new_with_tokenizer(config.into(), tokenizer)?,

View File

@ -219,7 +219,7 @@ impl TextGenerationOption {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
ONNXCausalGenerator::new(config.into(), None, None)?,
ONNXCausalGenerator::new(config.into(), None)?,
)),
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(GPT2Generator::new(
config.into(),
@ -254,7 +254,7 @@ impl TextGenerationOption {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TextGenerationOption::ONNX(
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
ONNXCausalGenerator::new_with_tokenizer(config.into(), tokenizer, None)?,
)),
(ModelType::GPT2, _) => Ok(TextGenerationOption::GPT2(
GPT2Generator::new_with_tokenizer(config.into(), tokenizer)?,

View File

@ -516,7 +516,6 @@ impl TokenClassificationOption {
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &TokenClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
@ -526,11 +525,7 @@ impl TokenClassificationOption {
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
}
/// Returns the `ModelType` for this TokenClassificationOption

View File

@ -1123,7 +1123,7 @@ impl TranslationOption {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
ONNXConditionalGenerator::new(config.into(), None, None)?,
ONNXConditionalGenerator::new(config.into(), None)?,
)),
(ModelType::Marian, _) => Ok(TranslationOption::Marian(MarianGenerator::new(
config.into(),
@ -1150,7 +1150,7 @@ impl TranslationOption {
match (config.model_type, &config.model_resource) {
#[cfg(feature = "onnx")]
(_, &ModelResource::ONNX(_)) => Ok(TranslationOption::ONNX(
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None, None)?,
ONNXConditionalGenerator::new_with_tokenizer(config.into(), tokenizer, None)?,
)),
(ModelType::Marian, _) => Ok(TranslationOption::Marian(
MarianGenerator::new_with_tokenizer(config.into(), tokenizer)?,

View File

@ -413,7 +413,6 @@ impl ZeroShotClassificationOption {
#[cfg(feature = "onnx")]
pub fn new_onnx(config: &ZeroShotClassificationConfig) -> Result<Self, RustBertError> {
let onnx_config = ONNXEnvironmentConfig::from_device(config.device);
let environment = onnx_config.get_environment()?;
let encoder_file = config
.model_resource
.get_onnx_local_paths()?
@ -423,11 +422,7 @@ impl ZeroShotClassificationOption {
.to_string(),
))?;
Ok(Self::ONNX(ONNXEncoder::new(
encoder_file,
&environment,
&onnx_config,
)?))
Ok(Self::ONNX(ONNXEncoder::new(encoder_file, &onnx_config)?))
}
/// Returns the `ModelType` for this SequenceClassificationOption