Updated DeBERTa configuration parsing

This commit is contained in:
Guillaume Becquin 2021-11-28 12:43:35 +01:00
parent 2acbaf6627
commit 60591c0644
4 changed files with 98 additions and 7 deletions

View File

@ -57,7 +57,7 @@ all-tests = []
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~7.0.0"
rust_tokenizers = {version = "~7.0.1", path = "E:/Coding/rust-tokenizers/main"}
tch = "~0.6.1"
serde_json = "1.0.68"
serde = { version = "1.0.130", features = ["derive"] }

View File

@ -10,9 +10,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::{Activation, Config};
use serde::{Deserialize, Serialize};
use crate::{Activation, Config, RustBertError};
use serde::de::{SeqAccess, Visitor};
use serde::{de, Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
/// # DeBERTa Pretrained model weight files
pub struct DebertaModelResources;
@ -67,6 +70,47 @@ pub enum PositionAttentionType {
p2p,
}
impl FromStr for PositionAttentionType {
type Err = RustBertError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"p2c" => Ok(PositionAttentionType::p2c),
"c2p" => Ok(PositionAttentionType::c2p),
"p2p" => Ok(PositionAttentionType::p2p),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Position attention type `{}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
s
))),
}
}
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PositionAttentionTypes {
types: Vec<PositionAttentionType>,
}
impl FromStr for PositionAttentionTypes {
type Err = RustBertError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let types = s
.to_lowercase()
.split('|')
.map(|s| PositionAttentionType::from_str(s))
.collect::<Result<Vec<_>, _>>()?;
Ok(PositionAttentionTypes { types })
}
}
impl Default for PositionAttentionTypes {
fn default() -> Self {
PositionAttentionTypes { types: vec![] }
}
}
#[derive(Debug, Serialize, Deserialize)]
/// # DeBERTa model configuration
/// Defines the DeBERTa model architecture (e.g. number of layers, hidden layer size, label mapping...)
@ -82,10 +126,10 @@ pub struct DebertaConfig {
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub embedding_size: i64,
pub relative_attention: bool,
pub position_biased_input: bool,
pub pos_att_type: Option<Vec<PositionAttentionType>>,
#[serde(default, deserialize_with = "deserialize_attention_type")]
pub pos_att_type: Option<PositionAttentionTypes>,
pub pooler_dropout: Option<f64>,
pub pooler_hidden: Option<Activation>,
pub pooler_hidden_size: Option<i64>,
@ -98,4 +142,41 @@ pub struct DebertaConfig {
pub label2id: Option<HashMap<String, i64>>,
}
fn deserialize_attention_type<'de, D>(
deserializer: D,
) -> Result<Option<PositionAttentionTypes>, D::Error>
where
D: Deserializer<'de>,
{
struct AttentionTypeVisitor;
impl<'de> Visitor<'de> for AttentionTypeVisitor {
type Value = PositionAttentionTypes;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("null, string or sequence")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FromStr::from_str(value).unwrap())
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: SeqAccess<'de>,
{
let mut types = vec![];
while let Some(attention_type) = seq.next_element::<String>()? {
types.push(FromStr::from_str(attention_type.as_str()).unwrap())
}
Ok(PositionAttentionTypes { types })
}
}
deserializer.deserialize_any(AttentionTypeVisitor).map(Some)
}
impl Config for DebertaConfig {}

View File

@ -1 +1,6 @@
mod deberta_model;
pub use deberta_model::{
DebertaConfig, DebertaConfigResources, DebertaMergesResources, DebertaModelResources,
DebertaVocabResources,
};

View File

@ -5,6 +5,8 @@ import subprocess
import argparse
import sys
from torch import Tensor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
@ -22,8 +24,11 @@ if __name__ == "__main__":
if args.skip_embeddings:
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
continue
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
if isinstance(v, Tensor):
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
else:
print(f'skipped non-tensor object: {k}')
np.savez(target_folder / 'model.npz', **nps)
source = str(target_folder / 'model.npz')