mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 22:19:05 +03:00
Updated DeBERTa configuration parsing
This commit is contained in:
parent
2acbaf6627
commit
60591c0644
@ -57,7 +57,7 @@ all-tests = []
|
||||
features = ["doc-only"]
|
||||
|
||||
[dependencies]
|
||||
rust_tokenizers = "~7.0.0"
|
||||
rust_tokenizers = {version = "~7.0.1", path = "E:/Coding/rust-tokenizers/main"}
|
||||
tch = "~0.6.1"
|
||||
serde_json = "1.0.68"
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
|
@ -10,9 +10,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::{Activation, Config};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::{Activation, Config, RustBertError};
|
||||
use serde::de::{SeqAccess, Visitor};
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// # DeBERTa Pretrained model weight files
|
||||
pub struct DebertaModelResources;
|
||||
@ -67,6 +70,47 @@ pub enum PositionAttentionType {
|
||||
p2p,
|
||||
}
|
||||
|
||||
impl FromStr for PositionAttentionType {
|
||||
type Err = RustBertError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"p2c" => Ok(PositionAttentionType::p2c),
|
||||
"c2p" => Ok(PositionAttentionType::c2p),
|
||||
"p2p" => Ok(PositionAttentionType::p2p),
|
||||
_ => Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Position attention type `{}` not in accepted variants (`p2c`, `c2p`, `p2p`)",
|
||||
s
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PositionAttentionTypes {
|
||||
types: Vec<PositionAttentionType>,
|
||||
}
|
||||
|
||||
impl FromStr for PositionAttentionTypes {
|
||||
type Err = RustBertError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let types = s
|
||||
.to_lowercase()
|
||||
.split('|')
|
||||
.map(|s| PositionAttentionType::from_str(s))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(PositionAttentionTypes { types })
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PositionAttentionTypes {
|
||||
fn default() -> Self {
|
||||
PositionAttentionTypes { types: vec![] }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # DeBERTa model configuration
|
||||
/// Defines the DeBERTa model architecture (e.g. number of layers, hidden layer size, label mapping...)
|
||||
@ -82,10 +126,10 @@ pub struct DebertaConfig {
|
||||
pub num_hidden_layers: i64,
|
||||
pub type_vocab_size: i64,
|
||||
pub vocab_size: i64,
|
||||
pub embedding_size: i64,
|
||||
pub relative_attention: bool,
|
||||
pub position_biased_input: bool,
|
||||
pub pos_att_type: Option<Vec<PositionAttentionType>>,
|
||||
#[serde(default, deserialize_with = "deserialize_attention_type")]
|
||||
pub pos_att_type: Option<PositionAttentionTypes>,
|
||||
pub pooler_dropout: Option<f64>,
|
||||
pub pooler_hidden: Option<Activation>,
|
||||
pub pooler_hidden_size: Option<i64>,
|
||||
@ -98,4 +142,41 @@ pub struct DebertaConfig {
|
||||
pub label2id: Option<HashMap<String, i64>>,
|
||||
}
|
||||
|
||||
fn deserialize_attention_type<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<PositionAttentionTypes>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct AttentionTypeVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for AttentionTypeVisitor {
|
||||
type Value = PositionAttentionTypes;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("null, string or sequence")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(FromStr::from_str(value).unwrap())
|
||||
}
|
||||
|
||||
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
|
||||
where
|
||||
S: SeqAccess<'de>,
|
||||
{
|
||||
let mut types = vec![];
|
||||
while let Some(attention_type) = seq.next_element::<String>()? {
|
||||
types.push(FromStr::from_str(attention_type.as_str()).unwrap())
|
||||
}
|
||||
Ok(PositionAttentionTypes { types })
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(AttentionTypeVisitor).map(Some)
|
||||
}
|
||||
|
||||
impl Config for DebertaConfig {}
|
||||
|
@ -1 +1,6 @@
|
||||
mod deberta_model;
|
||||
|
||||
pub use deberta_model::{
|
||||
DebertaConfig, DebertaConfigResources, DebertaMergesResources, DebertaModelResources,
|
||||
DebertaVocabResources,
|
||||
};
|
||||
|
@ -5,6 +5,8 @@ import subprocess
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
|
||||
@ -22,8 +24,11 @@ if __name__ == "__main__":
|
||||
if args.skip_embeddings:
|
||||
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
|
||||
continue
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
|
||||
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
|
||||
if isinstance(v, Tensor):
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy().astype(np.float32))
|
||||
print(f'converted {k} - {str(sys.getsizeof(nps[k]))} bytes')
|
||||
else:
|
||||
print(f'skipped non-tensor object: {k}')
|
||||
np.savez(target_folder / 'model.npz', **nps)
|
||||
|
||||
source = str(target_folder / 'model.npz')
|
||||
|
Loading…
Reference in New Issue
Block a user