make critical stability improvements

This commit is contained in:
Torsten Scholak 2021-05-03 20:40:28 -04:00
parent 5ab5ca03f8
commit 41bf8695a9
No known key found for this signature in database
GPG Key ID: EF135E6C40866D80
4 changed files with 108 additions and 47 deletions

View File

@ -15,14 +15,12 @@ use tokenizers::AddedToken;
pub extern "C" fn deserialize_tokenizer(cconfig: *const c_char) -> *mut Tokenizer {
unsafe {
let config = CStr::from_ptr(cconfig);
if let Ok(config_file) = config.to_str() {
if let Ok(tokenizer) = Tokenizer::from_file(config_file) {
return Box::into_raw(Box::new(tokenizer));
} else {
panic!("Unable to read tokenizer from file.");
}
} else {
panic!("Unable to read config.");
match config.to_str() {
Ok(config_file) => match Tokenizer::from_file(config_file) {
Ok(tokenizer) => return Box::into_raw(Box::new(tokenizer)),
Err(error) => panic!("Unable to read tokenizer from file: {:?}", error),
},
Err(error) => panic!("Unable to read config: {:?}", error),
}
}
}
@ -31,14 +29,12 @@ pub extern "C" fn deserialize_tokenizer(cconfig: *const c_char) -> *mut Tokenize
pub extern "C" fn deserialize_tokenizer_from_json(cjson: *const c_char) -> *mut Tokenizer {
unsafe {
let json = CStr::from_ptr(cjson);
if let Ok(json_str) = json.to_str() {
if let Ok(tokenizer) = Tokenizer::from_str(json_str) {
return Box::into_raw(Box::new(tokenizer));
} else {
panic!("Unable to read tokenizer from json.");
}
} else {
panic!("Unable to read json string.");
match json.to_str() {
Ok(json_str) => match Tokenizer::from_str(json_str) {
Ok(tokenizer) => return Box::into_raw(Box::new(tokenizer)),
Err(error) => panic!("Unable to read tokenizer from json: {:?}", error),
},
Err(error) => panic!("Unable to read json string: {:?}", error),
}
}
}
@ -51,14 +47,12 @@ pub extern "C" fn serialize_tokenizer(cconfig: *const c_char, ptr: *mut Tokenize
assert!(!ptr.is_null());
&mut *ptr
};
if let Ok(config_file) = config.to_str() {
if let Ok(res) = tokenizer.save(config_file, false) {
return res;
} else {
panic!("Unable to save tokenizer to file.");
}
} else {
panic!("Unable to read config.");
match config.to_str() {
Ok(config_file) => match tokenizer.save(config_file, false) {
Ok(res) => return res,
Err(error) => panic!("Unable to save tokenizer to file: {:?}", error),
},
Err(error) => panic!("Unable to read config: {:?}", error),
}
}
}
@ -160,12 +154,19 @@ pub extern "C" fn decode(clength: c_uint, cids: *const c_uint, ptr: *mut Tokeniz
unsafe { ids.push(*p) };
}
ids.shrink_to_fit();
if let Ok(res) = tokenizer.decode(ids, false) {
let c_str = CString::new(res).unwrap();
let ptr = c_str.into_raw();
return ptr;
} else {
panic!("Unable to decode ids.");
let ids_ = ids.clone();
match tokenizer.decode(ids, false) {
Ok(res) => {
let res_ = res.clone();
match CString::new(res) {
Ok(c_str) => {
let ptr = c_str.into_raw();
return ptr;
}
Err(error) => panic!("Unable to convert tokenizer result to CString: {:?} {:?} {:?}", error, res_, ids_),
}
},
Err(error) => panic!("Unable to decode ids: {:?}", error),
}
}
@ -177,11 +178,12 @@ pub extern "C" fn add_special_token(ctoken: *const c_char, ptr: *mut Tokenizer)
assert!(!ptr.is_null());
&mut *ptr
};
if let Ok(s) = cstring.to_str() {
let token = AddedToken::from(s, true);
tokenizer.add_special_tokens(&[token]);
} else {
panic!("Unable to read token.");
match cstring.to_str() {
Ok(s) => {
let token = AddedToken::from(s, true);
tokenizer.add_special_tokens(&[token]);
}
Err(error) => panic!("Unable to read token: {:?}", error),
}
}
}

View File

