diff --git a/CHANGELOG.md b/CHANGELOG.md index 84eccab..611b243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format - Addition of All-MiniLM-L6-V2 model weights - Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT) - Addition of Masked Language Model pipeline, allowing to predict masked words. +- Support for the CodeBERT language model with pretrained models for language detection and masked token prediction ## Changed - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`). diff --git a/examples/codebert.rs b/examples/codebert.rs new file mode 100644 index 0000000..d683484 --- /dev/null +++ b/examples/codebert.rs @@ -0,0 +1,85 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel}; +use rust_bert::pipelines::sequence_classification::{ + SequenceClassificationConfig, SequenceClassificationModel, +}; +use rust_bert::resources::RemoteResource; +use rust_bert::roberta::{ + RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources, +}; + +fn main() -> anyhow::Result<()> { + // Language identification + let sequence_classification_config = SequenceClassificationConfig::new( + ModelType::Roberta, + RemoteResource::from_pretrained(RobertaModelResources::CODEBERTA_LANGUAGE_ID), + RemoteResource::from_pretrained(RobertaConfigResources::CODEBERTA_LANGUAGE_ID), + RemoteResource::from_pretrained(RobertaVocabResources::CODEBERTA_LANGUAGE_ID), + Some(RemoteResource::from_pretrained( + RobertaMergesResources::CODEBERTA_LANGUAGE_ID, + )), + false, + None, + None, + ); + + let sequence_classification_model = + SequenceClassificationModel::new(sequence_classification_config)?; + + // Define input + let input = [ + "def f(x):\ + return x**2", + "outcome := rand.Intn(6) + 1", + ]; + + // Run model + let output = sequence_classification_model.predict(input); + for label in output { + println!("{:?}", label); + } + + // Masked language model + let config = MaskedLanguageConfig::new( + ModelType::Roberta, + RemoteResource::from_pretrained(RobertaModelResources::CODEBERT_MLM), + RemoteResource::from_pretrained(RobertaConfigResources::CODEBERT_MLM), + RemoteResource::from_pretrained(RobertaVocabResources::CODEBERT_MLM), + Some(RemoteResource::from_pretrained( + RobertaMergesResources::CODEBERT_MLM, + )), + false, + None, + None, + Some(String::from("")), + ); + + let mask_language_model = MaskedLanguageModel::new(config)?; + // Define input + let input = [ + "if (x is not None) (x>1)", + " x = if let (x_option) {}", + ]; + + // Run model + let output = mask_language_model.predict(input)?; + for sentence_output in output { + println!("{:?}", sentence_output); + } + + Ok(()) +} diff --git a/src/roberta/roberta_model.rs b/src/roberta/roberta_model.rs index 8bce245..a664342 100644 --- a/src/roberta/roberta_model.rs +++ b/src/roberta/roberta_model.rs @@ -69,11 +69,21 @@ impl RobertaModelResources { "xlm-roberta-ner-es/model", "https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot", ); - /// Shared under Apache 2.0 licenseat . Modified with conversion to C-array format. + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = ( "all-distilroberta-v1/model", "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot", ); + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. + pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = ( + "codeberta-language-id/model", + "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/rust_model.ot", + ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const CODEBERT_MLM: (&'static str, &'static str) = ( + "codebert-mlm/model", + "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/rust_model.ot", + ); } impl RobertaConfigResources { @@ -117,6 +127,16 @@ impl RobertaConfigResources { "all-distilroberta-v1/config", "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json", ); + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. + pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = ( + "codeberta-language-id/config", + "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/config.json", + ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const CODEBERT_MLM: (&'static str, &'static str) = ( + "codebert-mlm/config", + "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/config.json", + ); } impl RobertaVocabResources { @@ -160,6 +180,16 @@ impl RobertaVocabResources { "all-distilroberta-v1/vocab", "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json", ); + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. + pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = ( + "codeberta-language-id/vocab", + "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/vocab.json", + ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const CODEBERT_MLM: (&'static str, &'static str) = ( + "codebert-mlm/vocab", + "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/vocab.json", + ); } impl RobertaMergesResources { @@ -183,6 +213,16 @@ impl RobertaMergesResources { "all-distilroberta-v1/merges", "https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt", ); + /// Shared under Apache 2.0 license by the HuggingFace Inc. team at . Modified with conversion to C-array format. + pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = ( + "codeberta-language-id/merges", + "https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/merges.txt", + ); + /// Shared under MIT license by the Microsoft team at . Modified with conversion to C-array format. + pub const CODEBERT_MLM: (&'static str, &'static str) = ( + "codebert-mlm/merges", + "https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/merges.txt", + ); } pub struct RobertaLMHead {