CodeBERT Pretrained models and examples (#322)

* Addition of Codebert examples

* Addition of CodeBERT pretrained models, CodeBERT example
This commit is contained in:
guillaume-be 2023-01-20 19:02:33 +00:00 committed by GitHub
parent f12e8ef475
commit 0fc5ce6ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 1 deletions

View File

@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format
- Addition of All-MiniLM-L6-V2 model weights
- Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT)
- Addition of Masked Language Model pipeline, allowing to predict masked words.
- Support for the CodeBERT language model with pretrained models for language detection and masked token prediction
## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`).

85
examples/codebert.rs Normal file
View File

@ -0,0 +1,85 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
};
use rust_bert::resources::RemoteResource;
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
fn main() -> anyhow::Result<()> {
// Language identification
let sequence_classification_config = SequenceClassificationConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::CODEBERTA_LANGUAGE_ID),
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERTA_LANGUAGE_ID),
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERTA_LANGUAGE_ID),
Some(RemoteResource::from_pretrained(
RobertaMergesResources::CODEBERTA_LANGUAGE_ID,
)),
false,
None,
None,
);
let sequence_classification_model =
SequenceClassificationModel::new(sequence_classification_config)?;
// Define input
let input = [
"def f(x):\
return x**2",
"outcome := rand.Intn(6) + 1",
];
// Run model
let output = sequence_classification_model.predict(input);
for label in output {
println!("{:?}", label);
}
// Masked language model
let config = MaskedLanguageConfig::new(
ModelType::Roberta,
RemoteResource::from_pretrained(RobertaModelResources::CODEBERT_MLM),
RemoteResource::from_pretrained(RobertaConfigResources::CODEBERT_MLM),
RemoteResource::from_pretrained(RobertaVocabResources::CODEBERT_MLM),
Some(RemoteResource::from_pretrained(
RobertaMergesResources::CODEBERT_MLM,
)),
false,
None,
None,
Some(String::from("<mask>")),
);
let mask_language_model = MaskedLanguageModel::new(config)?;
// Define input
let input = [
"if (x is not None) <mask> (x>1)",
"<mask> x = if let <mask>(x_option) {}",
];
// Run model
let output = mask_language_model.predict(input)?;
for sentence_output in output {
println!("{:?}", sentence_output);
}
Ok(())
}

View File

@ -69,11 +69,21 @@ impl RobertaModelResources {
"xlm-roberta-ner-es/model",
"https://huggingface.co/xlm-roberta-large-finetuned-conll02-spanish/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/model",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/model",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/model",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/rust_model.ot",
);
}
impl RobertaConfigResources {
@ -117,6 +127,16 @@ impl RobertaConfigResources {
"all-distilroberta-v1/config",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/config",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/config.json",
);
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/config",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/config.json",
);
}
impl RobertaVocabResources {
@ -160,6 +180,16 @@ impl RobertaVocabResources {
"all-distilroberta-v1/vocab",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/vocab",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/vocab.json",
);
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/vocab",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/vocab.json",
);
}
impl RobertaMergesResources {
@ -183,6 +213,16 @@ impl RobertaMergesResources {
"all-distilroberta-v1/merges",
"https://huggingface.co/sentence-transformers/all-distilroberta-v1/resolve/main/merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at <https://huggingface.co/huggingface/CodeBERTa-language-id>. Modified with conversion to C-array format.
pub const CODEBERTA_LANGUAGE_ID: (&'static str, &'static str) = (
"codeberta-language-id/merges",
"https://huggingface.co/huggingface/CodeBERTa-language-id/resolve/main/merges.txt",
);
/// Shared under MIT license by the Microsoft team at <https://github.com/microsoft/CodeBERT>. Modified with conversion to C-array format.
pub const CODEBERT_MLM: (&'static str, &'static str) = (
"codebert-mlm/merges",
"https://huggingface.co/microsoft/codebert-base-mlm/resolve/main/merges.txt",
);
}
pub struct RobertaLMHead {