add wordpiece tokenizer

This commit is contained in:
Torsten Scholak 2021-04-12 23:00:49 -04:00
parent d96a1ae974
commit 915cea6e1a
No known key found for this signature in database
GPG Key ID: EF135E6C40866D80
9 changed files with 126 additions and 70 deletions

View File

@ -2,22 +2,26 @@ use std::ffi::CStr;
use std::ffi::CString;
use std::mem::forget;
use std::os::raw::{c_char, c_int, c_uint};
use std::mem;
use tokenizers::models::bpe::BpeBuilder;
use tokenizers::models::bpe::BPE;
use tokenizers::models::unigram::*;
use tokenizers::tokenizer::Encoding;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::processors::roberta::RobertaProcessing;
use tokenizers::tokenizer::Encoding;
use tokenizers::tokenizer::Tokenizer;
#[no_mangle]
pub extern "C" fn mk_t5_tokenizer(cvocab_file: *const c_char, ctokenizer_file: *const c_char,) -> *mut Tokenizer {
pub extern "C" fn mk_wordpiece_tokenizer(cvocab: *const c_char) -> *mut Tokenizer {
unsafe {
// let t = Tokenizer::new();
unimplemented!()
let vocab = CStr::from_ptr(cvocab);
if let Ok(vocab_file) = vocab.to_str() {
let wp_builder = WordPiece::from_file(vocab_file);
let wp = wp_builder.build().unwrap();
let mut tokenizer = Tokenizer::new(wp);
return Box::into_raw(Box::new(tokenizer));
} else {
panic!("Unable to read parameters.");
}
}
}
@ -30,12 +34,12 @@ pub extern "C" fn mk_roberta_tokenizer(
let vocab = CStr::from_ptr(cvocab);
let merges = CStr::from_ptr(cmerges);
if let (Ok(vocab_file), Ok(merges_file)) = (vocab.to_str(), merges.to_str()) {
let bpe_builder = BPE::from_file(vocab_file, merges_file);
let bpe = bpe_builder.build().unwrap();
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_post_processor(RobertaProcessing::default());
return Box::into_raw(Box::new(tokenizer));
let bpe_builder = BPE::from_file(vocab_file, merges_file);
let bpe = bpe_builder.build().unwrap();
let mut tokenizer = Tokenizer::new(bpe);
tokenizer.with_pre_tokenizer(ByteLevel::default());
tokenizer.with_post_processor(RobertaProcessing::default());
return Box::into_raw(Box::new(tokenizer));
} else {
panic!("Unable to read parameters.");
}
@ -43,7 +47,10 @@ pub extern "C" fn mk_roberta_tokenizer(
}
#[no_mangle]
pub extern "C" fn mk_tokenizer(cvocab: *const c_char, cmerges: *const c_char) -> *mut Tokenizer {
pub extern "C" fn mk_bpe_tokenizer(
cvocab: *const c_char,
cmerges: *const c_char,
) -> *mut Tokenizer {
unsafe {
let vocab = CStr::from_ptr(cvocab);
let merges = CStr::from_ptr(cmerges);
@ -77,7 +84,7 @@ pub extern "C" fn encode(text: *const c_char, ptr: *mut Tokenizer) -> *mut Encod
#[repr(C)]
pub struct CTokens {
length: c_int,
data: *const *const c_char
data: *const *const c_char,
}
#[no_mangle]
@ -101,7 +108,10 @@ pub extern "C" fn get_tokens(ptr: *mut Encoding) -> *mut CTokens {
c_char_vec.push(value);
}
let array = CTokens { length: cstr_vec.len() as c_int, data: c_char_vec.as_ptr()};
let array = CTokens {
length: cstr_vec.len() as c_int,
data: c_char_vec.as_ptr(),
};
// todo - do this without leaking
forget(cstr_vec);
forget(c_char_vec);
@ -112,7 +122,7 @@ pub extern "C" fn get_tokens(ptr: *mut Encoding) -> *mut CTokens {
#[repr(C)]
pub struct CIDs {
length: c_uint,
data: *const c_uint
data: *const c_uint,
}
#[no_mangle]
@ -130,7 +140,10 @@ pub extern "C" fn get_ids(ptr: *mut Encoding) -> *mut CIDs {
}
*/
// forget(result);
let mut array = CIDs { length: result.len() as c_uint, data: result.as_ptr()};
let mut array = CIDs {
length: result.len() as c_uint,
data: result.as_ptr(),
};
return Box::into_raw(Box::new(array));
}
}

View File

@ -0,0 +1,6 @@
cradle:
stack:
- path: "./src"
component: "tokenizers:lib"
- path: "./test/Spec.hs"
component: "tokenizers:test:tokenizers-test"

View File

@ -1 +1 @@
../target/release
../../../target/release

View File

@ -2,47 +2,72 @@
module Tokenizers where
import Foreign.Ptr
-- import Foreign.ForeignPtr
import Foreign.Storable
import Foreign.C.String ( CString, newCString, peekCString )
import Foreign.C.Types
import Foreign.C.String (CString, newCString, peekCString)
import Foreign.C.Types (CInt, CUInt)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable (peek, peekByteOff))
data CTokenizer
data CEncoding
data CTokens
data CIDs
data Tokenizer = Tokenizer {
tok :: Ptr CTokenizer,
vocab :: String,
merges :: String
}
data Tokenizer = Tokenizer
{ tok :: Ptr CTokenizer,
vocab :: FilePath,
merges :: Maybe FilePath
}
newtype Encoding = Encoding {
enc :: Ptr CEncoding
}
newtype Encoding = Encoding
{ enc :: Ptr CEncoding
}
instance Show Tokenizer where
show (Tokenizer _ vocab merges) = "Huggingface Tokenizer Object\n vocab: " ++ vocab ++ "\n merges: " ++ merges
show (Tokenizer _ vocab merges) =
"Huggingface Tokenizer Object\n"
<> " vocab: "
<> vocab
<> "\n"
<> maybe mempty (" merges: " <>) merges
foreign import ccall unsafe "mk_tokenizer" r_mk_tokenizer :: CString -> CString -> IO (Ptr CTokenizer)
foreign import ccall unsafe "mk_wordpiece_tokenizer"
r_mk_wordpiece_tokenizer ::
CString -> IO (Ptr CTokenizer)
mkTokenizer vocab merges = do
mkWordPieceTokenizer :: FilePath -> IO Tokenizer
mkWordPieceTokenizer vocab = do
cvocab <- newCString $ vocab ++ "\0"
result <- r_mk_wordpiece_tokenizer cvocab
pure (Tokenizer result vocab Nothing)
foreign import ccall unsafe "mk_bpe_tokenizer"
r_mk_bpe_tokenizer ::
CString -> CString -> IO (Ptr CTokenizer)
mkBPETokenizer :: FilePath -> FilePath -> IO Tokenizer
mkBPETokenizer vocab merges = do
cvocab <- newCString $ vocab ++ "\0"
cmerges <- newCString $ merges ++ "\0"
result <- r_mk_tokenizer cvocab cmerges
pure (Tokenizer result vocab merges)
result <- r_mk_bpe_tokenizer cvocab cmerges
pure (Tokenizer result vocab (Just merges))
foreign import ccall unsafe "mk_roberta_tokenizer" r_mk_roberta_tokenizer :: CString -> CString -> IO (Ptr CTokenizer)
foreign import ccall unsafe "mk_roberta_tokenizer"
r_mk_roberta_tokenizer ::
CString -> CString -> IO (Ptr CTokenizer)
mkRobertaTokenizer :: FilePath -> FilePath -> IO Tokenizer
mkRobertaTokenizer vocab merges = do
cvocab <- newCString $ vocab ++ "\0"
cmerges <- newCString $ merges ++ "\0"
result <- r_mk_roberta_tokenizer cvocab cmerges
pure (Tokenizer result vocab merges)
pure (Tokenizer result vocab (Just merges))
foreign import ccall unsafe "encode" r_encode :: CString -> Ptr CTokenizer -> IO (Ptr CEncoding)
foreign import ccall unsafe "encode"
r_encode ::
CString -> Ptr CTokenizer -> IO (Ptr CEncoding)
encode :: Tokenizer -> String -> IO Encoding
encode (Tokenizer tokenizer _ _) text = do
@ -50,7 +75,9 @@ encode (Tokenizer tokenizer _ _) text = do
encoding <- r_encode str tokenizer
pure (Encoding encoding)
foreign import ccall unsafe "get_tokens" r_get_tokens :: Ptr CEncoding -> IO (Ptr CTokens)
foreign import ccall unsafe "get_tokens"
r_get_tokens ::
Ptr CEncoding -> IO (Ptr CTokens)
getTokens :: Encoding -> IO [String]
getTokens (Encoding encoding) = do
@ -58,27 +85,28 @@ getTokens (Encoding encoding) = do
-- 1st value of struct is the # of tokens
sz <- fromIntegral <$> (peek (castPtr ptr) :: IO CInt) :: IO Int
-- 2nd value of struct is the array of tokens
tokens <- peekByteOff ptr 8 :: IO (Ptr CString)
mapM
(\idx -> peekByteOff tokens (step*idx) >>= peekCString)
[0 .. sz-1]
tokens <- peekByteOff ptr 8 :: IO (Ptr CString)
mapM
(\idx -> peekByteOff tokens (step * idx) >>= peekCString)
[0 .. sz -1]
where
step = 8
cleanTokens :: String -> String
cleanTokens xs = [x | x <- xs, x `notElem` "\288"]
foreign import ccall unsafe "get_ids" r_get_ids :: Ptr CEncoding -> IO (Ptr CIDs)
getIDs:: Encoding -> IO [Int]
getIDs :: Encoding -> IO [Int]
getIDs (Encoding encoding) = do
ptr <- r_get_ids encoding
-- 1st value of struct is the # of tokens
sz <- fromIntegral <$> (peek (castPtr ptr) :: IO CUInt) :: IO Int
-- print $ "SIZE " ++ show sz
-- 2nd value of struct is the array of tokens
tokens <- peekByteOff ptr 8 :: IO (Ptr CUInt)
(mapM
(\idx -> (peekByteOff tokens (step*idx) :: IO CUInt) >>= (pure . fromIntegral))
[0 .. sz-1])
tokens <- peekByteOff ptr 8 :: IO (Ptr CUInt)
mapM
(\idx -> (peekByteOff tokens (step * idx) :: IO CUInt) >>= (pure . fromIntegral))
[0 .. sz -1]
where
step = 4

View File

@ -1,5 +1,4 @@
resolver:
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/16/28.yaml
resolver: lts-17.9
packages:
- .

View File

@ -28,4 +28,3 @@ main = do
test "hi hi hi hi hi hi hi hi" tokenizer
test "hi there hi there hi there hi there hi there hi there hi there hi there" tokenizer
test "hello world. Let's try tokenizing this. hi hi hi and hello hello" tokenizer

View File

@ -10,21 +10,32 @@ build-type: Simple
cabal-version: >=1.10
library
exposed-modules: Tokenizers
hs-source-dirs: src
exposed-modules:
Tokenizers
Paths_tokenizers
autogen-modules:
Paths_tokenizers
hs-source-dirs:
src
ghc-options: -W -Wall -dcore-lint
build-depends:
base >= 4.7 && < 5
extra-libraries:
tokenizers_haskell
default-language: Haskell2010
build-depends: base >= 4.7 && < 5
extra-libraries: tokenizers_haskell
executable haskell-test
hs-source-dirs: src
main-is: Main.hs
other-modules: Tokenizers
default-language: Haskell2010
build-depends: base >= 4.7 && < 5
, tokenizers
extra-libraries: tokenizers_haskell
-- extra-lib-dirs: ./lib
test-suite tokenizers-test
type: exitcode-stdio-1.0
main-is: Spec.hs
hs-source-dirs:
test
ghc-options: -W -Wall -dcore-lint
build-depends:
base >= 4.7 && < 5
, tokenizers
extra-libraries:
tokenizers_haskell
default-language: Haskell2010
-- executable download-vocab
-- hs-source-dirs: src

View File

@ -6,7 +6,7 @@ with pkgs;
let
shell = mkShell {
nativeBuildInputs = [ cargo libiconv pkgconfig ];
nativeBuildInputs = [ cargo rustc rls libiconv pkgconfig ];
};
in