mirror of
https://github.com/hasktorch/tokenizers.git
synced 2024-10-05 17:37:44 +03:00
add wordpiece tokenizer
This commit is contained in:
parent
d96a1ae974
commit
915cea6e1a
@ -2,22 +2,26 @@ use std::ffi::CStr;
|
||||
use std::ffi::CString;
|
||||
use std::mem::forget;
|
||||
use std::os::raw::{c_char, c_int, c_uint};
|
||||
use std::mem;
|
||||
|
||||
use tokenizers::models::bpe::BpeBuilder;
|
||||
use tokenizers::models::bpe::BPE;
|
||||
use tokenizers::models::unigram::*;
|
||||
use tokenizers::tokenizer::Encoding;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::models::wordpiece::WordPiece;
|
||||
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
|
||||
use tokenizers::processors::roberta::RobertaProcessing;
|
||||
|
||||
use tokenizers::tokenizer::Encoding;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn mk_t5_tokenizer(cvocab_file: *const c_char, ctokenizer_file: *const c_char,) -> *mut Tokenizer {
|
||||
pub extern "C" fn mk_wordpiece_tokenizer(cvocab: *const c_char) -> *mut Tokenizer {
|
||||
unsafe {
|
||||
// let t = Tokenizer::new();
|
||||
unimplemented!()
|
||||
let vocab = CStr::from_ptr(cvocab);
|
||||
if let Ok(vocab_file) = vocab.to_str() {
|
||||
let wp_builder = WordPiece::from_file(vocab_file);
|
||||
let wp = wp_builder.build().unwrap();
|
||||
let mut tokenizer = Tokenizer::new(wp);
|
||||
return Box::into_raw(Box::new(tokenizer));
|
||||
} else {
|
||||
panic!("Unable to read parameters.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,12 +34,12 @@ pub extern "C" fn mk_roberta_tokenizer(
|
||||
let vocab = CStr::from_ptr(cvocab);
|
||||
let merges = CStr::from_ptr(cmerges);
|
||||
if let (Ok(vocab_file), Ok(merges_file)) = (vocab.to_str(), merges.to_str()) {
|
||||
let bpe_builder = BPE::from_file(vocab_file, merges_file);
|
||||
let bpe = bpe_builder.build().unwrap();
|
||||
let mut tokenizer = Tokenizer::new(bpe);
|
||||
tokenizer.with_pre_tokenizer(ByteLevel::default());
|
||||
tokenizer.with_post_processor(RobertaProcessing::default());
|
||||
return Box::into_raw(Box::new(tokenizer));
|
||||
let bpe_builder = BPE::from_file(vocab_file, merges_file);
|
||||
let bpe = bpe_builder.build().unwrap();
|
||||
let mut tokenizer = Tokenizer::new(bpe);
|
||||
tokenizer.with_pre_tokenizer(ByteLevel::default());
|
||||
tokenizer.with_post_processor(RobertaProcessing::default());
|
||||
return Box::into_raw(Box::new(tokenizer));
|
||||
} else {
|
||||
panic!("Unable to read parameters.");
|
||||
}
|
||||
@ -43,7 +47,10 @@ pub extern "C" fn mk_roberta_tokenizer(
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn mk_tokenizer(cvocab: *const c_char, cmerges: *const c_char) -> *mut Tokenizer {
|
||||
pub extern "C" fn mk_bpe_tokenizer(
|
||||
cvocab: *const c_char,
|
||||
cmerges: *const c_char,
|
||||
) -> *mut Tokenizer {
|
||||
unsafe {
|
||||
let vocab = CStr::from_ptr(cvocab);
|
||||
let merges = CStr::from_ptr(cmerges);
|
||||
@ -77,7 +84,7 @@ pub extern "C" fn encode(text: *const c_char, ptr: *mut Tokenizer) -> *mut Encod
|
||||
#[repr(C)]
|
||||
pub struct CTokens {
|
||||
length: c_int,
|
||||
data: *const *const c_char
|
||||
data: *const *const c_char,
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
@ -101,7 +108,10 @@ pub extern "C" fn get_tokens(ptr: *mut Encoding) -> *mut CTokens {
|
||||
c_char_vec.push(value);
|
||||
}
|
||||
|
||||
let array = CTokens { length: cstr_vec.len() as c_int, data: c_char_vec.as_ptr()};
|
||||
let array = CTokens {
|
||||
length: cstr_vec.len() as c_int,
|
||||
data: c_char_vec.as_ptr(),
|
||||
};
|
||||
// todo - do this without leaking
|
||||
forget(cstr_vec);
|
||||
forget(c_char_vec);
|
||||
@ -112,7 +122,7 @@ pub extern "C" fn get_tokens(ptr: *mut Encoding) -> *mut CTokens {
|
||||
#[repr(C)]
|
||||
pub struct CIDs {
|
||||
length: c_uint,
|
||||
data: *const c_uint
|
||||
data: *const c_uint,
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
@ -130,7 +140,10 @@ pub extern "C" fn get_ids(ptr: *mut Encoding) -> *mut CIDs {
|
||||
}
|
||||
*/
|
||||
// forget(result);
|
||||
let mut array = CIDs { length: result.len() as c_uint, data: result.as_ptr()};
|
||||
let mut array = CIDs {
|
||||
length: result.len() as c_uint,
|
||||
data: result.as_ptr(),
|
||||
};
|
||||
return Box::into_raw(Box::new(array));
|
||||
}
|
||||
}
|
||||
|
6
bindings/haskell/tokenizers-haskell/hie.yaml
Normal file
6
bindings/haskell/tokenizers-haskell/hie.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
cradle:
|
||||
stack:
|
||||
- path: "./src"
|
||||
component: "tokenizers:lib"
|
||||
- path: "./test/Spec.hs"
|
||||
component: "tokenizers:test:tokenizers-test"
|
@ -1 +1 @@
|
||||
../target/release
|
||||
../../../target/release
|
@ -2,47 +2,72 @@
|
||||
|
||||
module Tokenizers where
|
||||
|
||||
import Foreign.Ptr
|
||||
-- import Foreign.ForeignPtr
|
||||
import Foreign.Storable
|
||||
import Foreign.C.String ( CString, newCString, peekCString )
|
||||
import Foreign.C.Types
|
||||
import Foreign.C.String (CString, newCString, peekCString)
|
||||
import Foreign.C.Types (CInt, CUInt)
|
||||
import Foreign.Ptr (Ptr, castPtr)
|
||||
import Foreign.Storable (Storable (peek, peekByteOff))
|
||||
|
||||
data CTokenizer
|
||||
|
||||
data CEncoding
|
||||
|
||||
data CTokens
|
||||
|
||||
data CIDs
|
||||
|
||||
data Tokenizer = Tokenizer {
|
||||
tok :: Ptr CTokenizer,
|
||||
vocab :: String,
|
||||
merges :: String
|
||||
}
|
||||
data Tokenizer = Tokenizer
|
||||
{ tok :: Ptr CTokenizer,
|
||||
vocab :: FilePath,
|
||||
merges :: Maybe FilePath
|
||||
}
|
||||
|
||||
newtype Encoding = Encoding {
|
||||
enc :: Ptr CEncoding
|
||||
}
|
||||
newtype Encoding = Encoding
|
||||
{ enc :: Ptr CEncoding
|
||||
}
|
||||
|
||||
instance Show Tokenizer where
|
||||
show (Tokenizer _ vocab merges) = "Huggingface Tokenizer Object\n vocab: " ++ vocab ++ "\n merges: " ++ merges
|
||||
show (Tokenizer _ vocab merges) =
|
||||
"Huggingface Tokenizer Object\n"
|
||||
<> " vocab: "
|
||||
<> vocab
|
||||
<> "\n"
|
||||
<> maybe mempty (" merges: " <>) merges
|
||||
|
||||
foreign import ccall unsafe "mk_tokenizer" r_mk_tokenizer :: CString -> CString -> IO (Ptr CTokenizer)
|
||||
foreign import ccall unsafe "mk_wordpiece_tokenizer"
|
||||
r_mk_wordpiece_tokenizer ::
|
||||
CString -> IO (Ptr CTokenizer)
|
||||
|
||||
mkTokenizer vocab merges = do
|
||||
mkWordPieceTokenizer :: FilePath -> IO Tokenizer
|
||||
mkWordPieceTokenizer vocab = do
|
||||
cvocab <- newCString $ vocab ++ "\0"
|
||||
result <- r_mk_wordpiece_tokenizer cvocab
|
||||
pure (Tokenizer result vocab Nothing)
|
||||
|
||||
foreign import ccall unsafe "mk_bpe_tokenizer"
|
||||
r_mk_bpe_tokenizer ::
|
||||
CString -> CString -> IO (Ptr CTokenizer)
|
||||
|
||||
mkBPETokenizer :: FilePath -> FilePath -> IO Tokenizer
|
||||
mkBPETokenizer vocab merges = do
|
||||
cvocab <- newCString $ vocab ++ "\0"
|
||||
cmerges <- newCString $ merges ++ "\0"
|
||||
result <- r_mk_tokenizer cvocab cmerges
|
||||
pure (Tokenizer result vocab merges)
|
||||
result <- r_mk_bpe_tokenizer cvocab cmerges
|
||||
pure (Tokenizer result vocab (Just merges))
|
||||
|
||||
foreign import ccall unsafe "mk_roberta_tokenizer" r_mk_roberta_tokenizer :: CString -> CString -> IO (Ptr CTokenizer)
|
||||
|
||||
foreign import ccall unsafe "mk_roberta_tokenizer"
|
||||
r_mk_roberta_tokenizer ::
|
||||
CString -> CString -> IO (Ptr CTokenizer)
|
||||
|
||||
mkRobertaTokenizer :: FilePath -> FilePath -> IO Tokenizer
|
||||
mkRobertaTokenizer vocab merges = do
|
||||
cvocab <- newCString $ vocab ++ "\0"
|
||||
cmerges <- newCString $ merges ++ "\0"
|
||||
result <- r_mk_roberta_tokenizer cvocab cmerges
|
||||
pure (Tokenizer result vocab merges)
|
||||
pure (Tokenizer result vocab (Just merges))
|
||||
|
||||
foreign import ccall unsafe "encode" r_encode :: CString -> Ptr CTokenizer -> IO (Ptr CEncoding)
|
||||
foreign import ccall unsafe "encode"
|
||||
r_encode ::
|
||||
CString -> Ptr CTokenizer -> IO (Ptr CEncoding)
|
||||
|
||||
encode :: Tokenizer -> String -> IO Encoding
|
||||
encode (Tokenizer tokenizer _ _) text = do
|
||||
@ -50,7 +75,9 @@ encode (Tokenizer tokenizer _ _) text = do
|
||||
encoding <- r_encode str tokenizer
|
||||
pure (Encoding encoding)
|
||||
|
||||
foreign import ccall unsafe "get_tokens" r_get_tokens :: Ptr CEncoding -> IO (Ptr CTokens)
|
||||
foreign import ccall unsafe "get_tokens"
|
||||
r_get_tokens ::
|
||||
Ptr CEncoding -> IO (Ptr CTokens)
|
||||
|
||||
getTokens :: Encoding -> IO [String]
|
||||
getTokens (Encoding encoding) = do
|
||||
@ -58,27 +85,28 @@ getTokens (Encoding encoding) = do
|
||||
-- 1st value of struct is the # of tokens
|
||||
sz <- fromIntegral <$> (peek (castPtr ptr) :: IO CInt) :: IO Int
|
||||
-- 2nd value of struct is the array of tokens
|
||||
tokens <- peekByteOff ptr 8 :: IO (Ptr CString)
|
||||
mapM
|
||||
(\idx -> peekByteOff tokens (step*idx) >>= peekCString)
|
||||
[0 .. sz-1]
|
||||
tokens <- peekByteOff ptr 8 :: IO (Ptr CString)
|
||||
mapM
|
||||
(\idx -> peekByteOff tokens (step * idx) >>= peekCString)
|
||||
[0 .. sz -1]
|
||||
where
|
||||
step = 8
|
||||
|
||||
cleanTokens :: String -> String
|
||||
cleanTokens xs = [x | x <- xs, x `notElem` "\288"]
|
||||
|
||||
foreign import ccall unsafe "get_ids" r_get_ids :: Ptr CEncoding -> IO (Ptr CIDs)
|
||||
|
||||
getIDs:: Encoding -> IO [Int]
|
||||
getIDs :: Encoding -> IO [Int]
|
||||
getIDs (Encoding encoding) = do
|
||||
ptr <- r_get_ids encoding
|
||||
-- 1st value of struct is the # of tokens
|
||||
sz <- fromIntegral <$> (peek (castPtr ptr) :: IO CUInt) :: IO Int
|
||||
-- print $ "SIZE " ++ show sz
|
||||
-- 2nd value of struct is the array of tokens
|
||||
tokens <- peekByteOff ptr 8 :: IO (Ptr CUInt)
|
||||
(mapM
|
||||
(\idx -> (peekByteOff tokens (step*idx) :: IO CUInt) >>= (pure . fromIntegral))
|
||||
[0 .. sz-1])
|
||||
tokens <- peekByteOff ptr 8 :: IO (Ptr CUInt)
|
||||
mapM
|
||||
(\idx -> (peekByteOff tokens (step * idx) :: IO CUInt) >>= (pure . fromIntegral))
|
||||
[0 .. sz -1]
|
||||
where
|
||||
step = 4
|
||||
|
@ -1,5 +1,4 @@
|
||||
resolver:
|
||||
url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/16/28.yaml
|
||||
resolver: lts-17.9
|
||||
|
||||
packages:
|
||||
- .
|
||||
|
@ -28,4 +28,3 @@ main = do
|
||||
test "hi hi hi hi hi hi hi hi" tokenizer
|
||||
test "hi there hi there hi there hi there hi there hi there hi there hi there" tokenizer
|
||||
test "hello world. Let's try tokenizing this. hi hi hi and hello hello" tokenizer
|
||||
|
@ -10,21 +10,32 @@ build-type: Simple
|
||||
cabal-version: >=1.10
|
||||
|
||||
library
|
||||
exposed-modules: Tokenizers
|
||||
hs-source-dirs: src
|
||||
exposed-modules:
|
||||
Tokenizers
|
||||
Paths_tokenizers
|
||||
autogen-modules:
|
||||
Paths_tokenizers
|
||||
hs-source-dirs:
|
||||
src
|
||||
ghc-options: -W -Wall -dcore-lint
|
||||
build-depends:
|
||||
base >= 4.7 && < 5
|
||||
extra-libraries:
|
||||
tokenizers_haskell
|
||||
default-language: Haskell2010
|
||||
build-depends: base >= 4.7 && < 5
|
||||
extra-libraries: tokenizers_haskell
|
||||
|
||||
executable haskell-test
|
||||
hs-source-dirs: src
|
||||
main-is: Main.hs
|
||||
other-modules: Tokenizers
|
||||
default-language: Haskell2010
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, tokenizers
|
||||
extra-libraries: tokenizers_haskell
|
||||
-- extra-lib-dirs: ./lib
|
||||
test-suite tokenizers-test
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: Spec.hs
|
||||
hs-source-dirs:
|
||||
test
|
||||
ghc-options: -W -Wall -dcore-lint
|
||||
build-depends:
|
||||
base >= 4.7 && < 5
|
||||
, tokenizers
|
||||
extra-libraries:
|
||||
tokenizers_haskell
|
||||
default-language: Haskell2010
|
||||
|
||||
-- executable download-vocab
|
||||
-- hs-source-dirs: src
|
||||
|
Loading…
Reference in New Issue
Block a user