mirror of
https://github.com/hasktorch/tokenizers.git
synced 2024-09-11 05:15:46 +03:00
make critical stability improvements
This commit is contained in:
parent
5ab5ca03f8
commit
41bf8695a9
@ -15,14 +15,12 @@ use tokenizers::AddedToken;
|
||||
pub extern "C" fn deserialize_tokenizer(cconfig: *const c_char) -> *mut Tokenizer {
|
||||
unsafe {
|
||||
let config = CStr::from_ptr(cconfig);
|
||||
if let Ok(config_file) = config.to_str() {
|
||||
if let Ok(tokenizer) = Tokenizer::from_file(config_file) {
|
||||
return Box::into_raw(Box::new(tokenizer));
|
||||
} else {
|
||||
panic!("Unable to read tokenizer from file.");
|
||||
}
|
||||
} else {
|
||||
panic!("Unable to read config.");
|
||||
match config.to_str() {
|
||||
Ok(config_file) => match Tokenizer::from_file(config_file) {
|
||||
Ok(tokenizer) => return Box::into_raw(Box::new(tokenizer)),
|
||||
Err(error) => panic!("Unable to read tokenizer from file: {:?}", error),
|
||||
},
|
||||
Err(error) => panic!("Unable to read config: {:?}", error),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -31,14 +29,12 @@ pub extern "C" fn deserialize_tokenizer(cconfig: *const c_char) -> *mut Tokenize
|
||||
pub extern "C" fn deserialize_tokenizer_from_json(cjson: *const c_char) -> *mut Tokenizer {
|
||||
unsafe {
|
||||
let json = CStr::from_ptr(cjson);
|
||||
if let Ok(json_str) = json.to_str() {
|
||||
if let Ok(tokenizer) = Tokenizer::from_str(json_str) {
|
||||
return Box::into_raw(Box::new(tokenizer));
|
||||
} else {
|
||||
panic!("Unable to read tokenizer from json.");
|
||||
}
|
||||
} else {
|
||||
panic!("Unable to read json string.");
|
||||
match json.to_str() {
|
||||
Ok(json_str) => match Tokenizer::from_str(json_str) {
|
||||
Ok(tokenizer) => return Box::into_raw(Box::new(tokenizer)),
|
||||
Err(error) => panic!("Unable to read tokenizer from json: {:?}", error),
|
||||
},
|
||||
Err(error) => panic!("Unable to read json string: {:?}", error),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -51,14 +47,12 @@ pub extern "C" fn serialize_tokenizer(cconfig: *const c_char, ptr: *mut Tokenize
|
||||
assert!(!ptr.is_null());
|
||||
&mut *ptr
|
||||
};
|
||||
if let Ok(config_file) = config.to_str() {
|
||||
if let Ok(res) = tokenizer.save(config_file, false) {
|
||||
return res;
|
||||
} else {
|
||||
panic!("Unable to save tokenizer to file.");
|
||||
}
|
||||
} else {
|
||||
panic!("Unable to read config.");
|
||||
match config.to_str() {
|
||||
Ok(config_file) => match tokenizer.save(config_file, false) {
|
||||
Ok(res) => return res,
|
||||
Err(error) => panic!("Unable to save tokenizer to file: {:?}", error),
|
||||
},
|
||||
Err(error) => panic!("Unable to read config: {:?}", error),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -160,12 +154,19 @@ pub extern "C" fn decode(clength: c_uint, cids: *const c_uint, ptr: *mut Tokeniz
|
||||
unsafe { ids.push(*p) };
|
||||
}
|
||||
ids.shrink_to_fit();
|
||||
if let Ok(res) = tokenizer.decode(ids, false) {
|
||||
let c_str = CString::new(res).unwrap();
|
||||
let ptr = c_str.into_raw();
|
||||
return ptr;
|
||||
} else {
|
||||
panic!("Unable to decode ids.");
|
||||
let ids_ = ids.clone();
|
||||
match tokenizer.decode(ids, false) {
|
||||
Ok(res) => {
|
||||
let res_ = res.clone();
|
||||
match CString::new(res) {
|
||||
Ok(c_str) => {
|
||||
let ptr = c_str.into_raw();
|
||||
return ptr;
|
||||
}
|
||||
Err(error) => panic!("Unable to convert tokenizer result to CString: {:?} {:?} {:?}", error, res_, ids_),
|
||||
}
|
||||
},
|
||||
Err(error) => panic!("Unable to decode ids: {:?}", error),
|
||||
}
|
||||
}
|
||||
|
||||
@ -177,11 +178,12 @@ pub extern "C" fn add_special_token(ctoken: *const c_char, ptr: *mut Tokenizer)
|
||||
assert!(!ptr.is_null());
|
||||
&mut *ptr
|
||||
};
|
||||
if let Ok(s) = cstring.to_str() {
|
||||
let token = AddedToken::from(s, true);
|
||||
tokenizer.add_special_tokens(&[token]);
|
||||
} else {
|
||||
panic!("Unable to read token.");
|
||||
match cstring.to_str() {
|
||||
Ok(s) => {
|
||||
let token = AddedToken::from(s, true);
|
||||
tokenizer.add_special_tokens(&[token]);
|
||||
}
|
||||
Err(error) => panic!("Unable to read token: {:?}", error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,8 +4,7 @@ module Tokenizers where
|
||||
|
||||
import Control.Applicative (empty)
|
||||
import Control.Exception (bracket)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.ByteString.Unsafe (unsafeUseAsCString)
|
||||
import Data.ByteString (ByteString, useAsCString)
|
||||
import Foreign.C.String (CString, peekCString, withCString)
|
||||
import Foreign.C.Types (CInt, CUInt (..))
|
||||
import Foreign.Marshal.Array (withArrayLen)
|
||||
@ -57,7 +56,7 @@ foreign import ccall unsafe "deserialize_tokenizer_from_json"
|
||||
|
||||
createTokenizerFromJSONConfig :: ByteString -> IO Tokenizer
|
||||
createTokenizerFromJSONConfig json =
|
||||
unsafeUseAsCString
|
||||
useAsCString
|
||||
json
|
||||
( \cjson ->
|
||||
Tokenizer
|
||||
|
@ -3,19 +3,27 @@
|
||||
|
||||
module Main where
|
||||
|
||||
import qualified Data.ByteString.Lazy as LBS (toStrict)
|
||||
import Data.Hashable (hash)
|
||||
import qualified Network.HTTP.Client as HTTP
|
||||
import qualified Network.HTTP.Client.TLS as HTTP
|
||||
import qualified Test.Tasty as T
|
||||
import qualified Test.Tasty.HUnit as H
|
||||
import qualified Data.ByteString as BS (readFile)
|
||||
import Tokenizers (Tokenizer, addSpecialToken, cleanTokens, createTokenizerFromConfigFile, createTokenizerFromJSONConfig, decode, encode, freeTokenizer, getIDs, getTokens, mkRobertaTokenizer)
|
||||
import Tokenizers (Tokenizer, addSpecialToken, cleanTokens, createTokenizerFromJSONConfig, decode, encode, freeTokenizer, getIDs, getTokens, mkRobertaTokenizer)
|
||||
|
||||
data TestItem
|
||||
= Group String [TestItem]
|
||||
| EncodeBart String [Int]
|
||||
| DecodeBart [Int] String
|
||||
| MassDecodeBart [Int] [Int]
|
||||
| IncrementalDecodeBart [Int] [Int]
|
||||
| EncodeRoberta String [Int]
|
||||
| DecodeRoberta [Int] String
|
||||
| EncodeT5 String [Int]
|
||||
| DecodeT5 [Int] String
|
||||
| MassDecodeT5 [Int] [Int]
|
||||
| IncrementalDecodeT5 [Int] [Int]
|
||||
| IncrementalDecodeT5Fail [Int] [Int]
|
||||
deriving stock (Eq, Show)
|
||||
|
||||
data TestTokenizers = TestTokenizers
|
||||
@ -60,7 +68,14 @@ bartTests =
|
||||
EncodeBart "<s>Hello <mask>!</s><pad>" [0, 31414, 50264, 328, 2, 1],
|
||||
EncodeBart "<s> Hello <mask> ! </s> <pad>" [0, 1437, 1437, 20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1],
|
||||
DecodeBart [0, 31414, 50264, 328, 2, 1] "<s>Hello<mask>!</s><pad>",
|
||||
DecodeBart [0, 1437, 1437, 20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1] "<s> Hello <mask> ! </s> <pad>"
|
||||
DecodeBart [0, 1437, 1437, 20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1] "<s> Hello <mask> ! </s> <pad>",
|
||||
MassDecodeBart [0, 21959, 1721, 44664, 2103, 4, 42351, 11974, 2103] ([0 .. 50107] <> [50109 .. 100000]),
|
||||
IncrementalDecodeBart [0] [31414, 50264, 328, 2, 1],
|
||||
IncrementalDecodeBart [0, 31414] [50264, 328, 2, 1],
|
||||
IncrementalDecodeBart [0, 31414, 50264, 328] [2, 1],
|
||||
IncrementalDecodeBart [0, 1437, 1437] [20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1],
|
||||
IncrementalDecodeBart [0, 1437, 1437, 20920] [1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1],
|
||||
IncrementalDecodeBart [0, 1437, 1437, 20920, 1437, 1437, 50264] [1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1]
|
||||
]
|
||||
|
||||
robertaTests :: [TestItem]
|
||||
@ -76,7 +91,11 @@ t5Tests =
|
||||
[ EncodeT5 "<pad>Hello world!</s>" [0, 8774, 296, 55, 1],
|
||||
EncodeT5 "<pad>Hello <extra_id_0>!</s><pad>" [0, 8774, 32099, 3, 55, 1, 0],
|
||||
EncodeT5 "<pad> Hello <extra_id_0> ! </s> <pad>" [0, 8774, 32099, 3, 55, 1, 0],
|
||||
DecodeT5 [0, 8774, 32099, 3, 55, 1, 0] "<pad> Hello<extra_id_0> !</s><pad>"
|
||||
DecodeT5 [0, 8774, 32099, 3, 55, 1, 0] "<pad> Hello<extra_id_0> !</s><pad>",
|
||||
MassDecodeT5 [0, 8774, 32099, 3, 55, 1] [0 .. 100000],
|
||||
IncrementalDecodeT5Fail [0] [8774, 32099, 3, 55, 1, 0],
|
||||
IncrementalDecodeT5 [0, 8774, 32099, 3, 55] [1, 0],
|
||||
IncrementalDecodeT5Fail [0, 4219, 834, 7, 9963, 1820, 1738] [3476]
|
||||
]
|
||||
|
||||
testData :: TestItem
|
||||
@ -95,12 +114,26 @@ testTree =
|
||||
freeTokenizers
|
||||
(toTest testData)
|
||||
where
|
||||
createTokenizer url expectedHash = do
|
||||
manager <- HTTP.newTlsManagerWith HTTP.tlsManagerSettings
|
||||
request <- HTTP.parseRequest url
|
||||
response <- HTTP.httpLbs request manager
|
||||
let body = LBS.toStrict . HTTP.responseBody $ response
|
||||
H.assertEqual "Unexpected json hash" expectedHash (hash body)
|
||||
createTokenizerFromJSONConfig body
|
||||
createTokenizers = do
|
||||
bartTokenizer <- do
|
||||
json <- BS.readFile "models/bart-base-tokenizer.json"
|
||||
createTokenizerFromJSONConfig json
|
||||
robertaTokenizer <- createTokenizerFromConfigFile "models/roberta-base-tokenizer.json"
|
||||
t5Tokenizer <- createTokenizerFromConfigFile "models/t5-base-tokenizer.json"
|
||||
bartTokenizer <-
|
||||
createTokenizer
|
||||
"https://huggingface.co/facebook/bart-base/resolve/main/tokenizer.json"
|
||||
(-5675567303366998911)
|
||||
robertaTokenizer <-
|
||||
createTokenizer
|
||||
"https://huggingface.co/roberta-base/resolve/main/tokenizer.json"
|
||||
(-5675567303366998911)
|
||||
t5Tokenizer <-
|
||||
createTokenizer
|
||||
"https://huggingface.co/t5-base/resolve/main/tokenizer.json"
|
||||
(-6144928463468424742)
|
||||
pure $ TestTokenizers {..}
|
||||
freeTokenizers TestTokenizers {..} = do
|
||||
freeTokenizer bartTokenizer
|
||||
@ -118,6 +151,15 @@ testTree =
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
s <- decode bartTokenizer ids
|
||||
H.assertEqual "Unexpected decoding result" expected s
|
||||
toTest (MassDecodeBart ids tokens) mtokenizers = H.testCase ("MassDecode " <> show ids) $ do
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
mapM_ (\token -> decode bartTokenizer (ids <> [token])) tokens
|
||||
toTest (IncrementalDecodeBart ids otherIds) mtokenizers = H.testCase ("Incrementally decode " <> show ids <> " " <> show otherIds) $ do
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
s <- decode bartTokenizer ids
|
||||
s' <- decode bartTokenizer otherIds
|
||||
s'' <- decode bartTokenizer $ ids <> otherIds
|
||||
H.assertEqual "Unexpected decoding result" s'' (s <> s')
|
||||
toTest (EncodeRoberta s expected) mtokenizers = H.testCase ("Encode " <> show s) $ do
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
enc <- encode robertaTokenizer s
|
||||
@ -136,6 +178,21 @@ testTree =
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
s <- decode t5Tokenizer ids
|
||||
H.assertEqual "Unexpected decoding result" expected s
|
||||
toTest (MassDecodeT5 ids tokens) mtokenizers = H.testCase ("MassDecode " <> show ids) $ do
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
mapM_ (\token -> decode t5Tokenizer (ids <> [token])) tokens
|
||||
toTest (IncrementalDecodeT5 ids otherIds) mtokenizers = H.testCase ("Incrementally decode " <> show ids <> " " <> show otherIds) $ do
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
s <- decode t5Tokenizer ids
|
||||
s' <- decode t5Tokenizer otherIds
|
||||
s'' <- decode t5Tokenizer $ ids <> otherIds
|
||||
H.assertEqual "Unexpected decoding result" s'' (s <> s')
|
||||
toTest (IncrementalDecodeT5Fail ids otherIds) mtokenizers = H.testCase ("Incrementally decode " <> show ids <> " " <> show otherIds) $ do
|
||||
TestTokenizers {..} <- mtokenizers
|
||||
s <- decode t5Tokenizer ids
|
||||
s' <- decode t5Tokenizer otherIds
|
||||
s'' <- decode t5Tokenizer $ ids <> otherIds
|
||||
H.assertBool "Unexpected decoding result" (s'' /= (s <> s'))
|
||||
|
||||
-- | Run 'stack ghci --test' to get a REPL for the tests.
|
||||
main :: IO ()
|
||||
|
@ -34,6 +34,9 @@ test-suite tokenizers-test
|
||||
build-depends:
|
||||
base >= 4.7 && < 5
|
||||
, bytestring
|
||||
, hashable
|
||||
, http-client
|
||||
, http-client-tls
|
||||
, tasty
|
||||
, tasty-hunit
|
||||
, tokenizers
|
||||
|
Loading…
Reference in New Issue
Block a user