@ -4,8 +4,7 @@ module Tokenizers where
import Control.Applicative (empty)
import Control.Exception (bracket)
import Data.ByteString (ByteString)
import Data.ByteString.Unsafe (unsafeUseAsCString)
import Data.ByteString (ByteString, useAsCString)
import Foreign.C.String (CString, peekCString, withCString)
import Foreign.C.Types (CInt, CUInt (..))
import Foreign.Marshal.Array (withArrayLen)
@ -57,7 +56,7 @@ foreign import ccall unsafe "deserialize_tokenizer_from_json"
createTokenizerFromJSONConfig :: ByteString -> IO Tokenizer
createTokenizerFromJSONConfig json =
unsafeUseAsCString
useAsCString
json
( \cjson ->
Tokenizer

View File

@ -3,19 +3,27 @@
module Main where
import qualified Data.ByteString.Lazy as LBS (toStrict)
import Data.Hashable (hash)
import qualified Network.HTTP.Client as HTTP
import qualified Network.HTTP.Client.TLS as HTTP
import qualified Test.Tasty as T
import qualified Test.Tasty.HUnit as H
import qualified Data.ByteString as BS (readFile)
import Tokenizers (Tokenizer, addSpecialToken, cleanTokens, createTokenizerFromConfigFile, createTokenizerFromJSONConfig, decode, encode, freeTokenizer, getIDs, getTokens, mkRobertaTokenizer)
import Tokenizers (Tokenizer, addSpecialToken, cleanTokens, createTokenizerFromJSONConfig, decode, encode, freeTokenizer, getIDs, getTokens, mkRobertaTokenizer)
data TestItem
= Group String [TestItem]
| EncodeBart String [Int]
| DecodeBart [Int] String
| MassDecodeBart [Int] [Int]
| IncrementalDecodeBart [Int] [Int]
| EncodeRoberta String [Int]
| DecodeRoberta [Int] String
| EncodeT5 String [Int]
| DecodeT5 [Int] String
| MassDecodeT5 [Int] [Int]
| IncrementalDecodeT5 [Int] [Int]
| IncrementalDecodeT5Fail [Int] [Int]
deriving stock (Eq, Show)
data TestTokenizers = TestTokenizers
@ -60,7 +68,14 @@ bartTests =
EncodeBart "<s>Hello <mask>!</s><pad>" [0, 31414, 50264, 328, 2, 1],
EncodeBart "<s> Hello <mask> ! </s> <pad>" [0, 1437, 1437, 20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1],
DecodeBart [0, 31414, 50264, 328, 2, 1] "<s>Hello<mask>!</s><pad>",
DecodeBart [0, 1437, 1437, 20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1] "<s> Hello <mask> ! </s> <pad>"
DecodeBart [0, 1437, 1437, 20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1] "<s> Hello <mask> ! </s> <pad>",
MassDecodeBart [0, 21959, 1721, 44664, 2103, 4, 42351, 11974, 2103] ([0 .. 50107] <> [50109 .. 100000]),
IncrementalDecodeBart [0] [31414, 50264, 328, 2, 1],
IncrementalDecodeBart [0, 31414] [50264, 328, 2, 1],
IncrementalDecodeBart [0, 31414, 50264, 328] [2, 1],
IncrementalDecodeBart [0, 1437, 1437] [20920, 1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1],
IncrementalDecodeBart [0, 1437, 1437, 20920] [1437, 1437, 50264, 1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1],
IncrementalDecodeBart [0, 1437, 1437, 20920, 1437, 1437, 50264] [1437, 27785, 1437, 1437, 1437, 1437, 2, 1437, 1437, 1]
]
robertaTests :: [TestItem]
@ -76,7 +91,11 @@ t5Tests =
[ EncodeT5 "<pad>Hello world!</s>" [0, 8774, 296, 55, 1],
EncodeT5 "<pad>Hello <extra_id_0>!</s><pad>" [0, 8774, 32099, 3, 55, 1, 0],
EncodeT5 "<pad> Hello <extra_id_0> ! </s> <pad>" [0, 8774, 32099, 3, 55, 1, 0],
DecodeT5 [0, 8774, 32099, 3, 55, 1, 0] "<pad> Hello<extra_id_0> !</s><pad>"
DecodeT5 [0, 8774, 32099, 3, 55, 1, 0] "<pad> Hello<extra_id_0> !</s><pad>",
MassDecodeT5 [0, 8774, 32099, 3, 55, 1] [0 .. 100000],
IncrementalDecodeT5Fail [0] [8774, 32099, 3, 55, 1, 0],
IncrementalDecodeT5 [0, 8774, 32099, 3, 55] [1, 0],
IncrementalDecodeT5Fail [0, 4219, 834, 7, 9963, 1820, 1738] [3476]
]
testData :: TestItem
@ -95,12 +114,26 @@ testTree =
freeTokenizers
(toTest testData)
where
createTokenizer url expectedHash = do
manager <- HTTP.newTlsManagerWith HTTP.tlsManagerSettings
request <- HTTP.parseRequest url
response <- HTTP.httpLbs request manager
let body = LBS.toStrict . HTTP.responseBody $ response
H.assertEqual "Unexpected json hash" expectedHash (hash body)
createTokenizerFromJSONConfig body
createTokenizers = do
bartTokenizer <- do
json <- BS.readFile "models/bart-base-tokenizer.json"
createTokenizerFromJSONConfig json
robertaTokenizer <- createTokenizerFromConfigFile "models/roberta-base-tokenizer.json"
t5Tokenizer <- createTokenizerFromConfigFile "models/t5-base-tokenizer.json"
bartTokenizer <-
createTokenizer
"https://huggingface.co/facebook/bart-base/resolve/main/tokenizer.json"
(-5675567303366998911)
robertaTokenizer <-
createTokenizer
"https://huggingface.co/roberta-base/resolve/main/tokenizer.json"
(-5675567303366998911)
t5Tokenizer <-
createTokenizer
"https://huggingface.co/t5-base/resolve/main/tokenizer.json"
(-6144928463468424742)
pure $ TestTokenizers {..}
freeTokenizers TestTokenizers {..} = do
freeTokenizer bartTokenizer
@ -118,6 +151,15 @@ testTree =
TestTokenizers {..} <- mtokenizers
s <- decode bartTokenizer ids
H.assertEqual "Unexpected decoding result" expected s
toTest (MassDecodeBart ids tokens) mtokenizers = H.testCase ("MassDecode " <> show ids) $ do
TestTokenizers {..} <- mtokenizers
mapM_ (\token -> decode bartTokenizer (ids <> [token])) tokens
toTest (IncrementalDecodeBart ids otherIds) mtokenizers = H.testCase ("Incrementally decode " <> show ids <> " " <> show otherIds) $ do
TestTokenizers {..} <- mtokenizers
s <- decode bartTokenizer ids
s' <- decode bartTokenizer otherIds
s'' <- decode bartTokenizer $ ids <> otherIds
H.assertEqual "Unexpected decoding result" s'' (s <> s')
toTest (EncodeRoberta s expected) mtokenizers = H.testCase ("Encode " <> show s) $ do
TestTokenizers {..} <- mtokenizers
enc <- encode robertaTokenizer s
@ -136,6 +178,21 @@ testTree =
TestTokenizers {..} <- mtokenizers
s <- decode t5Tokenizer ids
H.assertEqual "Unexpected decoding result" expected s
toTest (MassDecodeT5 ids tokens) mtokenizers = H.testCase ("MassDecode " <> show ids) $ do
TestTokenizers {..} <- mtokenizers
mapM_ (\token -> decode t5Tokenizer (ids <> [token])) tokens
toTest (IncrementalDecodeT5 ids otherIds) mtokenizers = H.testCase ("Incrementally decode " <> show ids <> " " <> show otherIds) $ do
TestTokenizers {..} <- mtokenizers
s <- decode t5Tokenizer ids
s' <- decode t5Tokenizer otherIds
s'' <- decode t5Tokenizer $ ids <> otherIds
H.assertEqual "Unexpected decoding result" s'' (s <> s')
toTest (IncrementalDecodeT5Fail ids otherIds) mtokenizers = H.testCase ("Incrementally decode " <> show ids <> " " <> show otherIds) $ do
TestTokenizers {..} <- mtokenizers
s <- decode t5Tokenizer ids
s' <- decode t5Tokenizer otherIds
s'' <- decode t5Tokenizer $ ids <> otherIds
H.assertBool "Unexpected decoding result" (s'' /= (s <> s'))
-- | Run 'stack ghci --test' to get a REPL for the tests.
main :: IO ()

View File

@ -34,6 +34,9 @@ test-suite tokenizers-test
build-depends:
base >= 4.7 && < 5
, bytestring
, hashable
, http-client
, http-client-tls
, tasty
, tasty-hunit
, tokenizers