Initialize repository

This commit is contained in:
Taku Kudo 2017-03-07 19:43:50 +09:00
commit 2928ce5307
93 changed files with 262295 additions and 0 deletions

27
CONTRIBUTING Normal file
View File

@ -0,0 +1,27 @@
Want to contribute? Great! First, read this page (including the small print at the end).
### Before you contribute
Before we can use your code, you must sign the
[Google Individual Contributor License Agreement]
(https://cla.developers.google.com/about/google-individual)
(CLA), which you can do online. The CLA is necessary mainly because you own the
copyright to your changes, even after your contribution becomes part of our
codebase, so we need your permission to use and distribute your code. We also
need to be sure of various other things—for instance that you'll tell us if you
know that your code infringes on other people's patents. You don't have to sign
the CLA until after you've submitted your code for review and a member has
approved it, but you must do it before we can put your code into our codebase.
Before you start working on a larger contribution, you should get in touch with
us first through the issue tracker with your idea so that we can help out and
possibly guide you. Coordinating up front makes it much easier to avoid
frustration later on.
### Code reviews
All submissions, including submissions by project members, require review. We
use Github pull requests for this purpose.
### The small print
Contributions made by corporations are covered by a different agreement than
the one above, the
[Software Grant and Corporate Contributor License Agreement]
(https://cla.developers.google.com/about/google-corporate).

202
LICENSE Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

22
Makefile.am Normal file
View File

@ -0,0 +1,22 @@
AUTOMAKE_OPTIONS = foreign
SUBDIRS = src
EXTRA_DIRS = m4 third_party data doc
EXTRA_DIST = README.md LICENSE
ACLOCAL_AMFLAGS = -I third_party/m4
dist-hook:
for subdir in $(EXTRA_DIRS); do \
cp -rp $$subdir $(distdir); \
rm -f $(distdir)/$$subdir/*~; \
rm -f $(distdir)/$$subdir/*.{bak,orig}; \
rm -rf $(distdir)/$$subdir/CVS; \
rm -rf $(distdir)/$$subdir/.svn; \
rm -rf $(distdir)/.svn; \
rm -rf $(distdir)/*/.svn; \
rm -rf $(distdir)/*/*/.svn; \
rm -rf $(distdir)/$$subdir/*/CVS; \
rm -rf $(distdir)/$$subdir/*/.svn; \
rm -rf $(distdir)/$$subdir/*/.pb.cc; \
find $(distdir) -name .svn | xargs rm -fr; \
done

229
README.md Normal file
View File

@ -0,0 +1,229 @@
# SentencePiece
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for
Neural Network-based text generation systems where the vocabulary size
is predetermined prior to the Neural model training. SentencePiece implements
**sub-word units** (also known as **wordpieces** [[Wu et al.](https://arxiv.org/pdf/1609.08144.pdf)]
[[Schuster et al.](https://static.googleusercontent.com/media/research.google.com/ja//pubs/archive/37842.pdf)]
and **byte-pair-encoding (BPE)** [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)]) with the extension of direct
training from raw sentences. SentencePiece allows us to make a purely end-to-end
system that does not depend on language-specific pre/postprocessing.
**This is not an official Google product.**
## Technical highlights
- **Purely data driven**: SentencePiece trains tokenization and detokenization
models from only raw sentences. No pre-tokenization ([Moses tokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl)/[MeCab](http://taku910.github.io/mecab/)/[KyTea](http://www.phontron.com/kytea/)) is required.
- **Language independent**: SentencePiece treats the sentences just as sequences of Unicode characters. There is no language-dependent logic.
- **Fast and lightweight**: Segmentation speed is around 50k sentences/sec, and memory footprint is around 6MB.
- **Self-contained**: The same tokenization/detokenization is obtained as long as the same model file is used.
- **Direct vocabulary id generation**: SentencePiece manages vocabulary to id mapping and can directly generate vocabulary id sequences from raw sentences.
- **NFKC-based normalization**: SentencePiece performs NFKC-based text normalization.
## Overview
### What is SentencePiece?
SentencePiece is an unsupervised text tokenizer and detokenizer designed mainly for Neural Network-based text generation, for example Neural Network Machine Translation. SentencePiece is a re-implementation of **sub-word units** (also known as **wordpieces** [[Wu et al.](https://arxiv.org/pdf/1609.08144.pdf)][[Schuster et al.](https://static.googleusercontent.com/media/research.google.com/ja//pubs/archive/37842.pdf)] and **byte-pair-encoding (BPE)** [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)]). Unlike previous sub-word approaches that train tokenizers from pre-tokenized sentences, SentencePiece directly trains the tokenizer and detokenizer from raw sentences.
SentencePiece might seem like a sort of unsupervised word segmentation, but there are several differences and constraints in SentencePiece.
#### The number of unique tokens is predetermined
Neural Machine Translation models typically operate with a fixed
vocabulary. Unlike most unsupervised word segmentation algorithms, which
assume an infinite vocabulary, SentencePiece trains the segmentation model such
that the final vocabulary size is fixed, e.g., 8k, 16k, or 32k.
#### Whitespace is considered as as a basic symbol
The first step of Natural Language processing is text tokenization. For
example, standard English tokenizer segments a text "Hello world." into the
following three tokens.
> [Hello] [World] [.]
One observation is that the original input and tokenized sequence are **NOT
reversibly convertible**. For instance, the information that no space exists
between “World” and “.” is dropped from the tokenized sequence, since e.g., `Tokenize(“World.”) == Tokenize(“World .”)`
SentencePiece treats the input text just as a sequence of Unicode characters. Whitespace is also handled as a normal symbol. To handle the whitespace as a basic token explicitly, SentencePiece first escapes the whitespace with a meta symbol "▁" (U+2581) as follows.
> Hello▁World.
Then, this text is segmented into small pieces, for example.
> [Hello] [▁Wor] [ld] [.]
Since the whitespace is preserved in the segmented text, we can detokenize the text without any ambiguities.
```
detokenized = ''.join(pieces).replace('_', ' ')
```
This feature makes it possible to perform detokenization without relying on language-specific resources.
Note that we cannot apply the same lossless conversions when splitting the
sentence with standard word segmenters, since they treat the whitespace as a
special symbol. Tokenized sequences do not preserve the necessary information to restore the orignal sentence.
* (en) Hello world. → [Hello] [World] [.] \(A space between Hello and World\)
* (ja) こんにちは世界。 → [こんにちは] [世界] [。] \(No space between こんにちは and 世界\)
## Required packages
The following tools and libraries are required to build SentencePiece:
* GNU autotools (autoconf automake libtool)
* C++11 compiler
* libprotobuf
On Ubuntu, autotools and libprotobuf can be install with apt-get:
```
% sudo apt-get install autoconf automake libtool libprotobuf-c++ protocolbuffer
```
## Build and Install SentencePiece
```
% cd /path/to/sentencepiece
% ./autogen.sh
% ./configure
% make
% make check
% sudo make install
```
## Train SentencePiece Model
```
% spm_train --input=<input> --model_prefix=<model_name> --vocab_size=8000 --model_type=<type>
```
* `--input`: one-sentence-per-line **raw** corpus file. No need to run
tokenizer, normalizer or preprocessor. By default, SentencePiece normalizes
the input with Unicode NFKC. You can pass a comma-separated list of files.
* `--model_prefix`: output model name prefix. `<model_name>.model` and `<model_name>.vocab` are generated.
* `--vocab_size`: vocabulary size, e.g., 8000, 16000, or 32000
* `--model_type`: model type. Choose from `unigram` (default), `bpe`, `char`, or `word`. The input sentence must be pre-tokenized when using `word` type.
Note that `spm_train` loads only the first `--input_sentence_size` sentences (default value is 10M).
Use `--help` flag to display all parameters for training.
## Encode raw text into sentence pieces/ids
```
% spm_encode --model=<model_file> --output_format=piece < input > output
% spm_encode --model=<model_file> --output_format=id < input > output
```
Use `--extra_options` flag to insert the BOS/EOS markers or reverse the input sequence.
```
% spm_encode --extra_options=eos (add </s> only)
% spm_encode --extra_options=bos:eos (add <s> and </s>)
% spm_encode --extra_options=reverse:bos:eos (reverse input and add <s> and </s>)
```
## Decode sentence pieces/ids into raw text
```
% spm_decode --model=<model_file> --input_format=piece < input > output
% spm_decode --model=<model_file> --input_format=id < input > output
```
Use `--extra_options` flag to decode the text in reverse order.
```
% spm_decode --extra_options=reverse < input > output
```
## End-to-End Example
```
% spm_train --input=data/botchan.txt --model_prefix=m --vocab_size=1000
unigram_model_trainer.cc(494) LOG(INFO) Starts training with :
input: "../data/botchan.txt"
... <snip>
unigram_model_trainer.cc(529) LOG(INFO) EM sub_iter=1 size=1100 obj=10.4973 num_tokens=37630 num_tokens/piece=34.2091
trainer_interface.cc(272) LOG(INFO) Saving model: m.model
trainer_interface.cc(281) LOG(INFO) Saving vocabs: m.vocab
% echo "I saw a girl with a telescope." | spm_encode --model=m.model
▁I ▁saw ▁a ▁girl ▁with ▁a ▁ te le s c o pe .
% echo "I saw a girl with a telescope." | spm_encode --model=m.model --output_format=id
9 459 11 939 44 11 4 142 82 8 28 21 132 6
% echo "9 459 11 939 44 11 4 142 82 8 28 21 132 6" | spm_decode --model=m.model --input_format=id
I saw a girl with a telescope.
```
You can find that the original input sentence is restored from the vocabulary id sequence.
## Export vocabulary list
```
% spm_export_vocab --model=<model_file> --output=<output file>
```
```<output file>``` stores a list of vocabulary and emission log probabilities. The vocabulary id corresponds to the line number in this file.
## Experiments
### Experimental settings
We have evaluated SentencePiece segmentation with the following configurations.
* Segmentation algorithms:
* **BPE** (Byte Pair
Encoding) [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)] (`--model_type=bpe`)
* **Unigram**. Language-model based segmentation. (`--model_type=unigram`)
* Pre-tokenization methods:
* **NoPretok**: No pre-tokenization. We train SentencePiece directly from
raw sentences (`--split_by_whitespace=false`).
* **WsPretok**: Trains SentencePiece model from the sentences tokenized by
whitespaces (`--split_by_whitespace=true`). When handling CJK, this setting is almost equivalent to **NoPretok**.
* **MosesPretok**: Trains SentencePiece model from sentences tokenized
by [Moses tokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl). We used [KyTea](http://www.phontron.com/kytea/) for
Japanese and in-house segmenters for Korean and Chinese respectively.
* NMT parameters: ([Googles Neural Machine Translation System](https://arxiv.org/pdf/1609.08144.pdf) is applied for all experiments.)
* 16k shared vocabulary (Shares the same vocabulary for source and
target. We train single SentencePiece model by concatenating raw source
and target sentences.)
* Dropout prob: 0.2
* num nodes: 512
* num lstms: 8
* Evaluation metrics:
* Case-sensitive BLEU on detokenized text with NIST scorer.
* For CJK, the same word segmenters are applied prior to NIST scorer.
* No detokenizer is applied for **NoPretok** and **WsPretok**, which can
directly emit detokenized sentences.
* Applied [Moses detokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/detokenizer.perl) and in-house rule-based detokenizer (CJK) for **MosesPretok**.
* Data sets:
* [KFTT](http://www.phontron.com/kftt/index.html)
* [MultiUN](http://opus.lingfil.uu.se/MultiUN.php) (First 5M and next
5k/5k sentences are used for training and development/testing respectively.)
* [WMT16](http://www.statmt.org/WMT16/)
* In-house: (Used 5M parallel sentences for training)
**NoPretok** and **WsPretok** do not use any language-dependent resources.
**BPE**+**MosePretok** is almost the same configuration used in [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)] and [[Wu et al.](https://arxiv.org/pdf/1609.08144.pdf)].
### Results (BLEU scores)
|Language Pair|BPE(NoPretok)|BPE(WsPretok)|BPE(MosesPretok)|Unigram(NoPretok)|Unigram(WsPretok)|Unigram(MosesPretok)
|---|---|---|---|---|---|---|
|KFTT en-ja| 0.2796| 0.281| 0.286| 0.2806| 0.280| 0.2871|
|KFTT ja-en| 0.1943| 0.208| 0.1967| 0.1985| 0.2148| 0.198|
|MultiUN ar-en| 0.5268| 0.5414| 0.5381| 0.5317| 0.5449| 0.5401|
|MultiUN en-ar| 0.4039| 0.4147| 0.4012| 0.4084| 0.4172| 0.3991|
|MultiUN en-zh| 0.4155| 0.4186| 0.395| 0.4214| 0.4165| 0.399|
|MultiUN zh-en| 0.46| 0.4716| 0.4806| 0.4644| 0.4711| 0.4759|
|In house en-ko| 0.178| 0.1851| 0.1893| 0.1846| 0.1872| 0.1890|
|In house ko-en| 0.1786| 0.1954| 0.1994| 0.1845| 0.1956| 0.2015|
|WMT16 cs-en| 0.1987| 0.2252| 0.2231| 0.2164| 0.2228| 0.2238|
|WMT16 de-en| 0.3194| 0.3348| 0.3374| 0.3261| 0.3375| 0.3398|
|WMT16 en-cs| 0.1607| 0.1827| 0.1812| 0.1722| 0.1778| 0.179|
|WMT16 en-de| 0.2847| 0.3029| 0.3013| 0.2946| 0.3000| 0.3053|
|WMT16 en-fi| 0.1434| 0.1528| 0.1499| 0.1472| 0.1568| 0.1517|
|WMT16 en-ru| 0.1884| 0.1973| 0.1989| 0.19| 0.1982| 0.1903|
|WMT16 fi-en| 0.1775| 0.1867| 0.1877| 0.182| 0.1882| 0.1865|
|WMT16 ru-en| 0.2042| 0.2229| 0.2194| 0.2087| 0.2201| 0.2155|
* **MosesPretok** does not always improve BLEU scores. Comparable
accuracy can be obtained without using language-dependent resources in many
language pairs.
* Whitespace pre-tokenization is a reasonable choice. It does not use language-specific resources.
* **NoPretok** shows poor BLEU scores. Unigrams are more robust than BPE when no pre-tokenizer is applied.
## Advanced topics
* [SentencePieceProcessor C++ API](doc/api.md)
* [Use custom text normalization rules](doc/normalization.md)
* [Use custom symbols](doc/special_symbols.md)
* [Segmentation and training algorithms in detail]

25
autogen.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/sh
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.!
echo "Running aclocal ..."
aclocal -I .
echo "Running autoheader..."
autoheader
echo "Running libtoolize .."
libtoolize
echo "Running automake ..."
automake --add-missing --copy
echo "Running autoconf ..."
autoconf

54
configure.ac Normal file
View File

@ -0,0 +1,54 @@
# -*- Autoconf -*-
# Process this file with autoconf to produce a configure script.
AC_PREREQ([2.69])
AC_INIT([sentencepiece], [0.1.0], [taku@google.com])
AM_INIT_AUTOMAKE()
AC_CONFIG_SRCDIR([src/normalizer.h])
AC_CONFIG_HEADERS([config.h])
# Checks for programs.
AC_LANG([C++])
AC_PROG_LIBTOOL
AC_PROG_CXX
AC_PROG_CC
CXXFLAGS="-std=c++11 -Wall -O3"
PKG_CHECK_MODULES(PROTOBUF, protobuf >= 2.4.0)
AC_SUBST(PROTOBUF_LIBS)
AC_SUBST(PROTOBUF_CFLAGS)
AC_SUBST(PROTOBUF_VERSION)
CXXFLAGS="$CXXFLAGS $PROTOBUF_CFLAGS"
LIBS="$LIBS $PROTOBUF_LIBS"
# Checks for header files.
AC_CHECK_HEADERS([unistd.h])
AC_CHECK_PROG([PROTOC], [protoc], [protoc])
AS_IF([test "x${PROTOC}" == "x"],
[AC_MSG_ERROR([ProtoBuf compiler "protoc" not found. You can install them with "sudo apt-get install libprotobuf-c++ protobuf-compiler" ])])
AC_ARG_ENABLE(enable-nfkc-compile,
[ --enable-nfkc-compile compile NFKC normalizer mapping [default no]])
if test "$enable_nfkc_compile" = "yes"; then
AX_CHECK_ICU([40], ,AC_MSG_ERROR([Library requirements (ICU) not met.]))
CXXFLAGS="$CXXFLAGS -DENABLE_NFKC_COMPILE"
LIBS="$LIBS $ICU_LIBS"
fi
# Checks for typedefs, structures, and compiler characteristics.
AC_CHECK_HEADER_STDBOOL
AC_C_INLINE
AC_TYPE_SIZE_T
# Checks for library functions.
AC_FUNC_STRTOD
AC_CHECK_FUNCS([memchr memset])
AC_CONFIG_MACRO_DIR([third_party/m4])
AC_CONFIG_FILES([Makefile
src/Makefile])
AC_OUTPUT

2632
data/Scripts.txt Normal file

File diff suppressed because it is too large Load Diff

5119
data/botchan.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,45 @@
#!/usr/bin/perl
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generate unicode_sciript_data.h from Unicode Scripts.txt
#
# usage: ./gen_unicode_Scripts_code.pl < scripts > unicode_script_data.h
#
print "#ifndef UNICODE_SCRIPT_DATA_H_\n";
print "#define UNICODE_SCRIPT_DATA_H_\n";
print "namespace sentencepiece {\n";
print "namespace unicode_script {\n";
print "namespace {\n";
print "void InitTable(std::unordered_map<char32, ScriptType> *smap) {\n";
print " CHECK_NOTNULL(smap)->clear();\n";
while (<>) {
chomp;
if (/^([0-9A-F]+)\s+;\s+(\S+)\s+\#/) {
printf(" (*smap)[0x%s] = U_%s;\n", $1, $2);
} elsif (/^([0-9A-F]+)\.\.([0-9A-F]+)\s+;\s+(\S+)\s+\#/) {
printf(" for (char32 c = 0x%s; c <= 0x%s; ++c)\n", $1, $2);
printf(" (*smap)[c] = U_%s;\n", $3);
} else {
next;
}
}
print "}\n";
print "} // namespace\n";
print "} // namespace unicode_script\n";
print "} // namespace sentencepiece\n";
print "#endif // UNICODE_SCRIPT_DATA_H_\n";

1
data/identity.tsv Normal file
View File

@ -0,0 +1 @@
20 20 # =>
1 20 20 # =>

224654
data/nfkc.tsv Normal file

File diff suppressed because it is too large Load Diff

2374
data/wagahaiwa_nekodearu.txt Normal file

File diff suppressed because one or more lines are too long

96
doc/api.md Normal file
View File

@ -0,0 +1,96 @@
# SentencePieceProcessor C++ API
## Load SentencePiece model
To start working with the SentencePiece model, you will want to include the `sentencepiece_processor.h` header file.
Then instantiate sentencepiece::SentencePieceProcessor class and calls `Load`or `LoadOrDie` method to load the model file.
```
#include <sentencepiece_processor.h>
sentencepiece::SentencePieceProcessor processor;
processor.LoadOrDie("//path/to/model.model");
```
## Tokenize text (preprocessing)
Calls `SentencePieceProcessor::Encode` method to tokenize text.
```
std::vector<std::string> pieces;
processor.Encode("This is a test.", &pieces);
for (const std::string &token : pieces) {
std::cout << token << std::endl;
}
```
You will obtain the sequence of vocab ids as follows:
```
std::vector<int> ids;
processor.Encode("This is a test.", &ids);
for (const int id : ids) {
std::cout << id << std::endl;
}
```
## Detokenize text (postprocessing)
Calls `SentencePieceProcessor::Decode` method to detokenize a sequence of pieces or ids into a text. Basically it is guaranteed that the detoknization is an inverse operation of Encode, i.e., `Decode(Encode(Normalize(input))) == Normalize(input)`.
```
std::vector<std::string> pieces = { "▁This", "▁is", "▁a", "▁", "te", "st", "." }; // sequence of pieces
std::string text
processor.Decode(pieces, &text);
std::cout << text << std::endl;
std::vector<int> ids = { 451, 26, 20, 3, 158, 128, 12 }; // sequence of ids
processor.Decode(ids, &text);
std::cout << text << std::endl;
```
## SentencePieceText proto
You will want to use `SentencePieceText` class to obtain the pieces and ids at the same time. This proto also encodes a utf8-byte offset of each piece over user input or detokenized text.
```
#include <sentencepiece.pb.h>
sentencepiece::SentencePieceText spt;
// Encode
processor.Encode("This is a test.", &spt);
std::cout << spt.text() << std::endl; // This is the same as the input.
for (const auto &piece : spt.pieces()) {
std::cout << piece.begin() << std::endl; // beginning of byte offset
std::cout << piece.end() << std::endl; // end of byte offset
std::cout << piece.piece() << std::endl; // internal representation.
std::cout << piece.surface() << std::endl; // external representation. spt.text().substr(begin, end - begin) == surface().
std::cout << piece.id() << std::endl; // vocab id
}
// Decode
processor.Decode({10, 20, 30}, &spt);
std::cout << spt.text() << std::endl; // This is the same as the decoded string.
for (const auto &piece : spt.pieces()) {
// the same as above.
}
```
## Vocabulary management
You will want to use the following methods to obtain ids from/to pieces.
```
processor.GetPieceSize(); // returns the size of vocabs.
processor.PieceToId("foo"); // returns the vocab id of "foo"
processor.IdToPiece(10); // returns the string representation of id 10.
processor.IsUnknown(0); // returns true if the given id is an unknown token. e.g., <unk>
processor.IsControl(10); // returns true if the given id is a control token. e.g., <s>, </s>
```
## Extra Options
Use `SetEncodeExtraOptions` and `SetDecodeExtraOptions` methods to set extra options for encoding and decoding respectively. These methods need to be called just after `Load/LoadOrDie` methods.
```
processor.SetEncodeExtraOptions("bos:eos"); // add <s> and </s>.
processor.SetEncodeExtraOptions("reverse:bos:eos"); // reverse the input and then add <s> and </s>.
processor.SetDecodeExtraOptions("reverse"); // the decoder's output is reversed.
```

45
doc/normalization.md Normal file
View File

@ -0,0 +1,45 @@
# Use custom normalization rule
By default, SentencePiece normalizes the input sentence with a variant of Unicode
[NFKC](https://en.wikipedia.org/wiki/Unicode_equivalence).
SentencePiece framework allows us to define custom normalization rule, which is stored in the model file.
## Use pre-defined normalization rule
SentencePiece framework provides the following pre-defined normalization rule. It is recommended to use one of them unless you have any special reasons.
* **nfkc**: [NFKC](https://en.wikipedia.org/wiki/Unicode_equivalence) normalization (default)
* **identity**: no normalization
You can choose the normalization rule with `--normalization_rule_name` flag.
```
% spm_train --normalization_rule_name=identity --input=<input> --model_prefix=<model file> --vocab_size=8000
```
NOTE: Due to the limitation of normalization algorithm, full NFKC normalization is not implemented. [builder.h] describes example character sequences not normalized by our NFKC implementation.
## Use custom normalization rule
The normalization is performed with user-defined string-to-string mappings and leftmost longest matching.
You can use custom normalization rule by preparing a TSV file formatted as follows:
```
41 302 300 1EA6
41 302 301 1EA4
41 302 303 1EAA
...
```
In this sample, UCS4 sequence [41 302 300] (hex) is converted into [1EA6] (hex). When there are ambiguities in the conversions, the longest rule is used.
Note that the tab is used as a delimiter for source and target sequence and space is used as a delimiter for UCS4 characters.
See data/nfkc.tsv as an example. Once a TSV file is prepared, you can specify it with `--normalization_rule_tsv` flag.
```
% spm_train --normalization_rule_tsv=<rule tsv file> --input=<input> --model_prefix=<model file> --vocab_size=8000
```
`<model file>` embeds the normalization rule `<rule tsv file>` so the same normalization rule is applied when `<model file>` is used.
## Command line tool to perform normalization
```
% spm_normalize --model=<model_file> file1 file2..
% spm_normalize --normalizatoin_rule_tsv=custom.tsv file1 file2..
```
The first command line uses the normalization rule embedded in the model file. The second command line uses the normalization rule in TSV file and is useful to make normalization rule interactively.

19
doc/special_symbols.md Normal file
View File

@ -0,0 +1,19 @@
# Use custom symbols
SentencePiece model supports two types of special symbols.
## Control symbol
Control symbols are used to encode special indicators for the decoder to change the behavior dynamically.
Example includes the language indicators in multi-lingual models. `<s>` and `</s>` are reserved control symbols.
Control symbols must be inserted outside of the SentencePiece segmentation. Developers need to take the responsibility to insert these symbols in data generation and decoding.
It is guaranteed that control symbols have no corresponding surface strings in the original user input. Control symbols are decoded into empty strings.
## User defined symbol
User defined symbol is handled as one piece in any context. If this symbol is included in the input text, this symbol is always extracted as one piece.
## Specify special symbols in training time
Use `--control_symbols` and `--user_defined_symbols` flags as follows
```
% spm_train --control_symbols=<foo>,<bar> --user_defined_symbols=<user1>,<user2> --input=<input file> --model_prefix=<model file> --vocab_size=8000
```

103
src/Makefile.am Normal file
View File

@ -0,0 +1,103 @@
lib_LTLIBRARIES = libsentencepiece.la
AM_CXXFLAS = -I($srcdir)
libsentencepiece_la_SOURCES = \
error.cc \
flags.cc \
sentencepiece_processor.cc \
util.cc \
normalizer.cc \
stringpiece.h unicode_script_map.h util.h \
common.h \
flags.h normalizer.h sentencepiece_processor.h \
model_factory.h model_factory.cc \
model_interface.h model_interface.cc \
unigram_model.h unigram_model.cc \
word_model.h word_model.cc \
char_model.h char_model.cc \
bpe_model.h bpe_model.cc
noinst_LIBRARIES = libtrain.a
libtrain_a_SOURCES = builder.cc builder.h \
normalization_rule.h \
unicode_script.h unicode_script.cc \
trainer_factory.h trainer_factory.cc \
trainer_interface.h trainer_interface.cc \
unigram_model_trainer.h unigram_model_trainer.cc \
word_model_trainer.h word_model_trainer.cc \
char_model_trainer.h char_model_trainer.cc \
bpe_model_trainer.h bpe_model_trainer.cc
nodist_libsentencepiece_la_SOURCES = \
sentencepiece.pb.cc sentencepiece.pb.h \
sentencepiece_model.pb.cc sentencepiece_model.pb.h
BUILT_SOURCES = \
sentencepiece.pb.cc \
sentencepiece_model.pb.cc
EXTRA_DIST = sentencepiece.proto sentencepiece_model.proto
bin_PROGRAMS = spm_encode spm_decode spm_normalize spm_train spm_export_vocab
noinst_PROGRAMS = compile_charsmap
spm_encode_SOURCES = spm_encode_main.cc
spm_encode_LDADD = libsentencepiece.la
spm_decode_SOURCES = spm_decode_main.cc
spm_decode_LDADD = libsentencepiece.la
spm_normalize_SOURCES = spm_normalize_main.cc
spm_normalize_LDADD = libsentencepiece.la libtrain.a
spm_export_vocab_SOURCES = spm_export_vocab_main.cc
spm_export_vocab_LDADD = libsentencepiece.la
spm_train_SOURCES = spm_train_main.cc
spm_train_LDADD = libsentencepiece.la libtrain.a
compile_charsmap_SOURCES = compile_charsmap_main.cc
compile_charsmap_LDADD = libsentencepiece.la libtrain.a
check_PROGRAMS = spm_test
TESTS = spm_test
spm_test_SOURCES = testharness.h \
builder_test.cc \
flags_test.cc \
normalizer_test.cc \
sentencepiece_processor_test.cc \
unicode_script_test.cc \
model_interface_test.cc \
model_factory_test.cc \
trainer_interface_test.cc \
trainer_factory_test.cc \
word_model_test.cc \
word_model_trainer_test.cc \
bpe_model_test.cc \
bpe_model_trainer_test.cc \
char_model_test.cc \
char_model_trainer_test.cc \
unigram_model_test.cc\
unigram_model_trainer_test.cc \
util_test.cc \
test_main.cc \
testharness.cc
spm_test_LDADD = libsentencepiece.la libtrain.a
CLEANFILES = *.pb.cc *.pb.h *.pb.h *.gcda *.gcno *.info
clean-local:
-rm -rf lcov_html
%.pb.cc %.pb.h: %.proto
$(PROTOC) --cpp_out=$(srcdir) $(srcdir)/$<
coverage:
make clean
make -j10 CXXFLAGS+="-O0 -Wall -std=c++11 -coverage" LIBS+="-lgcov -lprotobuf" check
lcov -c -d . -o coverage.info
lcov --remove coverage.info "include*" "/c++" "_test*" "testharness*" "third_party*" ".pb.*" -o coverage.info
mkdir -p lcov_html
genhtml -o lcov_html coverage.info

159
src/bpe_model.cc Normal file
View File

@ -0,0 +1,159 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "bpe_model.h"
#include <queue>
#include "util.h"
namespace sentencepiece {
namespace bpe {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
CheckControlSymbols();
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
CHECK(!sp.piece().empty());
if (sp.type() == ModelProto::SentencePiece::NORMAL) {
CHECK(sp.has_score());
port::InsertOrDie(&pieces_, sp.piece(), i);
} else if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
// TODO(taku): implement USER_DEFINED symbol.
LOG(FATAL) << "User defined symbol is not supported in BPE";
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
}
}
}
Model::~Model() {}
std::vector<std::pair<StringPiece, int>> Model::Encode(
StringPiece normalized) const {
if (normalized.empty()) {
return {};
}
struct SymbolPair {
int left; // left index of this pair
int right; // right index of this pair
float score; // score of this pair. large is better.
size_t size; // length of this piece
};
class SymbolPairComparator {
public:
const bool operator()(SymbolPair *h1, SymbolPair *h2) {
return (h1->score < h2->score ||
(h1->score == h2->score && h1->left > h2->left));
}
};
struct Symbol {
int prev; // prev index of this symbol. -1 for BOS.
int next; // next index of tihs symbol. -1 for EOS.
StringPiece piece;
};
using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
SymbolPairComparator>;
Agenda agenda;
std::vector<Symbol> symbols;
symbols.reserve(normalized.size());
// Lookup new symbol pair at [left, right] and inserts it to agenda.
auto MaybeAddNewSymbolPair = [this, &symbols, &agenda](int left, int right) {
if (left == -1 || right == -1) return;
const StringPiece piece(
symbols[left].piece.data(),
symbols[left].piece.size() + symbols[right].piece.size());
const auto it = pieces_.find(piece);
if (it == pieces_.end()) {
return;
}
auto *h = new SymbolPair;
h->left = left;
h->right = right;
h->score = GetScore(it->second);
h->size = piece.size();
agenda.push(h);
};
// Splits the input into character sequence
const char *begin = normalized.data();
const char *end = normalized.data() + normalized.size();
int index = 0;
while (begin < end) {
int mblen = string_util::OneCharLen(begin);
if (mblen > end - begin) {
LOG(ERROR) << "Invalid character length.";
mblen = end - begin;
}
Symbol s;
s.piece = StringPiece(begin, mblen);
s.prev = begin == normalized.data() ? -1 : index - 1;
begin += mblen;
s.next = begin == end ? -1 : index + 1;
++index;
symbols.emplace_back(s);
}
CHECK(!symbols.empty());
// Lookup all bigrams.
for (size_t i = 1; i < symbols.size(); ++i) {
MaybeAddNewSymbolPair(i - 1, i);
}
// Main loop.
while (!agenda.empty()) {
std::unique_ptr<SymbolPair> top(agenda.top());
agenda.pop();
// |top| is no longer available.
if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
top->size) {
continue;
}
// Replaces symbols with |top| rule.
symbols[top->left].piece = StringPiece(
symbols[top->left].piece.data(),
symbols[top->left].piece.size() + symbols[top->right].piece.size());
// Updates prev/next pointers.
symbols[top->left].next = symbols[top->right].next;
if (symbols[top->right].next >= 0) {
symbols[symbols[top->right].next].prev = top->left;
}
symbols[top->right].piece = StringPiece("");
// Adds new symbol pairs which are newly added after symbol replacement.
MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
}
std::vector<std::pair<StringPiece, int>> output;
for (int index = 0; index != -1; index = symbols[index].next) {
CHECK_GE(index, 0);
CHECK_LT(index, static_cast<int>(symbols.size()));
output.emplace_back(symbols[index].piece, PieceToId(symbols[index].piece));
}
return output;
}
} // namespace bpe
} // namespace sentencepiece

40
src/bpe_model.h Normal file
View File

@ -0,0 +1,40 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef BPE_MODEL_H_
#define BPE_MODEL_H_
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
namespace sentencepiece {
namespace bpe {
// Segmentation model with BPE (Byte Pair Encoding)
// Details:
// Neural Machine Translation of Rare Words with Subword Units
// https://arxiv.org/abs/1508.07909
//
// https://en.wikipedia.org/wiki/Byte_pair_encoding
class Model : public ModelInterface {
public:
explicit Model(const ModelProto &model_proto);
~Model() override;
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const override;
};
} // namespace bpe
} // namespace sentencepiece
#endif // BPE_MODEL_H_

144
src/bpe_model_test.cc Normal file
View File

@ -0,0 +1,144 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "bpe_model.h"
#include "testharness.h"
namespace sentencepiece {
namespace bpe {
namespace {
ModelProto MakeBaseModelProto() {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(BPEModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "ab", -0.1);
AddPiece(&model_proto, "cd", -0.2);
AddPiece(&model_proto, "abc", -0.3);
AddPiece(&model_proto, "a", -0.4);
AddPiece(&model_proto, "b", -0.5);
AddPiece(&model_proto, "c", -0.6);
AddPiece(&model_proto, "d", -0.7);
const Model model(model_proto);
std::vector<std::pair<StringPiece, int>> result;
result = model.Encode("");
EXPECT_TRUE(result.empty());
result = model.Encode("abc");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abc", result[0].first);
result = model.Encode("AB");
EXPECT_EQ(2, result.size());
EXPECT_EQ("A", result[0].first);
EXPECT_EQ("B", result[1].first);
result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("cd", result[1].first);
result = model.Encode("abcc");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("c", result[1].first);
result = model.Encode("xabcabaabcdd");
EXPECT_EQ(7, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("abc", result[1].first);
EXPECT_EQ("ab", result[2].first);
EXPECT_EQ("a", result[3].first);
EXPECT_EQ("ab", result[4].first);
EXPECT_EQ("cd", result[5].first);
EXPECT_EQ("d", result[6].first);
// all unknown.
result = model.Encode("xyz東京");
EXPECT_EQ(5, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("y", result[1].first);
EXPECT_EQ("z", result[2].first);
EXPECT_EQ("", result[3].first);
EXPECT_EQ("", result[4].first);
}
TEST(BPEModelTest, EncodeAmbiguousTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "aa", -0.1);
AddPiece(&model_proto, "bb", -0.2);
AddPiece(&model_proto, "ab", -0.3);
AddPiece(&model_proto, "a", -0.4);
AddPiece(&model_proto, "b", -0.5);
const Model model(model_proto);
std::vector<std::pair<StringPiece, int>> result;
// leftmost symbols are merged first.
result = model.Encode("aaa");
EXPECT_EQ(2, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("a", result[1].first);
// "bb" is replaced earlier than "ab".
result = model.Encode("aabb");
EXPECT_EQ(2, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("bb", result[1].first);
// "bb" is replaced earlier than "ab".
result = model.Encode("aaabbb");
EXPECT_EQ(4, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("a", result[1].first);
EXPECT_EQ("bb", result[2].first);
EXPECT_EQ("b", result[3].first);
result = model.Encode("aaaba");
EXPECT_EQ(3, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("ab", result[1].first);
EXPECT_EQ("a", result[2].first);
}
} // namespace
} // namespace bpe
} // namespace sentencepiece

323
src/bpe_model_trainer.cc Normal file
View File

@ -0,0 +1,323 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "bpe_model_trainer.h"
#include <unordered_set>
#include "util.h"
namespace sentencepiece {
namespace bpe {
std::string Trainer::Symbol::ToString() const {
return string_util::UnicodeTextToUTF8(chars);
}
Trainer::Symbol *Trainer::GetCharSymbol(char32 c) {
const uint64 freq = port::FindWithDefault(required_chars_, c, 1);
CHECK_GT(freq, 0);
const auto it = symbols_cache_.find(c);
if (it != symbols_cache_.end()) {
return it->second;
}
Symbol *s = new Symbol;
allocated_.push_back(s);
s->is_unk = (kUNKChar == c);
s->fp = c;
s->chars.push_back(c);
s->freq = freq;
port::InsertOrDie(&symbols_cache_, s->fp, s);
return s;
}
Trainer::Symbol *Trainer::GetPairSymbol(const Symbol *left,
const Symbol *right) {
if (left == nullptr || right == nullptr || left->is_unk || right->is_unk) {
return nullptr;
}
const uint64 fp = port::FingerprintCat(left->fp, right->fp);
const auto it = symbols_cache_.find(fp);
if (it != symbols_cache_.end()) {
return it->second;
}
CHECK(!left->chars.empty());
CHECK(!right->chars.empty());
string_util::UnicodeText ut;
for (const char32 c : left->chars) ut.push_back(c);
for (const char32 c : right->chars) ut.push_back(c);
// Do not make an invalid piece.
if (!IsValidSentencePiece(ut)) {
return nullptr;
}
Symbol *s = new Symbol;
allocated_.push_back(s);
s->fp = fp;
s->left = left;
s->right = right;
s->chars = ut;
port::InsertOrDie(&symbols_cache_, s->fp, s);
return s;
}
void Trainer::ComputeFreq(Symbol *symbol) const {
if (symbol->freq > 0) { // if freq == 0, re-computation is required.
return;
}
// Avoids double-count. ("AAA" => only count the first "AA").
Position prev_pos = {-1, 0};
CHECK_EQ(0, symbol->freq);
for (auto it = symbol->positions.begin(); it != symbol->positions.end();) {
const Position pos = DecodePos(*it);
// There are two same bigrams in "AAA", [AA] [AA], and we want to
// remove the second one to avoid double counts.
// If the right symbol in the first bigram and the left symbol in the
// second bigram have the same position, (pos.left == prev_pos.right),
// duplicated bigram exisit.
// Also, symbols_[sid][left] and symbols_[sid]right] must store
// the same symbols in symbol->left and symbols->right.
if ((pos.sid == prev_pos.sid && pos.left == prev_pos.right) ||
symbol->left != symbols_[pos.sid][pos.left] ||
symbol->right != symbols_[pos.sid][pos.right]) {
it = symbol->positions.erase(it);
// Initializes prev_pos.
// In "AAAA", the last "AA" can be counted.
prev_pos = {-1, 0};
} else {
symbol->freq += sentences_[pos.sid].second;
prev_pos = pos;
++it;
}
}
}
int Trainer::GetNextIndex(int sid, int index) const {
for (size_t i = index + 1; i < symbols_[sid].size(); ++i) {
if (symbols_[sid][i] == nullptr) continue;
return i;
}
return -1;
}
int Trainer::GetPrevIndex(int sid, int index) const {
for (int i = index - 1; i >= 0; --i) {
if (symbols_[sid][i] == nullptr) continue;
return i;
}
return -1;
}
void Trainer::AddNewPair(int sid, int left, int right) {
if (left == -1 || right == -1) return;
auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]);
if (symbol != nullptr) {
active_symbols_.insert(symbol);
symbol->positions.insert(EncodePos(sid, left, right));
}
}
void Trainer::ResetFreq(int sid, int left, int right, const Symbol *best) {
if (left == -1 || right == -1) return;
auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]);
if (symbol != nullptr && symbol != best) {
symbol->freq = 0;
}
}
void Trainer::UpdateActiveSymbols() {
std::vector<Symbol *> symbols;
for (auto &it : symbols_cache_) {
Symbol *symbol = it.second;
if (symbol->IsBigram()) {
ComputeFreq(symbol);
symbols.push_back(symbol);
}
}
// At least kMinActiveSymbolsSize symbols must be in |active_symbols_|.
constexpr int kMinActiveSymbolsSize = 1000;
// Keeps top 5% frequent symbols.
constexpr float kTopFrequentRatio = 0.05;
const int size =
std::min<int>(std::max<int>(kMinActiveSymbolsSize,
symbols_cache_.size() * kTopFrequentRatio),
symbols.size());
std::partial_sort(symbols.begin(), symbols.begin() + size, symbols.end(),
[](Symbol *s1, Symbol *s2) { return s1->freq > s2->freq; });
LOG(INFO) << "Updating active symbols. max_freq=" << symbols[0]->freq
<< " min_freq=" << symbols[size - 1]->freq;
active_symbols_.clear();
active_symbols_.insert(symbols.begin(), symbols.begin() + size);
}
void Trainer::Train() {
#define CHECK_RANGE(variable, minval, maxval) \
CHECK(variable >= minval && variable <= maxval)
CHECK_GT(trainer_spec_.input().size(), 0);
CHECK(!trainer_spec_.model_prefix().empty());
CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0);
CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000);
CHECK_RANGE(trainer_spec_.max_sentencepiece_length(), 1, 64);
CHECK_GT(trainer_spec_.vocab_size(), 0);
#undef CHECK_RANGE
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
CHECK(normalizer_spec_.escape_whitespaces());
CHECK_EQ(TrainerSpec::BPE, trainer_spec_.model_type());
symbols_.clear();
allocated_.clear();
symbols_cache_.clear();
active_symbols_.clear();
// Load all sentences
LoadSentences();
if (trainer_spec_.split_by_whitespace()) {
SplitSentencesByWhitespace();
}
// Initializes symbols_. symbols_[sid][i] stores an unary symbol.
symbols_.resize(sentences_.size());
for (size_t i = 0; i < sentences_.size(); ++i) {
for (const char32 c : string_util::UTF8ToUnicodeText(sentences_[i].first)) {
symbols_[i].push_back(GetCharSymbol(c));
}
}
// Makes all bigram symbols.
for (size_t sid = 0; sid < symbols_.size(); ++sid) {
for (size_t i = 1; i < symbols_[sid].size(); ++i) {
AddNewPair(sid, i - 1, i);
}
}
const int meta_symbols_size = trainer_spec_.control_symbols().size() +
trainer_spec_.user_defined_symbols().size() +
3; // <s>, </s>, <unk>
const int vocab_size =
trainer_spec_.vocab_size() - meta_symbols_size - required_chars_.size();
CHECK_GE(vocab_size, 0);
// We may see duplicated pieces that are extracted with different path.
// In real segmentation phase, we can consider them as one symbol.
// e.g., "aaa" => "aa" + "a" or "a" + "aa".
std::unordered_set<std::string> dup;
// Main loop.
CHECK(final_pieces_.empty());
while (final_pieces_.size() < static_cast<size_t>(vocab_size)) {
constexpr int kUpdateActiveSymbolsInteval = 100;
if (final_pieces_.size() % kUpdateActiveSymbolsInteval == 0) {
UpdateActiveSymbols();
}
// Scanning active symbols, finds the best_symbol with highest freq.
Symbol *best_symbol = nullptr;
for (auto &it : active_symbols_) {
Symbol *symbol = it;
ComputeFreq(symbol);
// If the frequency is the same, take shorter symbol.
// if the length is the same, use lexicographical comparison
if (best_symbol == nullptr ||
(symbol->freq > best_symbol->freq ||
(symbol->freq == best_symbol->freq &&
(symbol->chars.size() < best_symbol->chars.size() ||
(symbol->chars.size() == best_symbol->chars.size() &&
symbol->ToString() < best_symbol->ToString()))))) {
best_symbol = symbol;
}
}
if (best_symbol == nullptr) {
LOG(WARNING) << "No valid symbol found";
break;
}
if (!dup.insert(best_symbol->ToString()).second) {
// Removes best_symbol so it is not selected again.
symbols_cache_.erase(best_symbol->fp);
active_symbols_.erase(best_symbol);
continue;
}
// Stores the best_symbol in the final output.
const float score = -final_pieces_.size();
final_pieces_.emplace_back(best_symbol->ToString(), score);
if (final_pieces_.size() % 20 == 0) {
LOG(INFO) << "Added: freq=" << best_symbol->freq
<< " size=" << final_pieces_.size()
<< " all=" << symbols_cache_.size()
<< " active=" << active_symbols_.size()
<< " piece=" << best_symbol->ToString();
}
// Add new bigrams which are created after symbol replacement.
// We do not need to scan all characters, but scan the neighbors in
// best_symbol.
for (const uint64 &encoded_pos : best_symbol->positions) {
const Position pos = DecodePos(encoded_pos);
if (symbols_[pos.sid][pos.left] == nullptr) {
// left index might be NULL (set in the privous iteration)
// when left_symbol == right_symbol.
continue;
}
CHECK_NOTNULL(symbols_[pos.sid][pos.right]);
// We have three bigrams [prev, left], [left, right], [right, next],
// which are affected with this symbol replacement.
const int next = GetNextIndex(pos.sid, pos.right);
const int prev = GetPrevIndex(pos.sid, pos.left);
// Resets the frequencies of bigrams [prev, left] and [right, next].
ResetFreq(pos.sid, prev, pos.left, best_symbol);
ResetFreq(pos.sid, pos.right, next, best_symbol);
// Merges two symbols.
symbols_[pos.sid][pos.left] = best_symbol;
symbols_[pos.sid][pos.right] = nullptr;
// Makes new symbol bigrams [prev, left] and [left, next].
AddNewPair(pos.sid, prev, pos.left);
AddNewPair(pos.sid, pos.left, next);
}
// Removes best_symbol so it is not selected again.
symbols_cache_.erase(best_symbol->fp);
active_symbols_.erase(best_symbol);
} // end of main loop
// Adds required_chars_
for (const auto &w : Sorted(required_chars_)) {
const Symbol *symbol = GetCharSymbol(w.first);
const float score = -final_pieces_.size();
final_pieces_.emplace_back(symbol->ToString(), score);
}
Save();
port::STLDeleteElements(&allocated_);
}
} // namespace bpe
} // namespace sentencepiece

121
src/bpe_model_trainer.h Normal file
View File

@ -0,0 +1,121 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef BPE_MODEL_TRAINER_H_
#define BPE_MODEL_TRAINER_H_
#include <set>
#include "sentencepiece_model.pb.h"
#include "trainer_interface.h"
namespace sentencepiece {
namespace bpe {
// Trainer class for BPE model.
class Trainer : public TrainerInterface {
public:
Trainer(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
void Train() override;
private:
// Symbol represents a character or symbol bigram.
struct Symbol {
const Symbol *left; // left symbol in bigram
const Symbol *right; // right symbol in bigram
string_util::UnicodeText chars; // all flattend chracter sequence
bool is_unk; // true if this symbol is unknown.
uint64 fp; // fingerprint of this symbol.
uint64 freq; // frequency of this symbol.
// Position list. Use set so that we can keep the order of occurrence.
// See EncodePos/DecodePos.
std::set<uint64> positions;
bool IsBigram() const { return left != nullptr && right != nullptr; }
std::string ToString() const;
Symbol() : left(nullptr), right(nullptr), is_unk(false), fp(0), freq(0) {}
};
struct Position {
int sid; // sentence id
int left; // left symbol index
int right; // right symbol index
};
// Encodes sid, left and right bigram index into uint64.
// Encoded value keeps the order of sid, left and right.
static uint64 EncodePos(int sid, int l, int r) {
CHECK_GE(l, 0);
CHECK_GE(r, 0);
CHECK_LE(l, kuint16max);
CHECK_LE(r, kuint16max);
const uint64 n = (static_cast<uint64>(sid) << 32 | (l << 16 | r));
return n;
}
// Decodes sid, left and right bigram index from uint64.
static Position DecodePos(uint64 n) {
Position p;
p.sid = n >> 32;
p.left = (n >> 16) & 0xffff;
p.right = n & 0xffff;
return p;
}
// Gets unary (character) symbol from the char code |c|.
// The return value is cached.
Symbol *GetCharSymbol(char32 c);
// Gets symbol pair from left/right symbols. The return value is cached.
Symbol *GetPairSymbol(const Symbol *left, const Symbol *right);
// Computes the frequency of |symbol| and update symbol->freq field.
void ComputeFreq(Symbol *symbol) const;
// Returns the valid index before symbols_[sid][index].
int GetNextIndex(int sid, int index) const;
// Returns the valid index after symbols_[sid][index].
int GetPrevIndex(int sid, int index) const;
// Makes a new bigram from [symbols_[sid][left], symbols_[sid][right]] and
// Adds it to symbols_cache_ and active_symbols_.
void AddNewPair(int sid, int left, int right);
// Resets the fequency of bigram [symbols_[sid][left] symbols_[sid][right]],
// if this bigram is not |best|.
void ResetFreq(int sid, int left, int right, const Symbol *best);
// Updates |active_symbols_| by copying the top 5% frequent symbols in
// symbols_cache_.
void UpdateActiveSymbols();
// All unique symbols. Key is a fingerprint of Symbol.
std::unordered_map<uint64, Symbol *> symbols_cache_;
// Set of symbols from which we find the best symbol in each iteration.
std::set<Symbol *> active_symbols_;
// Stores symbols allocated in heap so that we can delete them at onece.
std::vector<Symbol *> allocated_;
// Sentences. symbols_[sid][index] stores a symbol in sentence_[sid][index].
std::vector<std::vector<Symbol *>> symbols_;
};
} // namespace bpe
} // namespace sentencepiece
#endif // BPE_MODEL_TRAINER_H_

View File

@ -0,0 +1,122 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "bpe_model_trainer.h"
#include "builder.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace bpe {
// Space symbol
#define WS "\xe2\x96\x81"
namespace {
std::string RunTrainer(const std::vector<std::string> &input, int size) {
test::ScopedTempFile input_scoped_file("input");
test::ScopedTempFile model_scoped_file("model");
const std::string input_file = input_scoped_file.filename();
const std::string model_prefix = model_scoped_file.filename();
{
io::OutputBuffer output(input_file);
for (const auto &line : input) {
output.WriteLine(line);
}
}
TrainerSpec trainer_spec;
trainer_spec.set_model_type(TrainerSpec::BPE);
trainer_spec.add_input(input_file);
trainer_spec.set_vocab_size(size - 3); // remove <unk>, <s>, </s>
trainer_spec.set_model_prefix(model_prefix);
auto normalizer_spec = normalizer::Builder::GetNormalizerSpec("identity");
normalizer_spec.set_add_dummy_prefix(false);
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
SentencePieceProcessor processor;
processor.Load(model_prefix + ".model");
const auto &model = processor.model_proto();
std::vector<std::string> pieces;
// remove <unk>, <s>, </s>
for (int i = 3; i < model.pieces_size(); ++i) {
pieces.emplace_back(model.pieces(i).piece());
}
return string_util::Join(pieces, " ");
}
} // namespace
TEST(BPETrainerTest, BasicTest) {
EXPECT_EQ("ab ra abra ad cad abracad abracadabra ac br a b r c d",
RunTrainer({"abracadabra"}, 20));
EXPECT_EQ("ap le app apple en in ine pen p e a l n i",
RunTrainer({"pen", "pineapple", "apple"}, 20));
EXPECT_EQ("he ll llo hello hellohe el lo oh hel ohe e h l o",
RunTrainer({"hellohe"}, 20));
}
TEST(BPETrainerTest, EndToEndTest) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
normalizer_spec = normalizer::Builder::GetNormalizerSpec("nfkc");
trainer_spec.add_input("../data/wagahaiwa_nekodearu.txt");
constexpr int kVocabSize = 8000;
trainer_spec.set_vocab_size(kVocabSize);
trainer_spec.set_model_type(TrainerSpec::BPE);
trainer_spec.add_control_symbols("<ctrl>");
// trainer_spec.add_user_defined_symbols("<user>");
test::ScopedTempFile sf("tmp_model");
trainer_spec.set_model_prefix(sf.filename());
bpe::Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(std::string(sf.filename()) + ".model"));
EXPECT_EQ(kVocabSize, sp.GetPieceSize());
const int cid = sp.PieceToId("<ctrl>");
// const int uid = sp.PieceToId("<user>");
EXPECT_TRUE(sp.IsControl(cid));
// EXPECT_FALSE(sp.IsUnknown(uid));
std::vector<std::string> tok;
sp.Encode("", &tok);
EXPECT_TRUE(tok.empty());
sp.Encode(
"吾輩《わがはい》は猫である。名前はまだ無い。"
"どこで生れたかとんと見当《けんとう》がつかぬ。"
"何でも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している"
"",
&tok);
EXPECT_EQ(WS
" 吾輩 《 わが はい 》 は猫 である 。 名前 はまだ 無い 。 "
"どこで 生 れた か とん と見 当 《 けんとう 》 が つかぬ 。 "
"何でも 薄 暗 いじ め じ め した 所で ニャー ニャー 泣 いていた "
"事 だけは 記憶 している 。",
string_util::Join(tok, " "));
}
} // namespace bpe
} // namespace sentencepiece

345
src/builder.cc Normal file
View File

@ -0,0 +1,345 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "builder.h"
#ifdef ENABLE_NFKC_COMPILE
#include <unicode/errorcode.h>
#include <unicode/locid.h>
#include <unicode/normlzr.h>
#include <unicode/numfmt.h>
#include <unicode/rbnf.h>
#include <unicode/utypes.h>
#endif
#include <set>
#include "normalization_rule.h"
#include "normalizer.h"
#include "third_party/darts_clone/darts.h"
#include "util.h"
namespace sentencepiece {
namespace normalizer {
namespace {
#ifdef ENABLE_NFKC_COMPILE
// Normalize |input| with ICU's normalizer with |mode|.
Builder::Chars UnicodeNormalize(UNormalizationMode mode,
const Builder::Chars &input) {
const std::string utf8 = string_util::UnicodeTextToUTF8(input);
CHECK(!utf8.empty());
icu::UnicodeString ustr;
const size_t utf8_length = utf8.size();
UChar *utf16 = ustr.getBuffer(utf8.size() + 1);
int32 utf16_length = 0;
icu::ErrorCode icuerrorcode;
u_strFromUTF8Lenient(utf16, ustr.getCapacity(), &utf16_length, utf8.data(),
utf8_length, icuerrorcode);
ustr.releaseBuffer(utf16_length);
UErrorCode status = U_ZERO_ERROR;
icu::UnicodeString dst;
icu::Normalizer::normalize(ustr, mode, 0, dst, status);
CHECK(U_SUCCESS(status));
std::string normalized;
normalized.reserve(dst.length() * 3);
dst.toUTF8String(normalized);
return string_util::UTF8ToUnicodeText(normalized);
}
Builder::Chars ToNFKD(const Builder::Chars &input) {
return UnicodeNormalize(UNORM_NFKD, input);
}
Builder::Chars ToNFKC(const Builder::Chars &input) {
return UnicodeNormalize(UNORM_NFKC, input);
}
Builder::Chars ToNFC(const Builder::Chars &input) {
return UnicodeNormalize(UNORM_NFC, input);
}
Builder::Chars ToNFD(const Builder::Chars &input) {
return UnicodeNormalize(UNORM_NFD, input);
}
// Given an NFKD-normalized string, returns a set of all strings which are
// normalized into the same |nfkd|. |norm2orig| is the normalized to
// un-normalized character mapping.
std::vector<Builder::Chars> ExpandUnnormalized(
const Builder::Chars &nfkd,
const std::map<char32, std::set<char32>> &norm2orig) {
CHECK(!nfkd.empty());
std::vector<Builder::Chars> results;
for (const auto c : port::FindOrDie(norm2orig, nfkd[0])) {
results.push_back({c});
}
for (size_t i = 1; i < nfkd.size(); ++i) {
const auto &orig = port::FindOrDie(norm2orig, nfkd[i]);
std::vector<Builder::Chars> new_results;
for (const auto &r : results) {
for (const auto c : orig) {
new_results.emplace_back(r);
new_results.back().push_back(c);
}
}
results = std::move(new_results);
}
CHECK_EQ(nfkd.size(), results[0].size());
return results;
}
#endif
// Normalizes |src| with |chars_map| and returns normalized Chars.
// |max_len| specifies the maximum length of the key in |chars_map|.
Builder::Chars Normalize(const Builder::CharsMap &chars_map,
const Builder::Chars &src, int max_len) {
CHECK_GE(max_len, 1);
Builder::Chars normalized;
for (size_t i = 0; i < src.size();) {
Builder::CharsMap::const_iterator it = chars_map.end();
const size_t slice = std::min<size_t>(i + max_len, src.size());
// starts with the longest prefix.
Builder::Chars key(src.begin() + i, src.begin() + slice);
while (!key.empty()) {
it = chars_map.find(key);
if (it != chars_map.end()) {
break;
}
key.pop_back(); // remove the last character.
}
// Consumes one character when no rule is found.
if (it == chars_map.end()) {
normalized.push_back(src[i]);
++i;
} else {
CHECK(!it->second.empty());
std::copy(it->second.begin(), it->second.end(),
std::back_inserter(normalized));
i += it->first.size();
}
}
return normalized;
}
} // namespace
// static
std::string Builder::CompileCharsMap(const CharsMap &chars_map) {
CHECK(!chars_map.empty());
LOG(INFO) << "Loading CharsMap of size " << chars_map.size();
// Aggregates the same target strings to save footprint.
std::map<Chars, int> normalized2pos;
for (const auto &p : chars_map) {
normalized2pos[p.second] = 0;
}
std::string normalized;
for (auto &p : normalized2pos) {
p.second = normalized.size(); // stores the pointer (position).
const std::string utf8_out = string_util::UnicodeTextToUTF8(p.first);
normalized += utf8_out;
normalized += '\0';
}
std::vector<std::pair<std::string, int>> kv; // key-value of Trie.
for (const auto &p : chars_map) {
// The value of Trie stores the pointer to the normalized string.
const std::string utf8_in = string_util::UnicodeTextToUTF8(p.first);
kv.emplace_back(utf8_in, port::FindOrDie(normalized2pos, p.second));
}
std::sort(kv.begin(), kv.end());
std::vector<const char *> key(kv.size());
std::vector<int> value(kv.size());
for (size_t i = 0; i < kv.size(); ++i) {
key[i] = kv[i].first.c_str();
value[i] = kv[i].second;
}
Darts::DoubleArray trie;
CHECK_EQ(
0,
trie.build(key.size(), const_cast<char **>(&key[0]), nullptr, &value[0]))
<< "cannot build double-array";
int max_nodes_size = 0;
std::vector<Darts::DoubleArray::result_pair_type> results(
2 * Normalizer::kMaxTrieResultsSize);
for (const char *str : key) {
const int num_nodes = trie.commonPrefixSearch(str, results.data(),
results.size(), strlen(str));
max_nodes_size = std::max(num_nodes, max_nodes_size);
}
CHECK_LT(max_nodes_size, Normalizer::kMaxTrieResultsSize)
<< "This charmaps contain many shared prefix. "
<< "The number of shared prefix must be less than "
<< Normalizer::kMaxTrieResultsSize;
StringPiece trie_blob(static_cast<const char *>(trie.array()),
trie.size() * trie.unit_size());
const std::string blob =
Normalizer::EncodePrecompiledCharsMap(trie_blob, normalized);
LOG(INFO) << "Generated normalizer blob. size= " << blob.size();
return blob;
}
// static
std::string Builder::GetPrecompiledCharsMap(const std::string &name) {
std::string result;
for (size_t i = 0; i < kNormalizationRules_size; ++i) {
const auto *blob = &kNormalizationRules_blob[i];
if (blob->name == name) {
result.assign(blob->data, blob->size);
return result;
}
}
LOG(FATAL) << "No precompiled charsmap is found: " << name;
return result;
}
// static
NormalizerSpec Builder::GetNormalizerSpec(const std::string &name) {
NormalizerSpec spec;
spec.set_name(name);
spec.set_precompiled_charsmap(GetPrecompiledCharsMap(name));
return spec;
}
// static
Builder::CharsMap Builder::BuildNFKCMap() {
#ifdef ENABLE_NFKC_COMPILE
LOG(INFO) << "Running BuildNFKCMap";
// Set of fully NFKD decomposed characters.
std::set<Builder::Chars> nfkd_decomposed;
// Fully normalized one character to unnormalized one character map.
std::map<char32, std::set<char32>> norm2orig;
Builder::CharsMap nfkc_map; // The final NFKC mapping.
constexpr int kMaxUnicode = 0x110000;
for (char32 cp = 1; cp <= kMaxUnicode; ++cp) {
if (!U_IS_UNICODE_CHAR(cp)) {
continue;
}
// Aggregates single character to fully NFKC normalized characters.
const auto nfkc = ToNFKC({cp});
if (nfkc.size() >= 2 || (nfkc.size() == 1 && nfkc[0] != cp)) {
nfkc_map[{cp}] = nfkc;
}
const auto nfkd = ToNFKD({cp});
if (nfkd.size() == 1) {
// Aggregates reverse mapping from normalized to unnormalized character.
norm2orig[nfkd[0]].insert(cp);
} else {
// One character is decomposed into multiple characters.
nfkd_decomposed.insert(nfkd);
}
}
for (const auto &nfkd : nfkd_decomposed) {
const auto nfkc = ToNFC(nfkd);
// This case is already covered by single-character to NFKC mapping.
if (nfkc == nfkd) {
continue;
}
// Expand all possible sequences which are normalized into the same |nfkd|.
for (const auto &nfkd_orig : ExpandUnnormalized(nfkd, norm2orig)) {
if (nfkd_orig != nfkc) {
nfkc_map[nfkd_orig] = nfkc;
}
}
}
return RemoveRedundantMap(nfkc_map);
#else
LOG(FATAL) << "NFKC compile is not enabled."
<< " rebuild with ./configure --enable-nfkc-compile";
return {};
#endif
}
// static
Builder::CharsMap Builder::BuildIdentityMap() {
// Adds one dummy entry since empty rule is not allowed.
const CharsMap result = {{{0x0020}, {0x0020}}};
return result;
}
// static
Builder::CharsMap Builder::BuildMapFromFile(StringPiece filename) {
LOG(INFO) << "Loading maping file: " << filename.data();
io::InputBuffer input(filename);
std::string line;
CharsMap chars_map;
while (input.ReadLine(&line)) {
const auto fields = string_util::SplitPiece(line, "\t");
CHECK_GE(fields.size(), 2);
std::vector<char32> src, trg;
for (const auto &s : string_util::SplitPiece(fields[0], " ")) {
src.push_back(string_util::HexToInt<char32>(s));
}
for (const auto &s : string_util::SplitPiece(fields[1], " ")) {
trg.push_back(string_util::HexToInt<char32>(s));
}
CHECK(!src.empty());
CHECK(!trg.empty());
chars_map[src] = trg;
}
return chars_map;
}
// static
Builder::CharsMap Builder::RemoveRedundantMap(const CharsMap &chars_map) {
CharsMap new_chars_map;
size_t max_len = 0;
for (const auto &p : chars_map) {
max_len = std::max(p.first.size(), max_len);
if (p.first.size() == 1) {
new_chars_map.insert(p);
}
}
CHECK_GT(max_len, 0);
// Checks whether the rules with size of |len| can be normalized by
// the rules with size of [1 .. len - 1].
for (size_t len = 2; len <= max_len; ++len) {
for (const auto &p : chars_map) {
if (p.first.size() == len &&
p.second != Normalize(new_chars_map, p.first, len - 1)) {
new_chars_map.insert(p);
}
}
}
// Verify all characters in |chars_map| are normalized by |new_chars_map|.
for (const auto &p : chars_map) {
CHECK_EQ(p.second, Normalize(new_chars_map, p.first, max_len));
}
return new_chars_map;
}
} // namespace normalizer
} // namespace sentencepiece

109
src/builder.h Normal file
View File

@ -0,0 +1,109 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef BUILDER_H_
#define BUILDER_H_
#include <map>
#include <string>
#include <vector>
#include "common.h"
#include "sentencepiece_model.pb.h"
#include "stringpiece.h"
namespace sentencepiece {
namespace normalizer {
// Builder creates a text normalization rule from user-defined string
// to string mappings. The normalization mapping is compiled into
// a single and compact blob index which is stored into the model proto.
// This class also provides pre-defined rules based on Unicode NFKC.
// https://en.wikipedia.org/wiki/Unicode_equivalence#Normalization
class Builder {
public:
Builder() = delete;
~Builder() = delete;
// Basic Unicode character sequence.
using Chars = std::vector<char32>;
// String-to-string mapping.
using CharsMap = std::map<Chars, Chars>;
// Compiles |chars_map| into a binary index.
static std::string CompileCharsMap(const CharsMap &chars_map);
// Returns a pre-compiled binary index with |name|.
static std::string GetPrecompiledCharsMap(const std::string &name);
// Returns a normalizer spec with a binary index |name|.
static NormalizerSpec GetNormalizerSpec(const std::string &name);
// Makes a normalization mapping based on NFKC.
//
// Note that Normalizer/Builder classes do not support
// full NFKC normalization, since full NFKC normalization cannot
// be implemented with a simple longest matching string-to-string
// replacement. One unsupported normalization is multiple combining
// marks.
//
// Strings with multiple combining marks cannot correctly
// be normalized, because it needs to sort the combining marks
// with Canonical_Combining_Class (CCC).
// http://unicode.org/reports/tr15/#Multiple_Mark_Figure
//
// Example:
// Original: U+1E0B U+0323
// Decomposed: U+0064 U+0307 U+0323
// NFKD: U+0064 U+0323 U+0307 (Combining characters are sorted by CCC)
// NFKC: U+1E0D U+0307 (U+0064 U+0323 => U+1E0D)
//
// To support the normalization above with a longest matching, we need to
// enumerate all possible permutations of combining marks in advance,
// which is not feasible. For example, suppose the case there are three
// combining marks X, Y and Z, which are sorted into one canonical order
// Z, Y, X with NFK(D|C). In this case, all permutations (XYZ, XZY, YXZ...)
// are normalized into ZYX. When we implement this normalization with
// a longest matching, we need to have 3! rules. XYZ=>ZYX, XZY=>ZYX..
// Since Unicode has more than 100 combining characters, it is not possible
// to expand all permutations.
//
// We will not implement the full NFKC in SentencePiece because
// 1) It is unusual to see decomposed Unicode characters in real text.
// 2) Providing a flexible, user-customizable, and self-contained
// normalizer is the goal of SentencePiece.
//
// TODO(taku): Make NFC, NFD, and NFKD mapping if necessary.
static CharsMap BuildNFKCMap();
// Returns identity mapping, which dose not perform any normalization.
static CharsMap BuildIdentityMap();
// Builds Chars map save in |filename|.
// Format:
// src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2...
// (src|trg)_ucharX must be a hex of UCS4.
static CharsMap BuildMapFromFile(StringPiece filename);
private:
FRIEND_TEST(BuilderTest, RemoveRedundantMapTest);
// Removes redundant rules from |chars_map|.
// When char_maps have "aa" => "bb" and "a" => "b", the first
// rule is not necessary since the second rule can cover the first rule.
static CharsMap RemoveRedundantMap(const CharsMap &chars_map);
};
} // namespace normalizer
} // namespace sentencepiece
#endif // BUILDER_H_

129
src/builder_test.cc Normal file
View File

@ -0,0 +1,129 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "builder.h"
#include "common.h"
#include "normalizer.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace normalizer {
// Space symbol
#define WS "\xe2\x96\x81"
TEST(BuilderTest, RemoveRedundantMapTest) {
Builder::CharsMap chars_map;
// ab => AB, a => A, b => B, abc => BCA
chars_map[{0x0061}] = {0x0041};
chars_map[{0x0062}] = {0x0042};
chars_map[{0x0061, 0x0062}] = {0x0041, 0x0042};
chars_map[{0x0061, 0x0062, 0x0063}] = {0x0043, 0x0042, 0x0041};
const auto new_chars_map = Builder::RemoveRedundantMap(chars_map);
EXPECT_EQ(3, new_chars_map.size());
EXPECT_EQ(new_chars_map.end(), new_chars_map.find({0x0061, 0x0062}));
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0061}));
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0062}));
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0061, 0x0062, 0x0063}));
}
TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
EXPECT_DEATH(Builder::GetPrecompiledCharsMap(""));
EXPECT_DEATH(Builder::GetPrecompiledCharsMap("__UNKNOWN__"));
}
TEST(BuilderTest, BuildIdentityMapTest) {
const auto m = Builder::BuildIdentityMap();
EXPECT_EQ(1, m.size());
}
TEST(BuilderTest, BuildNFKCMapTest) {
#ifdef ENABLE_NFKC_COMPILE
const auto m = Builder::BuildNFKCMap();
EXPECT_TRUE(!m.empty());
#else
EXPECT_DEATH(Builder::BuildNFKCMap());
#endif
}
TEST(BuilderTest, GetPrecompiledCharsMapTest) {
{
const NormalizerSpec spec = Builder::GetNormalizerSpec("nfkc");
const Normalizer normalizer(spec);
EXPECT_EQ(WS "ABC", normalizer.Normalize(""));
EXPECT_EQ(WS "(株)", normalizer.Normalize(""));
EXPECT_EQ(WS "グーグル", normalizer.Normalize("グーグル"));
}
{
const NormalizerSpec spec = Builder::GetNormalizerSpec("identity");
const Normalizer normalizer(spec);
EXPECT_EQ(WS "", normalizer.Normalize(""));
EXPECT_EQ(WS "", normalizer.Normalize(""));
EXPECT_EQ(WS "グーグル", normalizer.Normalize("グーグル"));
}
}
TEST(BuilderTest, CompileCharsMap) {
Builder::CharsMap chars_map;
// Lowercase => Uppercase
for (char32 lc = static_cast<char32>('a'); lc <= static_cast<char32>('z');
++lc) {
const char32 uc = lc + 'A' - 'a';
chars_map[{lc}] = {uc};
}
// あいう => abc
chars_map[{0x3042, 0x3044, 0x3046}] = {0x0061, 0x0062, 0x0063};
NormalizerSpec spec;
spec.set_precompiled_charsmap(Builder::CompileCharsMap(chars_map));
spec.set_add_dummy_prefix(false);
const Normalizer normalizer(spec);
EXPECT_EQ("ABC", normalizer.Normalize("abc"));
EXPECT_EQ("ABC", normalizer.Normalize("ABC"));
EXPECT_EQ("XY" WS "Z", normalizer.Normalize("xy z"));
EXPECT_EQ("", normalizer.Normalize(""));
EXPECT_EQ("abc", normalizer.Normalize("あいう"));
EXPECT_EQ("abcえ", normalizer.Normalize("あいうえ"));
EXPECT_EQ("ABCabcD", normalizer.Normalize("abcあいうd"));
}
TEST(BuilderTest, BuildMapFromFileTest) {
const auto cmap = Builder::BuildMapFromFile("../data/nfkc.tsv");
const auto precompiled = Builder::CompileCharsMap(cmap);
EXPECT_EQ(Builder::GetPrecompiledCharsMap("nfkc"), precompiled);
}
TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
Builder::CharsMap chars_map;
std::vector<char32> keys;
// chars_map contains too many shared prefix ("aaaa...");
for (int i = 0; i < 100; ++i) {
keys.push_back('a');
chars_map[keys] = {'b'};
}
EXPECT_DEATH(Builder::CompileCharsMap(chars_map));
}
} // namespace normalizer
} // namespace sentencepiece

65
src/char_model.cc Normal file
View File

@ -0,0 +1,65 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "char_model.h"
#include "util.h"
namespace sentencepiece {
namespace character {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
CheckControlSymbols();
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
CHECK(!sp.piece().empty());
if (sp.type() == ModelProto::SentencePiece::NORMAL ||
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
CHECK(sp.has_score());
port::InsertOrDie(&pieces_, sp.piece(), i);
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
}
}
}
Model::~Model() {}
std::vector<std::pair<StringPiece, int>> Model::Encode(
StringPiece normalized) const {
if (normalized.empty()) {
return {};
}
// Splits the input into character sequence
const char *begin = normalized.data();
const char *end = normalized.data() + normalized.size();
std::vector<std::pair<StringPiece, int>> output;
while (begin < end) {
int mblen = string_util::OneCharLen(begin);
if (mblen > end - begin) {
LOG(ERROR) << "Invalid character length.";
mblen = end - begin;
}
StringPiece w(begin, mblen);
output.emplace_back(w, PieceToId(w));
begin += mblen;
}
return output;
}
} // namespace character
} // namespace sentencepiece

35
src/char_model.h Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef CHAR_MODEL_H_
#define CHAR_MODEL_H_
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
namespace sentencepiece {
namespace character {
// Tokenize text into character sequence
class Model : public ModelInterface {
public:
explicit Model(const ModelProto &model_proto);
~Model() override;
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const override;
};
} // namespace character
} // namespace sentencepiece
#endif // CHAR_MODEL_H_

94
src/char_model_test.cc Normal file
View File

@ -0,0 +1,94 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "char_model.h"
#include <unordered_map>
#include <unordered_set>
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace character {
namespace {
// Space symbol (U+2581)
#define WS "\xe2\x96\x81"
ModelProto MakeBaseModelProto() {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(ModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, WS, 0.0);
AddPiece(&model_proto, "a", 0.1);
AddPiece(&model_proto, "b", 0.2);
AddPiece(&model_proto, "c", 0.3);
AddPiece(&model_proto, "d", 0.4);
const Model model(model_proto);
std::vector<std::pair<StringPiece, int>> result;
result = model.Encode("");
EXPECT_TRUE(result.empty());
result = model.Encode(WS "a" WS "b" WS "c");
EXPECT_EQ(6, result.size());
EXPECT_EQ(WS, result[0].first);
EXPECT_EQ("a", result[1].first);
EXPECT_EQ(WS, result[2].first);
EXPECT_EQ("b", result[3].first);
EXPECT_EQ(WS, result[4].first);
EXPECT_EQ("c", result[5].first);
result = model.Encode(WS "ab" WS "cd" WS "abc");
EXPECT_EQ(10, result.size());
EXPECT_EQ(WS, result[0].first);
EXPECT_EQ("a", result[1].first);
EXPECT_EQ("b", result[2].first);
EXPECT_EQ(WS, result[3].first);
EXPECT_EQ("c", result[4].first);
EXPECT_EQ("d", result[5].first);
EXPECT_EQ(WS, result[6].first);
EXPECT_EQ("a", result[7].first);
EXPECT_EQ("b", result[8].first);
EXPECT_EQ("c", result[9].first);
}
} // namespace
} // namespace character
} // namespace sentencepiece

67
src/char_model_trainer.cc Normal file
View File

@ -0,0 +1,67 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "char_model_trainer.h"
#include "char_model.h"
#include "util.h"
namespace sentencepiece {
namespace character {
void Trainer::Train() {
#define CHECK_RANGE(variable, minval, maxval) \
CHECK(variable >= minval && variable <= maxval)
CHECK_GT(trainer_spec_.input().size(), 0);
CHECK(!trainer_spec_.model_prefix().empty());
CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0);
CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000);
CHECK_GT(trainer_spec_.vocab_size(), 0);
#undef CHECK_RANGE
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
CHECK(normalizer_spec_.escape_whitespaces());
CHECK_EQ(TrainerSpec::CHAR, trainer_spec_.model_type());
LoadSentences();
const int meta_symbols_size = trainer_spec_.control_symbols().size() +
trainer_spec_.user_defined_symbols().size() +
3; // <s>, </s>, <unk>
const int vocab_size = trainer_spec_.vocab_size() - meta_symbols_size;
CHECK_GE(vocab_size, 0);
uint64 sum = 0;
for (const auto &it : required_chars_) {
sum += it.second;
}
const float logsum = log(sum);
CHECK(final_pieces_.empty());
for (const auto &it : Sorted(required_chars_)) {
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
break;
}
final_pieces_.emplace_back(string_util::UnicodeCharToUTF8(it.first),
log(it.second) - logsum);
}
Save();
}
} // namespace character
} // namespace sentencepiece

35
src/char_model_trainer.h Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef CHAR_MODEL_TRAINER_H_
#define CHAR_MODEL_TRAINER_H_
#include "sentencepiece_model.pb.h"
#include "trainer_interface.h"
namespace sentencepiece {
namespace character {
// Trainer class for character model.
class Trainer : public TrainerInterface {
public:
Trainer(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
void Train() override;
};
} // namespace character
} // namespace sentencepiece
#endif // CHAR_MODEL_TRAINER_H_

View File

@ -0,0 +1,72 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "char_model_trainer.h"
#include "builder.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace character {
namespace {
// Space symbol (U+2581)
#define WS "\xE2\x96\x81"
std::string RunTrainer(const std::vector<std::string> &input, int size) {
test::ScopedTempFile input_scoped_file("input");
test::ScopedTempFile model_scoped_file("model");
const std::string input_file = input_scoped_file.filename();
const std::string model_prefix = model_scoped_file.filename();
{
io::OutputBuffer output(input_file);
for (const auto &line : input) {
output.WriteLine(line);
}
}
TrainerSpec trainer_spec;
trainer_spec.set_model_type(TrainerSpec::CHAR);
trainer_spec.add_input(input_file);
trainer_spec.set_vocab_size(size);
trainer_spec.set_model_prefix(model_prefix);
auto normalizer_spec = normalizer::Builder::GetNormalizerSpec("identity");
normalizer_spec.set_add_dummy_prefix(true);
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
SentencePieceProcessor processor;
processor.Load(model_prefix + ".model");
const auto &model = processor.model_proto();
std::vector<std::string> pieces;
// remove <unk>, <s>, </s>
for (int i = 3; i < model.pieces_size(); ++i) {
pieces.emplace_back(model.pieces(i).piece());
}
return string_util::Join(pieces, " ");
}
} // namespace
TEST(TrainerTest, BasicTest) {
EXPECT_EQ(WS " a e p n I h l v",
RunTrainer({"I have a pen", "I have an apple", "apple pen"}, 100));
}
} // namespace character
} // namespace sentencepiece

177
src/common.h Normal file
View File

@ -0,0 +1,177 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef COMMON_H_
#define COMMON_H_
#include <setjmp.h>
#include <stdint.h>
#include <stdlib.h>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
#if defined(_WIN32) && !defined(__CYGWIN__)
#define OS_WIN
#else
#define OS_UNIX
#endif
#ifdef OS_WIN
#define NOMINMAX
#include <windows.h>
#endif
typedef int8_t int8;
typedef int16_t int16;
typedef int32_t int32;
typedef int64_t int64;
typedef uint8_t uint8;
typedef uint16_t uint16;
typedef uint32_t char32;
typedef uint32_t uint32;
typedef uint64_t uint64;
static const uint8 kuint8max = ((uint8)0xFF);
static const uint16 kuint16max = ((uint16)0xFFFF);
static const uint32 kuint32max = ((uint32)0xFFFFFFFF);
static const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFF));
static const int8 kint8min = ((int8)~0x7F);
static const int8 kint8max = ((int8)0x7F);
static const int16 kint16min = ((int16)~0x7FFF);
static const int16 kint16max = ((int16)0x7FFF);
static const int32 kint32min = ((int32)~0x7FFFFFFF);
static const int32 kint32max = ((int32)0x7FFFFFFF);
static const int64 kint64min = ((int64)(~0x7FFFFFFFFFFFFFFF));
static const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFF));
#ifdef OS_WIN
#define OUTPUT_MODE std::ios::binary | std::ios::out
#else
#define OUTPUT_MODE std::ios::out
#endif
#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
#define WPATH(path) (sentencepiece::win32::Utf8ToWide(path).c_str())
#else
#define WPATH(path) (path)
#endif
template <typename T, size_t N>
char (&ArraySizeHelper(T (&array)[N]))[N];
#ifndef _MSC_VER
template <typename T, size_t N>
char (&ArraySizeHelper(const T (&array)[N]))[N];
#endif // !_MSC_VER
#define arraysize(array) (sizeof(ArraySizeHelper(array)))
namespace sentencepiece {
#ifdef OS_WIN
namespace win32 {
std::wstring Utf8ToWide(const std::string &input);
std::string WideToUtf8(const std::wstring &input);
} // namespace win32
#endif
namespace error {
extern jmp_buf gTestJmp;
extern bool gTestMode;
inline void Abort() {
if (error::gTestMode) {
longjmp(error::gTestJmp, 0);
} else {
abort();
}
}
inline void Exit(int code) {
if (error::gTestMode) {
longjmp(error::gTestJmp, 0);
} else {
exit(code);
}
}
class Die {
public:
explicit Die(bool die) : die_(die) {}
~Die() {
std::cerr << std::endl;
if (die_) {
Abort();
}
}
int operator&(std::ostream &) { return 0; }
private:
bool die_;
};
template <typename T>
T &&CheckNotNull(const char *file, int line, const char *exprtext, T &&t) {
if (t == nullptr) {
std::cerr << file << "(" << line << ") " << exprtext;
Abort();
}
return std::forward<T>(t);
}
} // namespace error
namespace logging {
enum LogSeverity {
LOG_INFO = 0,
LOG_WARNING = 1,
LOG_ERROR = 2,
LOG_FATAL = 3,
LOG_SEVERITY_SIZE = 4,
};
} // namespace logging
} // namespace sentencepiece
#define LOG(severity) \
sentencepiece::error::Die(sentencepiece::logging::LOG_##severity >= \
sentencepiece::logging::LOG_FATAL) & \
std::cerr << __FILE__ << "(" << __LINE__ << ") " \
<< "LOG(" << #severity << ") "
#define CHECK(condition) \
(condition) ? 0 \
: sentencepiece::error::Die(true) & \
std::cerr << __FILE__ << "(" << __LINE__ << ") [" \
<< #condition << "] "
#define CHECK_IFS(a, b) CHECK((a)) << "No such file or directory: [" << b << "]"
#define CHECK_OFS(a, b) CHECK((a)) << "Permission denied: [" << b << "]"
#define CHECK_STREQ(a, b) CHECK_EQ(std::string(a), std::string(b))
#define CHECK_EQ(a, b) CHECK((a) == (b))
#define CHECK_NE(a, b) CHECK((a) != (b))
#define CHECK_GE(a, b) CHECK((a) >= (b))
#define CHECK_LE(a, b) CHECK((a) <= (b))
#define CHECK_GT(a, b) CHECK((a) > (b))
#define CHECK_LT(a, b) CHECK((a) < (b))
#define CHECK_NOTNULL(val) \
sentencepiece::error::CheckNotNull(__FILE__, __LINE__, \
"'" #val "' Must be non NULL", (val))
#define FRIEND_TEST(a, b) friend class a##_Test_##b;
#endif // COMMON_H_

View File

@ -0,0 +1,129 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <iostream>
#include <sstream>
#include <string>
#include "builder.h"
#include "flags.h"
#include "stringpiece.h"
#include "util.h"
using sentencepiece::normalizer::Builder;
DEFINE_bool(output_precompiled_header, false, "make normalization_rule.h file");
namespace sentencepiece {
namespace {
void WriteTSV(const Builder::CharsMap &cmap, StringPiece filename) {
sentencepiece::io::OutputBuffer output(filename);
for (const auto &c : cmap) {
std::vector<std::string> src, trg;
for (char32 v : c.first) {
src.push_back(string_util::IntToHex(v));
}
for (char32 v : c.second) {
trg.push_back(string_util::IntToHex(v));
}
std::string line = string_util::Join(src, " ") + "\t" +
string_util::Join(trg, " ") + "\t# " +
string_util::UnicodeTextToUTF8(c.first) + " => " +
string_util::UnicodeTextToUTF8(c.second);
line = string_util::StringReplace(line, "\n", " ", true);
line = string_util::StringReplace(line, "\r", " ", true);
output.WriteLine(line);
}
}
std::string ToHexData(StringPiece data) {
const char *begin = data.data();
const char *end = data.data() + data.size();
constexpr char kHex[] = "0123456789ABCDEF";
constexpr size_t kNumOfBytesOnOneLine = 20;
size_t output_count = 0;
std::stringstream os;
while (begin < end) {
const size_t bucket_size =
std::min<size_t>(end - begin, kNumOfBytesOnOneLine -
output_count % kNumOfBytesOnOneLine);
if (output_count % kNumOfBytesOnOneLine == 0) {
os << "\"";
}
for (size_t i = 0; i < bucket_size; ++i) {
os << "\\x" << kHex[(*begin & 0xF0) >> 4] << kHex[(*begin & 0x0F) >> 0];
++begin;
}
output_count += bucket_size;
if (output_count % kNumOfBytesOnOneLine == 0) {
os << "\"\n";
}
}
os << "\"\n";
return os.str();
}
} // namespace
} // sentencepiece
int main(int argc, char **argv) {
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
const std::vector<std::pair<std::string, std::function<Builder::CharsMap()>>>
kRuleList = {{"nfkc", Builder::BuildNFKCMap},
{"identity", Builder::BuildIdentityMap}};
constexpr char kHeader[] =
R"(#ifndef NORMALIZATION_RULE_H_
#define NORMALIZATION_RULE_H_
#include <cstdio>
namespace sentencepiece {
namespace {
struct BinaryBlob {
const char *name;
size_t size;
const char *data;
};
constexpr BinaryBlob kNormalizationRules_blob[] = {)";
constexpr char kFooter[] = R"(
} // namespace
} // namespace sentencepiece
#endif // NORMALIZATION_RULE_H_)";
std::stringstream os;
os << kHeader;
for (const auto &p : kRuleList) {
const auto normalized_map = p.second();
const auto index = Builder::CompileCharsMap(normalized_map);
os << "{ \"" << p.first << "\", " << index.size() << ",\n";
os << sentencepiece::ToHexData(index);
os << " },";
sentencepiece::WriteTSV(normalized_map, p.first + ".tsv");
}
os << "};\n";
os << "constexpr size_t kNormalizationRules_size = " << kRuleList.size()
<< ";\n";
os << kFooter;
if (FLAGS_output_precompiled_header) {
constexpr char kPrecompiledHeaderFileName[] = "normalization_rule.h";
sentencepiece::io::OutputBuffer output(kPrecompiledHeaderFileName);
output.Write(os.str());
}
return 0;
}

23
src/error.cc Normal file
View File

@ -0,0 +1,23 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <cstring>
#include "common.h"
namespace sentencepiece {
namespace error {
jmp_buf gTestJmp;
bool gTestMode = false;
} // namespace error
} // namespace sentencepiece

238
src/flags.cc Normal file
View File

@ -0,0 +1,238 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "flags.h"
#include "common.h"
#include <algorithm>
#include <cctype>
#include <iostream>
#include <map>
#include <sstream>
#include <string>
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif
namespace sentencepiece {
namespace flags {
struct Flag {
int type;
void *storage;
const void *default_storage;
std::string help;
};
namespace {
using FlagMap = std::map<std::string, Flag *>;
FlagMap *GetFlagMap() {
static FlagMap flag_map;
return &flag_map;
}
bool IsTrue(const std::string &value) {
const char *kTrue[] = {"1", "t", "true", "y", "yes"};
const char *kFalse[] = {"0", "f", "false", "n", "no"};
std::string lower_value = value;
std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(),
::tolower);
for (size_t i = 0; i < 5; ++i) {
if (lower_value == kTrue[i]) {
return true;
} else if (lower_value == kFalse[i]) {
return false;
}
}
LOG(FATAL) << "cannot parse boolean value: " << value;
return false;
}
bool SetFlag(const std::string &name, const std::string &value) {
auto it = GetFlagMap()->find(name);
if (it == GetFlagMap()->end()) {
return false;
}
std::string v = value;
Flag *flag = it->second;
// If empty value is set, we assume true or emtpy string is set
// for boolean or string option. With other types, setting fails.
if (value.empty()) {
switch (flag->type) {
case B:
v = "true";
break;
case S:
v = "";
break;
default:
return false;
}
}
switch (flag->type) {
case I:
*reinterpret_cast<int32 *>(flag->storage) = atoi(v.c_str());
break;
case B:
*(reinterpret_cast<bool *>(flag->storage)) = IsTrue(v);
break;
case I64:
*reinterpret_cast<int64 *>(flag->storage) = atoll(v.c_str());
break;
case U64:
*reinterpret_cast<uint64 *>(flag->storage) = atoll(v.c_str());
break;
case D:
*reinterpret_cast<double *>(flag->storage) = strtod(v.c_str(), nullptr);
break;
case S:
*reinterpret_cast<std::string *>(flag->storage) = v;
break;
default:
break;
}
return true;
}
bool CommandLineGetFlag(int argc, char **argv, std::string *key,
std::string *value, int *used_args) {
key->clear();
value->clear();
*used_args = 1;
const char *start = argv[0];
if (start[0] != '-') {
return false;
}
++start;
if (start[0] == '-') ++start;
const std::string arg = start;
const size_t n = arg.find("=");
if (n != std::string::npos) {
*key = arg.substr(0, n);
*value = arg.substr(n + 1, arg.size() - n);
return true;
}
key->assign(arg);
value->clear();
if (argc == 1) {
return true;
}
start = argv[1];
if (start[0] == '-') {
return true;
}
*used_args = 2;
value->assign(start);
return true;
}
} // namespace
FlagRegister::FlagRegister(const char *name, void *storage,
const void *default_storage, int shortype,
const char *help)
: flag_(new Flag) {
flag_->type = shortype;
flag_->storage = storage;
flag_->default_storage = default_storage;
flag_->help = help;
GetFlagMap()->insert(std::make_pair(std::string(name), flag_.get()));
}
FlagRegister::~FlagRegister() {}
std::string PrintHelp(const char *programname) {
std::ostringstream os;
os << PACKAGE_STRING << "\n\n";
os << "Usage: " << programname << " [options] files\n\n";
for (const auto &it : *GetFlagMap()) {
os << " --" << it.first << " (" << it.second->help << ")";
const Flag *flag = it.second;
switch (flag->type) {
case I:
os << " type: int32 default: "
<< *(reinterpret_cast<const int *>(flag->default_storage)) << '\n';
break;
case B:
os << " type: bool default: "
<< (*(reinterpret_cast<const bool *>(flag->default_storage))
? "true"
: "false")
<< '\n';
break;
case I64:
os << " type: int64 default: "
<< *(reinterpret_cast<const int64 *>(flag->default_storage)) << '\n';
break;
case U64:
os << " type: uint64 default: "
<< *(reinterpret_cast<const uint64 *>(flag->default_storage))
<< '\n';
break;
case D:
os << " type: double default: "
<< *(reinterpret_cast<const double *>(flag->default_storage))
<< '\n';
break;
case S:
os << " type: string default: "
<< *(reinterpret_cast<const std::string *>(flag->default_storage))
<< '\n';
break;
default:
break;
}
}
os << "\n\n";
return os.str();
}
void ParseCommandLineFlags(int argc, char **argv,
std::vector<std::string> *rest_flags) {
int used_argc = 0;
std::string key, value;
for (int i = 1; i < argc; i += used_argc) {
if (!CommandLineGetFlag(argc - i, argv + i, &key, &value, &used_argc)) {
if (rest_flags) rest_flags->push_back(std::string(argv[i]));
continue;
}
if (key == "help") {
std::cout << PrintHelp(argv[0]);
error::Exit(0);
} else if (key == "version") {
std::cout << PACKAGE_STRING << " " << VERSION << std::endl;
error::Exit(0);
} else if (!SetFlag(key, value)) {
std::cerr << "Unknown/Invalid flag " << key << "\n\n"
<< PrintHelp(argv[0]);
error::Exit(1);
}
}
}
} // namespace flags
} // namespace sentencepiece

95
src/flags.h Normal file
View File

@ -0,0 +1,95 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef FLAGS_H_
#define FLAGS_H_
#include <memory>
#include <string>
#include <vector>
namespace sentencepiece {
namespace flags {
enum { I, B, I64, U64, D, S };
struct Flag;
class FlagRegister {
public:
FlagRegister(const char *name, void *storage, const void *default_storage,
int shorttpe, const char *help);
~FlagRegister();
private:
std::unique_ptr<Flag> flag_;
};
std::string PrintHelp(const char *programname);
void ParseCommandLineFlags(int argc, char **argv,
std::vector<std::string> *rest_args = nullptr);
} // namespace flags
} // namespace sentencepiece
#define DEFINE_VARIABLE(type, shorttype, name, value, help) \
namespace sentencepiece_flags_fL##shorttype { \
using namespace sentencepiece::flags; \
type FLAGS_##name = value; \
static const type FLAGS_DEFAULT_##name = value; \
static const sentencepiece::flags::FlagRegister fL##name( \
#name, reinterpret_cast<void *>(&FLAGS_##name), \
reinterpret_cast<const void *>(&FLAGS_DEFAULT_##name), shorttype, \
help); \
} \
using sentencepiece_flags_fL##shorttype::FLAGS_##name
#define DECLARE_VARIABLE(type, shorttype, name) \
namespace sentencepiece_flags_fL##shorttype { \
extern type FLAGS_##name; \
} \
using sentencepiece_flags_fL##shorttype::FLAGS_##name
#define DEFINE_int32(name, value, help) \
DEFINE_VARIABLE(int32, I, name, value, help)
#define DECLARE_int32(name) DECLARE_VARIABLE(int32, I, name)
#define DEFINE_int64(name, value, help) \
DEFINE_VARIABLE(int64, I64, name, value, help)
#define DECLARE_int64(name) DECLARE_VARIABLE(int64, I64, name)
#define DEFINE_uint64(name, value, help) \
DEFINE_VARIABLE(uint64, U64, name, value, help)
#define DECLARE_uint64(name) DECLARE_VARIABLE(uint64, U64, name)
#define DEFINE_double(name, value, help) \
DEFINE_VARIABLE(double, D, name, value, help)
#define DECLARE_double(name) DECLARE_VARIABLE(double, D, name)
#define DEFINE_bool(name, value, help) \
DEFINE_VARIABLE(bool, B, name, value, help)
#define DECLARE_bool(name) DECLARE_VARIABLE(bool, B, name)
#define DEFINE_string(name, value, help) \
DEFINE_VARIABLE(std::string, S, name, value, help)
#define DECLARE_string(name) DECLARE_VARIABLE(std::string, S, name)
#define CHECK_OR_HELP(flag) \
if (FLAGS_##flag.empty()) { \
std::cout << "ERROR: --" << #flag << " must not be empty\n\n"; \
std::cout << sentencepiece::flags::PrintHelp(PACKAGE_STRING); \
sentencepiece::error::Exit(0); \
}
#endif // FLAGS_H_

138
src/flags_test.cc Normal file
View File

@ -0,0 +1,138 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "flags.h"
#include "common.h"
#include "testharness.h"
DEFINE_int32(int32_f, 10, "int32_flags");
DEFINE_bool(bool_f, false, "bool_flags");
DEFINE_int64(int64_f, 20, "int64_flags");
DEFINE_uint64(uint64_f, 30, "uint64_flags");
DEFINE_double(double_f, 40.0, "double_flags");
DEFINE_string(string_f, "str", "string_flags");
namespace sentencepiece {
namespace flags {
TEST(FlagsTest, DefaultValueTest) {
EXPECT_EQ(10, FLAGS_int32_f);
EXPECT_EQ(false, FLAGS_bool_f);
EXPECT_EQ(20, FLAGS_int64_f);
EXPECT_EQ(30, FLAGS_uint64_f);
EXPECT_EQ(40.0, FLAGS_double_f);
EXPECT_EQ("str", FLAGS_string_f);
}
TEST(FlagsTest, PrintHelpTest) {
const std::string help = PrintHelp("foobar");
EXPECT_NE(std::string::npos, help.find("foobar"));
EXPECT_NE(std::string::npos, help.find("int32_flags"));
EXPECT_NE(std::string::npos, help.find("bool_flags"));
EXPECT_NE(std::string::npos, help.find("int64_flags"));
EXPECT_NE(std::string::npos, help.find("uint64_flags"));
EXPECT_NE(std::string::npos, help.find("double_flags"));
EXPECT_NE(std::string::npos, help.find("string_flags"));
}
TEST(FlagsTest, ParseCommandLineFlagsTest) {
const char *kFlags[] = {"program", "--int32_f=100", "other1",
"--bool_f=true", "--int64_f=200", "--uint64_f=300",
"--double_f=400", "--string_f=foo", "other2",
"other3"};
std::vector<std::string> rest;
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags), &rest);
EXPECT_EQ(100, FLAGS_int32_f);
EXPECT_EQ(true, FLAGS_bool_f);
EXPECT_EQ(200, FLAGS_int64_f);
EXPECT_EQ(300, FLAGS_uint64_f);
EXPECT_EQ(400.0, FLAGS_double_f);
EXPECT_EQ("foo", FLAGS_string_f);
EXPECT_EQ(3, rest.size());
EXPECT_EQ("other1", rest[0]);
EXPECT_EQ("other2", rest[1]);
EXPECT_EQ("other3", rest[2]);
}
TEST(FlagsTest, ParseCommandLineFlagsTest2) {
const char *kFlags[] = {"program", "--int32_f", "500",
"-int64_f=600", "-uint64_f", "700",
"--bool_f=FALSE"};
std::vector<std::string> rest;
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags), &rest);
EXPECT_EQ(500, FLAGS_int32_f);
EXPECT_EQ(600, FLAGS_int64_f);
EXPECT_EQ(700, FLAGS_uint64_f);
EXPECT_FALSE(FLAGS_bool_f);
EXPECT_TRUE(rest.empty());
}
TEST(FlagsTest, ParseCommandLineFlagsTest3) {
const char *kFlags[] = {"program", "--bool_f", "--int32_f", "800"};
std::vector<std::string> rest;
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags), &rest);
EXPECT_TRUE(FLAGS_bool_f);
EXPECT_EQ(800, FLAGS_int32_f);
EXPECT_TRUE(rest.empty());
}
TEST(FlagsTest, ParseCommandLineFlagsHelpTest) {
const char *kFlags[] = {"program", "--help"};
EXPECT_DEATH(
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags)));
}
TEST(FlagsTest, ParseCommandLineFlagsVersionTest) {
const char *kFlags[] = {"program", "--version"};
EXPECT_DEATH(
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags)));
}
TEST(FlagsTest, ParseCommandLineFlagsUnknownTest) {
const char *kFlags[] = {"program", "--foo"};
EXPECT_DEATH(
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags)));
}
TEST(FlagsTest, ParseCommandLineFlagsInvalidBoolTest) {
const char *kFlags[] = {"program", "--bool_f=X"};
EXPECT_DEATH(
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags)));
}
TEST(FlagsTest, ParseCommandLineFlagsEmptyStringArgs) {
const char *kFlags[] = {"program", "--string_f="};
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags));
EXPECT_EQ("", FLAGS_string_f);
}
TEST(FlagsTest, ParseCommandLineFlagsEmptyBoolArgs) {
const char *kFlags[] = {"program", "--bool_f"};
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags));
EXPECT_TRUE(FLAGS_bool_f);
}
TEST(FlagsTest, ParseCommandLineFlagsEmptyIntArgs) {
const char *kFlags[] = {"program", "--int32_f"};
EXPECT_DEATH(
ParseCommandLineFlags(arraysize(kFlags), const_cast<char **>(kFlags)));
}
} // namespace flags
} // namespace sentencepiece

50
src/model_factory.cc Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "model_factory.h"
#include "bpe_model.h"
#include "char_model.h"
#include "unigram_model.h"
#include "util.h"
#include "word_model.h"
namespace sentencepiece {
// Instantiate Model instance from |model_proto|
std::unique_ptr<ModelInterface> ModelFactory::Create(
const ModelProto& model_proto) {
const auto& trainer_spec = model_proto.trainer_spec();
switch (trainer_spec.model_type()) {
case TrainerSpec::UNIGRAM:
return port::MakeUnique<unigram::Model>(model_proto);
break;
case TrainerSpec::BPE:
return port::MakeUnique<bpe::Model>(model_proto);
break;
case TrainerSpec::WORD:
return port::MakeUnique<word::Model>(model_proto);
break;
case TrainerSpec::CHAR:
return port::MakeUnique<character::Model>(model_proto);
break;
default:
LOG(FATAL) << "Unknown model_type: " << trainer_spec.model_type();
break;
}
return port::MakeUnique<unigram::Model>(model_proto);
}
} // namespace sentencepiece

30
src/model_factory.h Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef MODEL_FACTORY_H_
#define MODEL_FACTORY_H_
#include <memory>
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
namespace sentencepiece {
class ModelFactory {
public:
// Creates Model instance from |model_proto|.
static std::unique_ptr<ModelInterface> Create(const ModelProto &model_proto);
};
} // namespace sentencepiece
#endif // MODEL_FACTORY_H_

58
src/model_factory_test.cc Normal file
View File

@ -0,0 +1,58 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "model_factory.h"
#include "testharness.h"
namespace sentencepiece {
TEST(ModelFactoryTest, BasicTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
auto *sp4 = model_proto.add_pieces();
sp4->set_piece("test");
sp4->set_score(1.0);
{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::UNIGRAM);
auto m = ModelFactory::Create(model_proto);
}
{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::BPE);
auto m = ModelFactory::Create(model_proto);
}
{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::WORD);
auto m = ModelFactory::Create(model_proto);
}
{
model_proto.mutable_trainer_spec()->set_model_type(TrainerSpec::CHAR);
auto m = ModelFactory::Create(model_proto);
}
}
} // namespace sentencepiece

101
src/model_interface.cc Normal file
View File

@ -0,0 +1,101 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
#include "util.h"
namespace sentencepiece {
const uint32 ModelInterface::kUnkID = 0;
ModelInterface::ModelInterface(const ModelProto &model_proto)
: model_proto_(&model_proto) {}
ModelInterface::~ModelInterface() {}
int ModelInterface::PieceToId(StringPiece piece) const {
auto it = reserved_id_map_.find(piece);
if (it != reserved_id_map_.end()) {
return it->second;
}
auto it2 = pieces_.find(piece);
if (it2 != pieces_.end()) {
return it2->second;
}
return kUnkID;
}
int ModelInterface::GetPieceSize() const {
return CHECK_NOTNULL(model_proto_)->pieces_size();
}
std::string ModelInterface::IdToPiece(int id) const {
return CHECK_NOTNULL(model_proto_)->pieces(id).piece();
}
float ModelInterface::GetScore(int id) const {
return CHECK_NOTNULL(model_proto_)->pieces(id).score();
}
bool ModelInterface::IsControl(int id) const {
return (CHECK_NOTNULL(model_proto_)->pieces(id).type() ==
ModelProto::SentencePiece::CONTROL);
}
bool ModelInterface::IsUnknown(int id) const {
return (CHECK_NOTNULL(model_proto_)->pieces(id).type() ==
ModelProto::SentencePiece::UNKNOWN);
}
void ModelInterface::CheckControlSymbols() const {
CHECK_NOTNULL(model_proto_);
CHECK_GE(model_proto_->pieces_size(), 3); // <unk>, <s>, </s>
// Verify reserved control symbols and unknon symbol.
CHECK_EQ(ModelProto::SentencePiece::UNKNOWN, // <unk>
model_proto_->pieces(0).type());
CHECK_EQ("<unk>", model_proto_->pieces(0).piece());
CHECK_EQ(ModelProto::SentencePiece::CONTROL, // <s>
model_proto_->pieces(1).type());
CHECK_EQ("<s>", model_proto_->pieces(1).piece());
CHECK_EQ(ModelProto::SentencePiece::CONTROL, // </s>
model_proto_->pieces(2).type());
CHECK_EQ("</s>", model_proto_->pieces(2).piece());
}
std::vector<StringPiece> SplitIntoWords(StringPiece text) {
const char *begin = text.data();
const char *end = text.data() + text.size();
// Space symbol (U+2581)
const StringPiece kSpaceSymbol = "\xe2\x96\x81";
std::vector<StringPiece> result;
while (begin < end) {
const int mblen =
std::min<int>(string_util::OneCharLen(begin), end - begin);
if (begin == text.data() || StringPiece(begin, mblen) == kSpaceSymbol) {
result.emplace_back(begin, 0); // add empty string piece.
}
CHECK(!result.empty());
result.back() =
StringPiece(result.back().data(), result.back().size() + mblen);
begin += mblen;
}
return result;
}
} // namespace sentencepiece

89
src/model_interface.h Normal file
View File

@ -0,0 +1,89 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef MODEL_INTERFACE_H_
#define MODEL_INTERFACE_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "common.h"
#include "stringpiece.h"
namespace sentencepiece {
// "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"]
std::vector<StringPiece> SplitIntoWords(StringPiece text);
class ModelProto;
// Underlying model interface.
// Given a normalized string, returns a sequence of sentence pieces with ids.
class ModelInterface {
public:
using PieceToIdMap = std::unordered_map<StringPiece, int, StringPieceHash>;
static const uint32 kUnkID;
// |model_proto| should not be deleted until ModelInterface is destroyed.
explicit ModelInterface(const ModelProto &model_proto);
ModelInterface() {}
virtual ~ModelInterface();
virtual const ModelProto &model_proto() const { return *model_proto_; }
// Given a normalized string, returns a sequence of sentence pieces with ids.
// The concatenation of pieces must be the same as |normalized|.
virtual std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const = 0;
// Returns the size of sentence pieces, which is the same
// as the size of vocabulary for NMT.
virtual int GetPieceSize() const;
// Returns the vocab id of |piece|.
// Returns UNK(0) if |piece| is unknown
virtual int PieceToId(StringPiece piece) const;
// Returns the string representation of vocab with |id|.
// id must be 0 <= id < GetPieceSize().
virtual std::string IdToPiece(int id) const;
// Returns the score of |id|.
// Score represents a log probability of the piece.
// We can roughly estimate the unigram frequency of the piece.
virtual float GetScore(int id) const;
// Returns true if |id| is unknown symbol.
virtual bool IsUnknown(int id) const;
// Returns true if |id| is control symbol.
virtual bool IsControl(int id) const;
protected:
void CheckControlSymbols() const;
const ModelProto *model_proto_ = nullptr;
// piece -> id map for normal pieces
PieceToIdMap pieces_;
// piece -> id map for control and unknown
PieceToIdMap reserved_id_map_;
};
} // namespace sentencepiece
#endif // MODEL_INTERFACE_H_

206
src/model_interface_test.cc Normal file
View File

@ -0,0 +1,206 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "model_interface.h"
#include "model_factory.h"
#include "testharness.h"
namespace sentencepiece {
namespace {
#define WS "\xe2\x96\x81"
const std::vector<TrainerSpec::ModelType> kModelTypes = {
TrainerSpec::UNIGRAM, TrainerSpec::BPE, TrainerSpec::WORD,
TrainerSpec::CHAR};
ModelProto MakeBaseModelProto(TrainerSpec::ModelType type) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
model_proto.mutable_trainer_spec()->set_model_type(type);
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(ModelInterfaceTest, SetModelInterfaceTest) {
for (const auto type : kModelTypes) {
ModelProto model_proto = MakeBaseModelProto(type);
AddPiece(&model_proto, "a");
AddPiece(&model_proto, "b");
AddPiece(&model_proto, "c");
AddPiece(&model_proto, "d");
auto model = ModelFactory::Create(model_proto);
EXPECT_EQ(model_proto.SerializeAsString(),
model->model_proto().SerializeAsString());
}
}
TEST(ModelInterfaceTest, PieceToIdTest) {
for (const auto type : kModelTypes) {
ModelProto model_proto = MakeBaseModelProto(type);
AddPiece(&model_proto, "a", 0.1);
AddPiece(&model_proto, "b", 0.2);
AddPiece(&model_proto, "c", 0.3);
AddPiece(&model_proto, "d", 0.4);
auto model = ModelFactory::Create(model_proto);
EXPECT_EQ(model_proto.SerializeAsString(),
model->model_proto().SerializeAsString());
EXPECT_EQ(0, model->PieceToId("<unk>"));
EXPECT_EQ(1, model->PieceToId("<s>"));
EXPECT_EQ(2, model->PieceToId("</s>"));
EXPECT_EQ(3, model->PieceToId("a"));
EXPECT_EQ(4, model->PieceToId("b"));
EXPECT_EQ(5, model->PieceToId("c"));
EXPECT_EQ(6, model->PieceToId("d"));
EXPECT_EQ(0, model->PieceToId("e")); // unk
EXPECT_EQ(0, model->PieceToId("")); // unk
EXPECT_EQ("<unk>", model->IdToPiece(0));
EXPECT_EQ("<s>", model->IdToPiece(1));
EXPECT_EQ("</s>", model->IdToPiece(2));
EXPECT_EQ("a", model->IdToPiece(3));
EXPECT_EQ("b", model->IdToPiece(4));
EXPECT_EQ("c", model->IdToPiece(5));
EXPECT_EQ("d", model->IdToPiece(6));
EXPECT_TRUE(model->IsUnknown(0));
EXPECT_FALSE(model->IsUnknown(1));
EXPECT_FALSE(model->IsUnknown(2));
EXPECT_FALSE(model->IsUnknown(3));
EXPECT_FALSE(model->IsUnknown(4));
EXPECT_FALSE(model->IsUnknown(5));
EXPECT_FALSE(model->IsUnknown(6));
EXPECT_FALSE(model->IsControl(0));
EXPECT_TRUE(model->IsControl(1));
EXPECT_TRUE(model->IsControl(2));
EXPECT_FALSE(model->IsControl(3));
EXPECT_FALSE(model->IsControl(4));
EXPECT_FALSE(model->IsControl(5));
EXPECT_FALSE(model->IsControl(6));
EXPECT_NEAR(0, model->GetScore(0), 0.0001);
EXPECT_NEAR(0, model->GetScore(1), 0.0001);
EXPECT_NEAR(0, model->GetScore(2), 0.0001);
EXPECT_NEAR(0.1, model->GetScore(3), 0.0001);
EXPECT_NEAR(0.2, model->GetScore(4), 0.0001);
EXPECT_NEAR(0.3, model->GetScore(5), 0.0001);
EXPECT_NEAR(0.4, model->GetScore(6), 0.0001);
}
}
std::string RandomString(int length) {
const char kAlphaNum[] =
"0123456789"
"!@#$%^&*"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
const int kAlphaSize = sizeof(kAlphaNum) - 1;
const int size = rand() % length + 1;
std::string result;
for (int i = 0; i < size; ++i) {
result += kAlphaNum[rand() % kAlphaSize];
}
return result;
}
TEST(ModelInterfaceTest, PieceToIdStressTest) {
for (const auto type : kModelTypes) {
for (int i = 0; i < 100; ++i) {
std::unordered_map<std::string, int> expected_p2i;
std::unordered_map<int, std::string> expected_i2p;
ModelProto model_proto = MakeBaseModelProto(type);
for (int n = 0; n < 1000; ++n) {
const std::string piece = RandomString(10);
if (expected_p2i.find(piece) != expected_p2i.end()) {
continue;
}
expected_p2i[piece] = model_proto.pieces_size();
expected_i2p[model_proto.pieces_size()] = piece;
AddPiece(&model_proto, piece);
}
auto model = ModelFactory::Create(model_proto);
for (const auto &it : expected_p2i) {
EXPECT_EQ(it.second, model->PieceToId(it.first));
}
for (const auto &it : expected_i2p) {
EXPECT_EQ(it.second, model->IdToPiece(it.first));
}
}
}
}
TEST(ModelInterfaceTest, SplitIntoWordsTest) {
{
const auto v = SplitIntoWords(WS "this" WS "is" WS "a" WS "pen");
EXPECT_EQ(4, v.size());
EXPECT_EQ(WS "this", v[0]);
EXPECT_EQ(WS "is", v[1]);
EXPECT_EQ(WS "a", v[2]);
EXPECT_EQ(WS "pen", v[3]);
}
{
const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen");
EXPECT_EQ(4, v.size());
EXPECT_EQ("this", v[0]);
EXPECT_EQ(WS "is", v[1]);
EXPECT_EQ(WS "a", v[2]);
EXPECT_EQ(WS "pen", v[3]);
}
{
const auto v = SplitIntoWords(WS "this" WS WS "is");
EXPECT_EQ(3, v.size());
EXPECT_EQ(WS "this", v[0]);
EXPECT_EQ(WS, v[1]);
EXPECT_EQ(WS "is", v[2]);
}
{
const auto v = SplitIntoWords("");
EXPECT_TRUE(v.empty());
}
{
const auto v = SplitIntoWords("hello");
EXPECT_EQ(1, v.size());
EXPECT_EQ("hello", v[0]);
}
}
} // namespace
} // namespace sentencepiece

11958
src/normalization_rule.h Normal file

File diff suppressed because it is too large Load Diff

228
src/normalizer.cc Normal file
View File

@ -0,0 +1,228 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "normalizer.h"
#include "common.h"
#include "stringpiece.h"
#include "third_party/darts_clone/darts.h"
#include "util.h"
namespace sentencepiece {
namespace normalizer {
Normalizer::Normalizer(const NormalizerSpec &spec) : spec_(&spec) {
StringPiece index = spec.precompiled_charsmap();
CHECK(!index.empty());
StringPiece trie_blob, normalized;
DecodePrecompiledCharsMap(index, &trie_blob, &normalized);
// Reads the body of double array.
trie_ = port::MakeUnique<Darts::DoubleArray>();
// The second arg of set_array is not the size of blob,
// but the number of double array units.
trie_->set_array(const_cast<char *>(trie_blob.data()),
trie_blob.size() / trie_->unit_size());
normalized_ = normalized.data();
}
Normalizer::~Normalizer() {}
void Normalizer::Normalize(StringPiece input, std::string *normalized,
std::vector<size_t> *norm_to_orig) const {
CHECK_NOTNULL(norm_to_orig)->clear();
CHECK_NOTNULL(normalized)->clear();
if (input.empty()) {
return;
}
int consumed = 0;
// Ignores heading space.
if (spec_->remove_extra_whitespaces()) {
while (!input.empty()) {
const auto p = NormalizePrefix(input);
if (p.first != " ") {
break;
}
input.remove_prefix(p.second);
consumed += p.second;
}
}
// all chars are whitespace.
if (input.empty()) {
return;
}
// Reserves the output buffer to avoid re-allocations.
const size_t kReservedSize = input.size() * 3;
normalized->reserve(kReservedSize);
norm_to_orig->reserve(kReservedSize);
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK)
// if escape_whitespaces() is set (default = true).
const StringPiece kSpaceSymbol = "\xe2\x96\x81";
// Adds a space symbol as a prefix (default is true)
// With this prefix, "world" and "hello world" are converted into
// "_world" and "_hello_world", which help the trainer to extract
// "_world" as one symbol.
if (spec_->add_dummy_prefix()) {
if (spec_->escape_whitespaces()) {
normalized->append(kSpaceSymbol.data(), kSpaceSymbol.size());
for (size_t n = 0; n < kSpaceSymbol.size(); ++n) {
norm_to_orig->push_back(consumed);
}
} else {
normalized->append(" ");
norm_to_orig->push_back(consumed);
}
}
bool is_prev_space = spec_->remove_extra_whitespaces();
while (!input.empty()) {
auto p = NormalizePrefix(input);
StringPiece sp = p.first;
// Removes heading spaces in sentence piece,
// if the previous sentence piece ends with whitespace.
while (is_prev_space && sp.Consume(" ")) {
}
if (!sp.empty()) {
const char *data = sp.data();
for (size_t n = 0; n < sp.size(); ++n) {
if (spec_->escape_whitespaces() && data[n] == ' ') {
// replace ' ' with kSpaceSymbol.
normalized->append(kSpaceSymbol.data(), kSpaceSymbol.size());
for (size_t m = 0; m < kSpaceSymbol.size(); ++m) {
norm_to_orig->push_back(consumed);
}
} else {
*normalized += data[n];
norm_to_orig->push_back(consumed);
}
}
// Checks whether the last character of sp is whitespace.
is_prev_space = sp.ends_with(" ");
}
consumed += p.second;
input.remove_prefix(p.second);
if (!spec_->remove_extra_whitespaces()) {
is_prev_space = false;
}
}
// Ignores tailing space.
if (spec_->remove_extra_whitespaces()) {
const StringPiece space = spec_->escape_whitespaces() ? kSpaceSymbol : " ";
while (string_util::EndsWith(*normalized, space)) {
const int length = normalized->size() - space.size();
CHECK_GE(length, 0);
consumed = (*norm_to_orig)[length];
normalized->resize(length);
norm_to_orig->resize(length);
}
}
norm_to_orig->push_back(consumed);
CHECK_EQ(norm_to_orig->size(), normalized->size() + 1);
}
std::string Normalizer::Normalize(StringPiece input) const {
std::vector<size_t> norm_to_orig;
std::string normalized;
Normalize(input, &normalized, &norm_to_orig);
return normalized;
}
std::pair<StringPiece, int> Normalizer::NormalizePrefix(
StringPiece input) const {
CHECK(!input.empty());
// Allocates trie_results in stack, which makes the encoding speed 36% faster.
// (38k sentences/sec => 60k sentences/sec).
// Builder checks that the result size never exceeds kMaxTrieResultsSize.
// This array consumes 0.5kByte in stack, which is less than
// default stack frames (16kByte).
Darts::DoubleArray::result_pair_type
trie_results[Normalizer::kMaxTrieResultsSize];
const size_t num_nodes = CHECK_NOTNULL(trie_)->commonPrefixSearch(
input.data(), trie_results, Normalizer::kMaxTrieResultsSize,
input.size());
// Finds the longest rule.
size_t longest_length = 0;
int longest_value = 0;
for (size_t k = 0; k < num_nodes; ++k) {
if (longest_length == 0 || trie_results[k].length > longest_length) {
longest_length = trie_results[k].length; // length of prefix
longest_value = trie_results[k].value; // pointer to |normalized_|.
}
}
std::pair<StringPiece, int> result;
if (longest_length == 0) {
result.second = std::min<int>(
input.size(), std::max<int>(1, string_util::OneCharLen(input.data())));
result.first.set(input.data(), result.second);
} else {
result.second = longest_length;
// No need to pass the size of normalized sentence,
// since |normalized| is delimitered by "\0".
result.first.set(&normalized_[longest_value]);
}
CHECK(!result.first.empty());
CHECK_GT(result.second, 0);
return result;
}
// static
std::string Normalizer::EncodePrecompiledCharsMap(StringPiece trie_blob,
StringPiece normalized) {
// <trie size(4byte)><double array trie><normalized string>
std::string blob;
blob.append(string_util::EncodePOD<uint32>(trie_blob.size()));
blob.append(trie_blob.data(), trie_blob.size());
blob.append(normalized.data(), normalized.size());
return blob;
}
// static
void Normalizer::DecodePrecompiledCharsMap(StringPiece blob,
StringPiece *trie_blob,
StringPiece *normalized) {
uint32 trie_blob_size = 0;
CHECK_GT(blob.size(), sizeof(trie_blob_size));
CHECK(string_util::DecodePOD<uint32>(
StringPiece(blob.data(), sizeof(trie_blob_size)), &trie_blob_size));
CHECK_LT(trie_blob_size, blob.size());
blob.remove_prefix(sizeof(trie_blob_size));
CHECK_NOTNULL(trie_blob)->set(blob.data(), trie_blob_size);
blob.remove_prefix(trie_blob_size);
CHECK_NOTNULL(normalized)->set(blob.data(), blob.size());
}
} // namespace normalizer
} // namespace sentencepiece

105
src/normalizer.h Normal file
View File

@ -0,0 +1,105 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef NORMALIZER_NORMALIZER_H_
#define NORMALIZER_NORMALIZER_H_
#include <memory>
#include <string>
#include "common.h"
#include "sentencepiece_model.pb.h"
#include "stringpiece.h"
#include "third_party/darts_clone/darts.h"
namespace sentencepiece {
namespace normalizer {
// Normalizer implements a simple text normalizer with
// user-defined string-to-string rules and leftmost longest
// matching. The rules of Normalizer are built with
// Builder::CompileCharsMap() method. Pre-compiled rules are
// also available via Builder::GetPrecompiledCharsMap(<name>) method.
//
// The motivation of Normalizer is to make flexible, user-customizable
// and self-contained normalizer. All the logic of normalization is
// encoded in the model proto which allows us to define language/task
// dependent normalization rules without breaking the default rule.
class Normalizer {
public:
// Instantiates Normalizer with |spec|.
// |spec| should not be deleted until Normalizer is destroyed.
explicit Normalizer(const NormalizerSpec &spec);
virtual ~Normalizer();
// Normalizes a plain utf8 string into an internal representation for
// Sentencepiece model. |norm_to_orig| stores the byte-alignment from
// normalized string to the original input.
// This function can do the following normalizations:
// - Character normalization.
// (NFKC / full-width to half-width conversion etc).
// - Adds a prefix space.
// - Replaces a space with a meta symbol.
// - Removing heading, tailing and other redundant spaces.
virtual void Normalize(StringPiece input, std::string *normalized,
std::vector<size_t> *norm_to_orig) const;
// Returns a normalized string without alignments.
// This function is used in sentencepiece training.
virtual std::string Normalize(StringPiece input) const;
friend class Builder;
private:
FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest);
// Normalizes the prefix of |input| and returns the pair of
// normalized prefix and length we must consume after
// normalization.
// Here's the sample code for the full text normalization.
//
// string output;
// StringPiece input = "...";
// while (!input.empty()) {
// const auto p = normalizer.NormalizePrefix(input);
// output.append(p.first.data(), p.first.size());
// input.remove_prefix(p.second);
// }
std::pair<StringPiece, int> NormalizePrefix(StringPiece input) const;
// Encodes trie_blob and normalized string and return compiled blob.
static std::string EncodePrecompiledCharsMap(StringPiece trie_blob,
StringPiece normalized);
// Decodes blob into trie_blob and normalized string.
static void DecodePrecompiledCharsMap(StringPiece blob,
StringPiece *trie_blob,
StringPiece *normalized);
// Maximum size of the return value of Trie, which corresponds
// to the maximum size of shared common prefix in the chars map.
static const int kMaxTrieResultsSize = 32;
// Internal trie for efficient longest matching.
std::unique_ptr<Darts::DoubleArray> trie_;
// "\0" delimitered output string.
// the value of |trie_| stores pointers to this string.
const char *normalized_;
// Spec for normalization.
const NormalizerSpec *spec_;
};
} // namespace normalizer
} // namespace sentencepiece
#endif // NORMALIZER_NORMALIZER_H_

293
src/normalizer_test.cc Normal file
View File

@ -0,0 +1,293 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "normalizer.h"
#include "builder.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace normalizer {
namespace {
// Space symbol
#define WS "\xe2\x96\x81"
NormalizerSpec MakeDefaultSpec() { return Builder::GetNormalizerSpec("nfkc"); }
} // namespace
TEST(NormalizerTest, NormalizeTest) {
auto spec = MakeDefaultSpec();
const Normalizer normalizer(spec);
// Empty strings.
EXPECT_EQ("", normalizer.Normalize(""));
EXPECT_EQ("", normalizer.Normalize(" "));
EXPECT_EQ("", normalizer.Normalize(" "));
// Sentence with heading/tailing/redundant spaces.
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
EXPECT_EQ(WS "ABC", normalizer.Normalize(" ABC "));
EXPECT_EQ(WS "A" WS "B" WS "C", normalizer.Normalize(" A B C "));
EXPECT_EQ(WS "ABC", normalizer.Normalize(" ABC "));
EXPECT_EQ(WS "ABC", normalizer.Normalize(" "));
EXPECT_EQ(WS "ABC", normalizer.Normalize("  ABC"));
EXPECT_EQ(WS "ABC", normalizer.Normalize("  ABC  "));
// NFKC char to char normalization.
EXPECT_EQ(WS "123", normalizer.Normalize("①②③"));
// NFKC char to multi-char normalization.
EXPECT_EQ(WS "株式会社", normalizer.Normalize(""));
// Half width katakana, character composition happens.
EXPECT_EQ(WS "グーグル", normalizer.Normalize(" グーグル "));
EXPECT_EQ(WS "I" WS "saw" WS "a" WS "girl",
normalizer.Normalize(" I saw a   girl  "));
}
TEST(NormalizerTest, NormalizeWithoutDummyPrefixTest) {
auto spec = MakeDefaultSpec();
spec.set_add_dummy_prefix(false);
const Normalizer normalizer(spec);
// Empty strings.
EXPECT_EQ("", normalizer.Normalize(""));
EXPECT_EQ("", normalizer.Normalize(" "));
EXPECT_EQ("", normalizer.Normalize(" "));
// Sentence with heading/tailing/redundant spaces.
EXPECT_EQ("ABC", normalizer.Normalize("ABC"));
EXPECT_EQ("ABC", normalizer.Normalize(" ABC "));
EXPECT_EQ("A" WS "B" WS "C", normalizer.Normalize(" A B C "));
EXPECT_EQ("ABC", normalizer.Normalize(" ABC "));
EXPECT_EQ("ABC", normalizer.Normalize(" "));
EXPECT_EQ("ABC", normalizer.Normalize("  ABC"));
EXPECT_EQ("ABC", normalizer.Normalize("  ABC  "));
}
TEST(NormalizerTest, NormalizeWithoutRemoveExtraWhitespacesTest) {
auto spec = MakeDefaultSpec();
spec.set_remove_extra_whitespaces(false);
const Normalizer normalizer(spec);
// Empty strings.
EXPECT_EQ("", normalizer.Normalize(""));
EXPECT_EQ(WS WS WS WS WS WS WS, normalizer.Normalize(" "));
EXPECT_EQ(WS WS, normalizer.Normalize(" "));
// Sentence with heading/tailing/redundant spaces.
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
EXPECT_EQ(WS WS "ABC" WS, normalizer.Normalize(" ABC "));
EXPECT_EQ(WS WS WS "A" WS WS "B" WS WS "C" WS WS,
normalizer.Normalize(" A B C "));
}
TEST(NormalizerTest, NormalizeWithoutEscapeWhitespacesTest) {
auto spec = MakeDefaultSpec();
spec.set_add_dummy_prefix(false);
spec.set_remove_extra_whitespaces(true);
spec.set_escape_whitespaces(false);
const Normalizer normalizer(spec);
// Empty strings.
EXPECT_EQ("", normalizer.Normalize(""));
EXPECT_EQ("", normalizer.Normalize(" "));
EXPECT_EQ("", normalizer.Normalize(" "));
// Sentence with heading/tailing/redundant spaces.
EXPECT_EQ("ABC", normalizer.Normalize("ABC"));
EXPECT_EQ("ABC", normalizer.Normalize(" ABC "));
EXPECT_EQ("A B C", normalizer.Normalize(" A B C "));
EXPECT_EQ("A B C", normalizer.Normalize("A  B  C"));
}
TEST(NormalizeTest, NomalizeWithSpaceContainedRules) {
Builder::CharsMap charsmap;
auto AddRule = [&](const std::string &src, const std::string &trg) {
Builder::Chars src_chars, trg_chars;
for (const char32 c : string_util::UTF8ToUnicodeText(src)) {
src_chars.push_back(c);
}
for (const char32 c : string_util::UTF8ToUnicodeText(trg)) {
trg_chars.push_back(c);
}
charsmap[src_chars] = trg_chars;
};
// Adds rules containing whitespaes.
AddRule("a", " A");
AddRule("b", "B");
AddRule("c", "D E");
AddRule("d", " F G ");
NormalizerSpec spec;
spec.set_precompiled_charsmap(Builder::CompileCharsMap(charsmap));
// Test default behavior
{
const Normalizer normalizer(spec);
EXPECT_EQ(WS "A", normalizer.Normalize("a"));
EXPECT_EQ(WS "B" WS "A", normalizer.Normalize("ba"));
EXPECT_EQ(WS "D" WS "E", normalizer.Normalize("c"));
EXPECT_EQ(WS "F" WS "G" WS "A", normalizer.Normalize("da"));
EXPECT_EQ(WS "A" WS "F" WS "G", normalizer.Normalize("ad"));
EXPECT_EQ(WS "A" WS "F" WS "G" WS "B", normalizer.Normalize("adb"));
}
spec.set_escape_whitespaces(false);
{
spec.set_add_dummy_prefix(false);
spec.set_remove_extra_whitespaces(true);
const Normalizer normalizer(spec);
EXPECT_EQ("A", normalizer.Normalize("a"));
EXPECT_EQ("B A", normalizer.Normalize("ba"));
EXPECT_EQ("D E", normalizer.Normalize("c"));
EXPECT_EQ("F G A", normalizer.Normalize("da"));
EXPECT_EQ("A F G", normalizer.Normalize("ad"));
EXPECT_EQ("A F G B", normalizer.Normalize("adb"));
}
{
spec.set_add_dummy_prefix(false);
spec.set_remove_extra_whitespaces(false);
const Normalizer normalizer(spec);
EXPECT_EQ(" A", normalizer.Normalize("a"));
EXPECT_EQ("B A", normalizer.Normalize("ba"));
EXPECT_EQ("D E", normalizer.Normalize("c"));
EXPECT_EQ(" F G A", normalizer.Normalize("da"));
EXPECT_EQ(" A F G ", normalizer.Normalize("ad"));
EXPECT_EQ(" A F G B", normalizer.Normalize("adb"));
}
{
spec.set_add_dummy_prefix(true);
spec.set_remove_extra_whitespaces(true);
const Normalizer normalizer(spec);
EXPECT_EQ(" A", normalizer.Normalize("a"));
EXPECT_EQ(" B A", normalizer.Normalize("ba"));
EXPECT_EQ(" D E", normalizer.Normalize("c"));
EXPECT_EQ(" F G A", normalizer.Normalize("da"));
EXPECT_EQ(" A F G", normalizer.Normalize("ad"));
EXPECT_EQ(" A F G B", normalizer.Normalize("adb"));
}
{
spec.set_add_dummy_prefix(true);
spec.set_remove_extra_whitespaces(false);
const Normalizer normalizer(spec);
EXPECT_EQ(" A", normalizer.Normalize("a"));
EXPECT_EQ(" B A", normalizer.Normalize("ba"));
EXPECT_EQ(" D E", normalizer.Normalize("c"));
EXPECT_EQ(" F G A", normalizer.Normalize("da"));
EXPECT_EQ(" A F G ", normalizer.Normalize("ad"));
EXPECT_EQ(" A F G B", normalizer.Normalize("adb"));
}
}
TEST(NormalizerTest, NormalizeFullTest) {
std::vector<size_t> n2i;
std::string output;
auto spec = MakeDefaultSpec();
const Normalizer normalizer(spec);
{
const std::string input = "I saw a girl";
normalizer.Normalize(input, &output, &n2i);
EXPECT_EQ(WS "I" WS "saw" WS "a" WS "girl", output);
const std::vector<size_t> expected = {0, 0, 0, // WS (3byte)
0, // I
1, 1, 1, // WS
2, 3, 4, // saw
5, 5, 5, // WS
6, // a
7, 7, 7, // WS
8, 9, 10, 11, // girl
12};
EXPECT_EQ(expected, n2i);
}
{
const std::string input = " I saw a   girl  ";
normalizer.Normalize(input, &output, &n2i);
EXPECT_EQ(WS "I" WS "saw" WS "a" WS "girl", output);
const std::vector<size_t> expected = {1, 1, 1, // WS (3byte)
1, // I
2, 2, 2, // WS
5, 6, 7, // saw
8, 8, 8, // WS
9, // a
10, 10, 10, // WS
17, 18, 19, 20, // girl
21};
EXPECT_EQ(expected, n2i);
}
{
const std::string input = " グーグル "; // halfwidth katakana
normalizer.Normalize(input, &output, &n2i);
EXPECT_EQ(WS "グーグル", output);
const std::vector<size_t> expected = {1, 1, 1, // WS (3byte)
1, 1, 1, // グ
7, 7, 7, // ー
10, 10, 10, // グ
16, 16, 16, // ル
19};
EXPECT_EQ(expected, n2i);
}
{
const std::string input = "①②③";
normalizer.Normalize(input, &output, &n2i);
EXPECT_EQ(WS "123", output);
const std::vector<size_t> expected = {0, 0, 0, // WS (3byte)
0, // 1
3, // 2
6, // 3
9};
EXPECT_EQ(expected, n2i);
}
{
const std::string input = "";
normalizer.Normalize(input, &output, &n2i);
EXPECT_EQ(WS "株式会社", output);
const std::vector<size_t> expected = {0, 0, 0, // WS (3byte)
0, 0, 0, // 株
0, 0, 0, // 式
0, 0, 0, // 会
0, 0, 0, // 社
3};
// When "株式" is one piece, this has no alignment to the input.
// Sentencepieces which includes the last character ("会社" or "社")
// have the alignment to the input.
EXPECT_EQ(expected, n2i);
}
}
TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest) {
const std::string blob = Normalizer::EncodePrecompiledCharsMap("foo", "bar");
StringPiece trie_blob, normalized_blob;
Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob, &normalized_blob);
EXPECT_EQ("foo", trie_blob);
EXPECT_EQ("bar", normalized_blob);
}
} // namespace normalizer
} // namespace sentencepiece

63
src/sentencepiece.proto Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
syntax = "proto2";
package sentencepiece;
// SentencePieceText manages a user-facing source sentence,
// postprocessed target sentence, and internal segmentation
// with byte offsets.
message SentencePieceText {
message SentencePiece {
// Internal representation for the decoder.
// - Decoder can use |piece| as a basic token.
// - the piece must be non-empty.
// - A whitespace is replaced with a meta symbol.
// - Concatenation of pieces is not always the same as the |text|.
optional string piece = 1;
// Vocabulary id.
optional uint32 id = 2;
// External representation for the client.
// - It is always guaranteed that
// text.substr(begin, end - begin) == surface.
// - Concatenation of surface is always the same as the |text|.
// - |surface| may contain whitespaces.
// - |surface| may be empty if the piece encodes
// a control vocabulary. e.g., </s>, </s>, <unk>.
// - When |surface| is empty, always begin == end. (zero-length span).
optional string surface = 3;
optional uint32 begin = 4;
optional uint32 end = 5;
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}
// User input or postprocessed text. This should be immutable
// since the byte range in SentencePiece is pointing to a span over this
// text. Meta symbols for whitespaces are not included.
optional string text = 1;
// A sequence of sentence pieces.
repeated SentencePiece pieces = 2;
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}

View File

@ -0,0 +1,196 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
syntax = "proto2";
package sentencepiece;
// TrainerSpec encodes a various parameters for SentencePiece training.
message TrainerSpec {
///////////////////////////////////////////////////////////////////
// General parameters
//
// Input corpus files.
// Trainer accepts the following two formats:
// A) Monolingual: plain text, one sentence per line.
// B) Bilingual: TSV, source sentence <tab> target sentence
// When bilingual data is passed, shared vocabulary model is built.
// Note that the input file must be raw corpus, not a preprocessed corpus.
// Trainer only loads the first |input_sentence_size| sentences specified
// with this parameter.
repeated string input = 1;
// Output model file prefix.
// <model_prefix>.model and <model_prefix>.vocab are generated.
optional string model_prefix = 2;
// Model type. only have UNIGRAM now.
enum ModelType {
UNIGRAM = 1; // Unigram language model with dynamic algorithm
BPE = 2; // Byte Pair Encoding
WORD = 3; // Delimitered by whitespace.
CHAR = 4; // tokenizes into character sequence
}
optional ModelType model_type = 3 [ default = UNIGRAM ];
// Vocabulary size. 32k is the default size.
optional int32 vocab_size = 4 [ default = 8000 ];
// List of the languages this model can accept.
// Since the model is language-agnostic, this field is used as a reference.
repeated string accept_language = 5;
///////////////////////////////////////////////////////////////////
// Training parameters.
//
// Uses characters which cover the corpus with the ratio of |chars_coverage|.
// This parameter determines the set of basic Alphabet of sentence piece.
// 1.0 - |chars_coverage| characters are treated as UNK.
optional float character_coverage = 10 [ default = 0.9995 ];
// Maximum size of sentences the trainer loads from |input| parameter.
// Trainer simply loads the |input| files in sequence.
// It is better to shuffle the input corpus randomly.
optional int32 input_sentence_size = 11 [ default = 10000000 ];
// Maximum size of sentences to make seed sentence pieces.
// Extended suffix array is constructed to extract frequent
// sub-strings from the corpus. This uses 20N working space,
// where N is the size of corpus.
optional int32 mining_sentence_size = 12 [ default = 2000000 ];
// Maximum size of sentences to train sentence pieces.
optional int32 training_sentence_size = 13 [ default = 10000000 ];
// The size of seed sentencepieces.
// |seed_sentencepiece_size| must be larger than |vocab_size|.
optional int32 seed_sentencepiece_size = 14 [ default = 1000000 ];
// In every EM sub-iterations, keeps top
// |shrinking_factor| * |current sentencepieces size| with respect to
// the loss of the sentence piece. This value should be smaller than 1.0.
optional float shrinking_factor = 15 [ default = 0.75 ];
// Number of threads in the training.
optional int32 num_threads = 16 [ default = 16 ];
// Number of EM sub iterations.
optional int32 num_sub_iterations = 17 [ default = 2 ];
///////////////////////////////////////////////////////////////////
// SentencePiece parameters which control the shapes of sentence piece.
//
// Maximum length of sentencepiece.
optional int32 max_sentencepiece_length = 20 [ default = 16 ];
// Uses Unicode script to split sentence pieces.
// When |split_by_unicode_script| is true, we do not allow sentence piece to
// include multiple Unicode scripts, e.g. "F1" is not a valid piece.
// Exception: CJ characters (Hiragana/Katakana/Han) are all handled
// as one script type, since Japanese word can consist of multiple scripts.
// This exception is always applied regardless of the accept-language
// parameter.
optional bool split_by_unicode_script = 21 [ default = true ];
// Use a white space to split sentence pieces.
// When |split_by_whitespace| is false, we may have the piece containing
// a white space in the middle. e.g., "in_the".
optional bool split_by_whitespace = 22 [ default = true ];
///////////////////////////////////////////////////////////////////
// Vocabulary management
//
// Defines control symbols used as an indicator to
// change the behavior of the decoder. <s> and </s> are pre-defined.
// We can use this field to encode various meta information,
// including language indicator in multilingual model.
// These symbols are not visible to users, but visible to
// the decoder. Note that when the input sentence contains control symbols,
// they are not treated as one token, but segmented into normal pieces.
// Control symbols must be inserted independently from the segmentation.
repeated string control_symbols = 30;
// Defines user defined symbols.
// These symbols are added with extremely high score
// so they are always treated as one unique symbol in any context.
// Typical usage of user_defined_symbols is placeholder for named entities.
repeated string user_defined_symbols = 31;
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
};
// NormalizerSpec encodes a various parameters for string normalizaiton
message NormalizerSpec {
// name of normalization rule.
optional string name = 1;
// Pre-compiled normalization rule created by
// Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method.
// Usually this field is set by Builder::GetNormalizerSpec() method.
optional bytes precompiled_charsmap = 2;
// Adds dummy whitespace at the beginning of text in order to
// treat "world" in "world" and "hello world" in the same way.
optional bool add_dummy_prefix = 3 [ default = true ];
// Removes leading, trailing, and duplicate internal whitespace.
optional bool remove_extra_whitespaces = 4 [ default = true ];
// Replaces whitespace with meta symbol.
// This field must be true to train sentence piece model.
optional bool escape_whitespaces = 5 [ default = true ];
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}
// ModelProto stores model parameters.
// SentencePieceProcessor is supposed to be self-contained.
// All settings/parameters which may change the behavior must be encoded
// in ModelProto.
message ModelProto {
message SentencePiece {
enum Type {
NORMAL = 1; // normal symbol
UNKNOWN = 2; // unknown symbol. only <unk> for now.
CONTROL = 3; // control symbols. </s>, <s>, <2ja> etc.
USER_DEFINED = 4; // user defined symbols.
// Typical usage of USER_DEFINED symbol
// is placeholder.
};
optional string piece = 1; // piece must not be empty.
optional float score = 2;
optional Type type = 3 [ default = NORMAL ];
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}
// Sentence pieces with scores.
repeated SentencePiece pieces = 1;
// Spec used to generate this model file.
optional TrainerSpec trainer_spec = 2;
// Spec for text normalization.
optional NormalizerSpec normalizer_spec = 3;
// Customized extensions: the range of field numbers
// are open to third-party extensions.
extensions 200 to max;
}

View File

@ -0,0 +1,315 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_processor.h"
#include "common.h"
#include "model_factory.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "unigram_model.h"
#include "util.h"
namespace sentencepiece {
namespace {
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
const char kSpaceSymbol[] = "\xe2\x96\x81";
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
const char kUnknownSymbol[] = " \xE2\x81\x87 ";
} // namespace
SentencePieceProcessor::SentencePieceProcessor() {}
SentencePieceProcessor::~SentencePieceProcessor() {}
bool SentencePieceProcessor::Load(const std::string &filename) {
std::ifstream ifs(filename.c_str(), std::ios::binary | std::ios::in);
if (!ifs) {
LOG(WARNING) << "Cannot open " << filename;
return false;
}
model_proto_ = port::MakeUnique<ModelProto>();
if (!model_proto_->ParseFromIstream(&ifs)) {
LOG(WARNING) << "Model file is broken: " << filename;
return false;
}
model_ = ModelFactory::Create(*model_proto_);
normalizer_ =
port::MakeUnique<normalizer::Normalizer>(model_proto_->normalizer_spec());
return true;
}
void SentencePieceProcessor::LoadOrDie(const std::string &filename) {
CHECK(Load(filename)) << "failed to load model: " << filename;
}
void SentencePieceProcessor::SetEncodeExtraOptions(
const std::string &extra_options) {
encode_extra_options_ = ParseExtraOptions(extra_options);
}
void SentencePieceProcessor::SetDecodeExtraOptions(
const std::string &extra_options) {
decode_extra_options_ = ParseExtraOptions(extra_options);
}
//////////////////////////////////////////////////////////////
// Simple API.
void SentencePieceProcessor::Encode(const std::string &input,
std::vector<std::string> *pieces) const {
CHECK_NOTNULL(pieces)->clear();
SentencePieceText spt;
Encode(input, &spt);
for (const auto &sp : spt.pieces()) {
pieces->emplace_back(sp.piece());
}
}
void SentencePieceProcessor::Encode(const std::string &input,
std::vector<int> *ids) const {
CHECK_NOTNULL(ids)->clear();
SentencePieceText spt;
Encode(input, &spt);
for (const auto &sp : spt.pieces()) {
ids->emplace_back(sp.id());
}
}
void SentencePieceProcessor::Decode(const std::vector<std::string> &pieces,
std::string *detokenized) const {
CHECK_NOTNULL(detokenized);
SentencePieceText spt;
Decode(pieces, &spt);
*detokenized = std::move(spt.text());
}
void SentencePieceProcessor::Decode(const std::vector<int> &ids,
std::string *detokenized) const {
CHECK_NOTNULL(detokenized);
SentencePieceText spt;
Decode(ids, &spt);
*detokenized = std::move(spt.text());
}
//////////////////////////////////////////////////////////////
// Advanced API with SentencePieceText proto.
void SentencePieceProcessor::Encode(const std::string &input,
SentencePieceText *spt) const {
CHECK_NOTNULL(spt)->Clear();
std::string normalized;
std::vector<size_t> norm_to_orig;
CHECK_NOTNULL(normalizer_)->Normalize(input, &normalized, &norm_to_orig);
size_t consumed = 0;
bool is_prev_unk = false;
for (const auto &p : CHECK_NOTNULL(model_)->Encode(normalized)) {
const StringPiece w = p.first; // piece
const int id = p.second; // id
CHECK(!w.empty());
const bool is_unk = IsUnknown(id);
if (IsControl(id)) {
// Control symbol has no corresponding source surface, so begin == end.
auto *sp = spt->add_pieces();
sp->set_piece(w.to_string());
sp->set_id(id);
sp->set_begin(norm_to_orig[consumed]);
sp->set_end(norm_to_orig[consumed]);
} else {
const size_t begin = consumed;
const size_t end = consumed + w.size();
CHECK_LE(begin, norm_to_orig.size());
CHECK_LE(end, norm_to_orig.size());
const size_t orig_begin = norm_to_orig[begin];
const size_t orig_end = norm_to_orig[end];
CHECK_LE(orig_begin, input.size());
CHECK_LE(orig_end, input.size());
CHECK_LE(orig_begin, orig_end);
const auto surface = input.substr(orig_begin, orig_end - orig_begin);
// Merges continuous run of unknown pieces so that decoder
// can copy or generate unknown tokens easily.
// Note that merged tokens are still unknown,
// since known pieces never consist of unknown characters.
if (is_prev_unk && is_unk) {
auto *sp = spt->mutable_pieces(spt->pieces_size() - 1);
sp->set_piece(sp->piece() + w.to_string());
sp->set_surface(sp->surface() + surface);
sp->set_end(orig_end);
} else {
auto *sp = spt->add_pieces();
sp->set_piece(w.to_string());
sp->set_id(id);
sp->set_surface(surface);
sp->set_begin(orig_begin);
sp->set_end(orig_end);
}
consumed += w.size();
}
is_prev_unk = is_unk;
}
CHECK_EQ(consumed, normalized.size());
ApplyExtraOptions(encode_extra_options_, spt);
spt->set_text(input);
}
void SentencePieceProcessor::Decode(const std::vector<std::string> &pieces,
SentencePieceText *spt) const {
CHECK_NOTNULL(spt)->Clear();
auto DecodeSentencePiece = [&](StringPiece piece, int id,
bool is_bos_ws) -> std::string {
if (IsControl(id)) { // <s>, </s>
return ""; // invisible symbol.
} else if (IsUnknown(id)) {
if (IdToPiece(id) == piece) { // <unk>
return kUnknownSymbol;
} else { // return piece when piece is not <unk>.
return piece.to_string();
}
}
if (is_bos_ws) {
// Consume if the current position is bos and
// piece starts with kSpaceSymbol.
piece.Consume(kSpaceSymbol);
}
return string_util::StringReplace(piece, kSpaceSymbol, " ", true);
};
for (const std::string &w : pieces) {
auto *sp = spt->add_pieces();
sp->set_piece(w);
sp->set_id(PieceToId(w));
}
ApplyExtraOptions(decode_extra_options_, spt);
std::string *text = spt->mutable_text();
for (auto &sp : *(spt->mutable_pieces())) {
sp.set_surface(DecodeSentencePiece(sp.piece(), sp.id(), text->empty()));
sp.set_begin(text->size());
sp.set_end(text->size() + sp.surface().size());
*text += sp.surface();
}
}
void SentencePieceProcessor::Decode(const std::vector<int> &ids,
SentencePieceText *spt) const {
std::vector<std::string> pieces;
for (const int id : ids) {
pieces.emplace_back(IdToPiece(id));
}
return Decode(pieces, spt);
}
int SentencePieceProcessor::GetPieceSize() const {
return CHECK_NOTNULL(model_)->GetPieceSize();
}
int SentencePieceProcessor::PieceToId(const std::string &piece) const {
return CHECK_NOTNULL(model_)->PieceToId(piece);
}
std::string SentencePieceProcessor::IdToPiece(int id) const {
return CHECK_NOTNULL(model_)->IdToPiece(id);
}
float SentencePieceProcessor::GetScore(int id) const {
return CHECK_NOTNULL(model_)->GetScore(id);
}
bool SentencePieceProcessor::IsControl(int id) const {
return CHECK_NOTNULL(model_)->IsControl(id);
}
bool SentencePieceProcessor::IsUnknown(int id) const {
return CHECK_NOTNULL(model_)->IsUnknown(id);
}
// static
void SentencePieceProcessor::ApplyExtraOptions(
const std::vector<ExtraOption> &extra_options,
SentencePieceText *spt) const {
constexpr int kBOS = 1;
constexpr int kEOS = 2;
for (const auto &extra_option : extra_options) {
switch (extra_option) {
case REVERSE:
std::reverse(spt->mutable_pieces()->begin(),
spt->mutable_pieces()->end());
break;
case EOS: {
auto *piece = spt->add_pieces();
piece->set_id(kEOS);
piece->set_piece(IdToPiece(kEOS));
} break;
case BOS: {
auto *array = spt->mutable_pieces();
array->Add();
for (int i = array->size() - 1; i > 0; --i) {
array->SwapElements(i - 1, i);
}
auto *piece = array->Mutable(0);
piece->set_id(kBOS);
piece->set_piece(IdToPiece(kBOS));
} break;
default:
LOG(FATAL) << "Unknown extra_option type: "
<< static_cast<int>(extra_option);
}
}
}
// static
std::vector<SentencePieceProcessor::ExtraOption>
SentencePieceProcessor::ParseExtraOptions(const std::string &extra_option) {
static std::map<std::string, SentencePieceProcessor::ExtraOption>
extra_option_map = {{"bos", SentencePieceProcessor::BOS},
{"eos", SentencePieceProcessor::EOS},
{"reverse", SentencePieceProcessor::REVERSE}};
std::vector<SentencePieceProcessor::ExtraOption> extra_options;
for (const auto &s : string_util::Split(extra_option, ":")) {
extra_options.push_back(port::FindOrDie(extra_option_map, s));
}
return extra_options;
}
void SentencePieceProcessor::SetModel(std::unique_ptr<ModelInterface> &&model) {
model_ = std::move(model);
}
void SentencePieceProcessor::SetNormalizer(
std::unique_ptr<normalizer::Normalizer> &&normalizer) {
normalizer_ = std::move(normalizer);
}
const ModelProto &SentencePieceProcessor::model_proto() const {
CHECK_NOTNULL(model_proto_);
return *model_proto_;
}
} // namespace sentencepiece

View File

@ -0,0 +1,194 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef SENTENCEPIECE_PROCESSOR_H_
#define SENTENCEPIECE_PROCESSOR_H_
#include <memory>
#include <string>
#include <vector>
namespace sentencepiece {
// SentencePieceProcessor:
// Simple and language independent tokenizer and de-tokenizer for
// Neural Network Machine Translation.
//
// SentencePieceProcessor provides Encode() and Decode() methods,
// which correspond to tokenization and de-tokenization respectively.
//
// - Encode:
// Given a raw source sentence, encode it into a sequence
// of pieces or vocabulary ids.
//
// - Decode:
// Given a sequence of pieces or vocabulary ids, decode it
// into a de-tokenized raw sentence.
//
// SentencePieceProcessor provides a lossless data conversion
// that allows the original raw sentence to be perfectly reconstructed
// from the encoded data, i.e., Decode(Encode(input)) == input.
// This characteristics is useful, as we can make the de-tokenization
// completely language independent.
//
// Usage:
// SentencePieceProcessor sp;
// sp.Load("//path/to/model");
//
// vector<string> sps;
// sp.Encode("hello world.", &sps);
//
// vector<int> ids;
// sp.Encode("hello world.", &ids);
//
// string detok;
// sp.Decode(sps, &detok);
// CHECK_EQ("hello world.", detok);
//
// sp.Decode(ids, &detok);
// CHECK_EQ("hello world.", detok);
//
// We can also use SentencePieceText which manages the byte-offsets
// between user input (output) and internal sentence pieces.
//
// SentencePieceText spt;
// sp.Encode("hello world.", &spt);
// // Emits the byte range of each piece.
// for (const auto &piece : spt.pieces()) {
// LOG(INFO) << piece.begin() << " " << piece.end();
// }
//
// sp.Decode({0, 1, 2, 3..}, &spt);
// for (const auto &piece : spt.pieces()) {
// LOG(INFO) << piece.begin() << " " << piece.end();
// }
//
class SentencePieceText;
class ModelInterface;
class ModelProto;
namespace normalizer {
class Normalizer;
} // namespace normalizer
class SentencePieceProcessor {
public:
SentencePieceProcessor();
virtual ~SentencePieceProcessor();
// Loads model from |filename|.
// Returns false if |filename| cannot be loaded.
virtual bool Load(const std::string &filename);
// Loads model from |filename|.
// Dies if |filename| cannot be loaded.
virtual void LoadOrDie(const std::string &filename);
// Sets encode extra_option sequence.
virtual void SetEncodeExtraOptions(const std::string &extra_option);
// Sets dncode extra_option sequence.
virtual void SetDecodeExtraOptions(const std::string &extra_option);
//////////////////////////////////////////////////////////////
// Simple API.
//
// Given a UTF8 input, encodes it into a sequence of sentence pieces.
virtual void Encode(const std::string &input,
std::vector<std::string> *pieces) const;
// Given a UTF8 input, encodes it into a sequence of ids.
virtual void Encode(const std::string &input, std::vector<int> *ids) const;
// Given a sequence of pieces, decodes it into a detokenized output.
virtual void Decode(const std::vector<std::string> &pieces,
std::string *detokenized) const;
// Given a sequence of ids, decodes it into a detokenized output.
virtual void Decode(const std::vector<int> &ids,
std::string *detokenized) const;
//////////////////////////////////////////////////////////////
// Advanced API returning SentencePieceText, which manages
// utf8-byte alignments between user-input/detokenized text
// and internal sentencepiece sequence.
//
// Given a UTF8 input, encodes it into SentencePieceText.
virtual void Encode(const std::string &input, SentencePieceText *spt) const;
// Given a sequence of pieces, decodes it into SentencePieceText.
virtual void Decode(const std::vector<std::string> &pieces,
SentencePieceText *spt) const;
// Given a sequence of ids, decodes it into SentencePieceText.
virtual void Decode(const std::vector<int> &ids,
SentencePieceText *spt) const;
//////////////////////////////////////////////////////////////
// Vocabulary management methods.
//
// Returns the size of sentence pieces, which is the same as
// the size of vocabulary for NMT.
virtual int GetPieceSize() const;
// Returns the vocab id of |piece|.
// Returns UNK(0) if |piece| is unknown.
virtual int PieceToId(const std::string &piece) const;
// Returns the string representation of vocab with |id|.
virtual std::string IdToPiece(int id) const;
// Returns the score of |id|.
// Usually score is an emission log probability of unigram language model.
virtual float GetScore(int id) const;
// Returns true if |id| is unknown symbol.
virtual bool IsUnknown(int id) const;
// Returns true if |id| is control symbol.
virtual bool IsControl(int id) const;
//////////////////////////////////////////////////////////////
// Model management.
//
// Allows injection of a mock model instance. |model| is moved.
void SetModel(std::unique_ptr<ModelInterface> &&model);
// Allows injection of a normalizer instance. |normalizer| is moved.
void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
// Returns immutable model proto. Useful to obtain extended
// or experimental parameters encoded in model_proto.
const ModelProto &model_proto() const;
private:
enum ExtraOption { REVERSE, BOS, EOS };
static std::vector<ExtraOption> ParseExtraOptions(
const std::string &extra_option);
void ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
SentencePieceText *spt) const;
std::unique_ptr<ModelInterface> model_;
std::unique_ptr<normalizer::Normalizer> normalizer_;
// Underlying model protocol buffer. The same lifetime as model_.
std::unique_ptr<ModelProto> model_proto_;
std::vector<ExtraOption> encode_extra_options_;
std::vector<ExtraOption> decode_extra_options_;
};
} // namespace sentencepiece
#endif // SENTENCEPIECE_PROCESSOR_H_

View File

@ -0,0 +1,660 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_processor.h"
#include <unordered_map>
#include "builder.h"
#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "stringpiece.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
using port::MakeUnique;
// Space symbol
#define WS "\xe2\x96\x81"
class MockModel : public ModelInterface {
public:
void SetEncodeResult(StringPiece input,
const std::vector<std::pair<StringPiece, int>> &output) {
input_ = input;
output_ = output;
}
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const {
EXPECT_EQ(normalized, input_);
return output_;
}
bool IsControl(int id) const { return id == 1 || id == 2; }
bool IsUnknown(int id) const { return id == 0; }
int GetPieceSize() const { return 10; }
int PieceToId(StringPiece piece) const { return 0; }
std::string IdToPiece(int id) const { return ""; }
float GetScore(int id) const { return 0.0; }
private:
StringPiece input_;
std::vector<std::pair<StringPiece, int>> output_;
};
std::vector<std::string> GetSpVec(
const std::vector<std::pair<StringPiece, int>> &pieces) {
std::vector<std::string> sps;
for (const auto &p : pieces) {
sps.emplace_back(p.first.to_string());
}
return sps;
}
std::vector<std::string> GetSpVec(const SentencePieceText &spt) {
std::vector<std::string> sps;
for (auto &sp : spt.pieces()) {
sps.emplace_back(sp.piece());
}
return sps;
}
NormalizerSpec MakeDefaultNormalizerSpec() {
return normalizer::Builder::GetNormalizerSpec("nfkc");
}
TEST(SentencepieceProcessorTest, EncodeTest) {
const StringPiece kInput = WS "ABC" WS "DEF";
SentencePieceProcessor sp;
const auto normalization_spec = MakeDefaultNormalizerSpec();
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
std::vector<std::string> output;
sp.Encode("ABC DEF", &output);
EXPECT_EQ(GetSpVec(result), output);
SentencePieceText spt;
sp.Encode("ABC DEF", &spt);
EXPECT_EQ(4, spt.pieces_size());
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("ABC", spt.pieces(0).surface());
EXPECT_EQ(" DE", spt.pieces(1).surface());
EXPECT_EQ("F", spt.pieces(2).surface());
EXPECT_EQ("", spt.pieces(3).surface()); // </s>
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(0, spt.pieces(2).id());
EXPECT_EQ(2, spt.pieces(3).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(3, spt.pieces(0).end());
EXPECT_EQ(3, spt.pieces(1).begin());
EXPECT_EQ(6, spt.pieces(1).end());
EXPECT_EQ(6, spt.pieces(2).begin());
EXPECT_EQ(7, spt.pieces(2).end());
EXPECT_EQ(7, spt.pieces(3).begin());
EXPECT_EQ(7, spt.pieces(3).end());
}
// Unknown sequences.
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {
{WS "ABC", 3}, {WS "D", 4}, {"E", 0}, {"F", 0}, {"</s>", 2}};
const std::vector<std::pair<StringPiece, int>> expected = {
{WS "ABC", 3}, {WS "D", 4}, {"EF", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
std::vector<std::string> output;
sp.Encode("ABC DEF", &output);
EXPECT_EQ(GetSpVec(expected), output);
SentencePieceText spt;
sp.Encode("ABC DEF", &spt);
EXPECT_EQ(4, spt.pieces_size());
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(expected[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("ABC", spt.pieces(0).surface());
EXPECT_EQ(" D", spt.pieces(1).surface());
EXPECT_EQ("EF", spt.pieces(2).surface());
EXPECT_EQ("", spt.pieces(3).surface()); // </s>
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(0, spt.pieces(2).id());
EXPECT_EQ(2, spt.pieces(3).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(3, spt.pieces(0).end());
EXPECT_EQ(3, spt.pieces(1).begin());
EXPECT_EQ(5, spt.pieces(1).end());
EXPECT_EQ(5, spt.pieces(2).begin());
EXPECT_EQ(7, spt.pieces(2).end());
EXPECT_EQ(7, spt.pieces(3).begin());
EXPECT_EQ(7, spt.pieces(3).end());
}
// Crash if
// ModelInterface::Encode() returns shorter results.
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {{WS "ABC", 3}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
// Expects crash.
EXPECT_DEATH(sp.Encode("ABC DEF", &spt));
}
// Crash if
// ModelInterface::Encode() returns longer results.
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 5}, {"G", 6}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
// Expects crash.
EXPECT_DEATH(sp.Encode("ABC DEF", &spt));
}
// Crash if
// ModelInterface::Encode() returns an empty piece.
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {
{WS "ABC", 3}, {WS "DE", 4}, {"", 5}, {"F", 6}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
// Expects crash.
EXPECT_DEATH(sp.Encode("ABC DEF", &spt));
}
// Halfwidth to Fullwidith katakana normalization.
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {
{WS "グー", 3}, {"グル", 4}, {"</s>", 2}};
const StringPiece input = WS "グーグル";
mock->SetEncodeResult(input, result);
sp.SetModel(std::move(mock));
std::vector<std::string> output;
sp.Encode("グーグル", &output);
EXPECT_EQ(GetSpVec(result), output);
SentencePieceText spt;
sp.Encode("グーグル", &spt);
EXPECT_EQ(3, spt.pieces_size());
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("グー", spt.pieces(0).surface());
EXPECT_EQ("グル", spt.pieces(1).surface());
EXPECT_EQ("", spt.pieces(2).surface());
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(2, spt.pieces(2).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(9, spt.pieces(0).end());
EXPECT_EQ(9, spt.pieces(1).begin());
EXPECT_EQ(18, spt.pieces(1).end());
EXPECT_EQ(18, spt.pieces(2).begin()); // </s>
EXPECT_EQ(18, spt.pieces(2).end());
}
// One to many normalization.
{
auto mock = MakeUnique<MockModel>();
const std::vector<std::pair<StringPiece, int>> result = {
{WS "株式", 3}, {"会社", 4}, {"</s>", 2}};
const StringPiece input = WS "株式会社";
mock->SetEncodeResult(input, result);
sp.SetModel(std::move(mock));
std::vector<std::string> output;
sp.Encode("", &output);
EXPECT_EQ(GetSpVec(result), output);
SentencePieceText spt;
sp.Encode("", &spt);
EXPECT_EQ(3, spt.pieces_size());
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("", spt.pieces(0).surface());
EXPECT_EQ("", spt.pieces(1).surface());
EXPECT_EQ("", spt.pieces(2).surface());
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(2, spt.pieces(2).id());
EXPECT_EQ(0, spt.pieces(0).begin()); // 株式
EXPECT_EQ(0, spt.pieces(0).end());
EXPECT_EQ(0, spt.pieces(1).begin()); // 会社
EXPECT_EQ(3, spt.pieces(1).end());
EXPECT_EQ(3, spt.pieces(2).begin()); // </s>
EXPECT_EQ(3, spt.pieces(2).end());
}
}
TEST(SentencepieceProcessorTest, DecodeTest) {
class DecodeMockModel : public ModelInterface {
public:
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const override {
return {};
}
int GetPieceSize() const override { return 7; }
int PieceToId(StringPiece piece) const override {
static std::unordered_map<StringPiece, int, StringPieceHash> kMap = {
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
{WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}};
return port::FindWithDefault(kMap, piece, 0);
}
std::string IdToPiece(int id) const override {
static std::vector<std::string> kMap = {
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H"};
return kMap[id];
}
bool IsUnknown(int id) const override { return (id == 0); }
bool IsControl(int id) const override { return (id == 1 || id == 2); }
float GetScore(int id) const override { return 0.0; }
};
SentencePieceProcessor sp;
auto mock = MakeUnique<DecodeMockModel>();
// std::unique_ptr<ModelInterface> mock(new DecodeMockModel);
sp.SetModel(std::move(mock));
const auto normalizaiton_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(MakeUnique<normalizer::Normalizer>(normalizaiton_spec));
const std::vector<std::string> input = {"<s>", WS "ABC", "<unk>", WS "DE",
"F", "G" WS "H", "I", "</s>"};
SentencePieceText spt;
sp.Decode(input, &spt);
EXPECT_EQ("ABC \xE2\x81\x87 DEFG HI", spt.text());
EXPECT_EQ(8, spt.pieces_size());
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(input[i], spt.pieces(i).piece());
}
EXPECT_EQ("", spt.pieces(0).surface());
EXPECT_EQ("ABC", spt.pieces(1).surface());
EXPECT_EQ(" \xE2\x81\x87 ", spt.pieces(2).surface());
EXPECT_EQ(" DE", spt.pieces(3).surface());
EXPECT_EQ("F", spt.pieces(4).surface());
EXPECT_EQ("G H", spt.pieces(5).surface());
EXPECT_EQ("I", spt.pieces(6).surface());
EXPECT_EQ("", spt.pieces(7).surface());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(0, spt.pieces(0).end());
EXPECT_EQ(0, spt.pieces(1).begin());
EXPECT_EQ(3, spt.pieces(1).end());
EXPECT_EQ(3, spt.pieces(2).begin());
EXPECT_EQ(8, spt.pieces(2).end());
EXPECT_EQ(8, spt.pieces(3).begin());
EXPECT_EQ(11, spt.pieces(3).end());
EXPECT_EQ(11, spt.pieces(4).begin());
EXPECT_EQ(12, spt.pieces(4).end());
EXPECT_EQ(12, spt.pieces(5).begin());
EXPECT_EQ(15, spt.pieces(5).end());
EXPECT_EQ(15, spt.pieces(6).begin());
EXPECT_EQ(16, spt.pieces(6).end());
EXPECT_EQ(16, spt.pieces(7).begin());
EXPECT_EQ(16, spt.pieces(7).end());
}
void AddPiece(ModelProto *model_proto, StringPiece piece, float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece.to_string());
sp->set_score(score);
}
TEST(SentencePieceProcessorTest, LoadInvalidModelTest) {
SentencePieceProcessor sp;
EXPECT_DEATH(sp.LoadOrDie(""));
EXPECT_DEATH(sp.LoadOrDie("__UNKNOWN_FILE__"));
}
TEST(SentencePieceProcessorTest, EndToEndTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
AddPiece(&model_proto, "a", 0.0);
AddPiece(&model_proto, "b", 0.3);
AddPiece(&model_proto, "c", 0.2);
AddPiece(&model_proto, "ab", 1.0);
AddPiece(&model_proto, "\xE2\x96\x81", 3.0); // kSpaceSymbol
*(model_proto.mutable_normalizer_spec()) = MakeDefaultNormalizerSpec();
test::ScopedTempFile sf("model");
{
std::ofstream ofs(sf.filename(), OUTPUT_MODE);
CHECK(model_proto.SerializeToOstream(&ofs));
}
SentencePieceProcessor sp;
sp.Load(sf.filename());
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
EXPECT_EQ(8, sp.GetPieceSize());
EXPECT_EQ(0, sp.PieceToId("<unk>"));
EXPECT_EQ(1, sp.PieceToId("<s>"));
EXPECT_EQ(2, sp.PieceToId("</s>"));
EXPECT_EQ(3, sp.PieceToId("a"));
EXPECT_EQ(4, sp.PieceToId("b"));
EXPECT_EQ(5, sp.PieceToId("c"));
EXPECT_EQ(6, sp.PieceToId("ab"));
EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81"));
EXPECT_EQ("<unk>", sp.IdToPiece(0));
EXPECT_EQ("<s>", sp.IdToPiece(1));
EXPECT_EQ("</s>", sp.IdToPiece(2));
EXPECT_EQ("a", sp.IdToPiece(3));
EXPECT_EQ("b", sp.IdToPiece(4));
EXPECT_EQ("c", sp.IdToPiece(5));
EXPECT_EQ("ab", sp.IdToPiece(6));
EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7));
EXPECT_NEAR(0.0, sp.GetScore(0), 0.001);
EXPECT_NEAR(0.0, sp.GetScore(1), 0.001);
EXPECT_NEAR(0.0, sp.GetScore(2), 0.001);
EXPECT_NEAR(0.0, sp.GetScore(3), 0.001);
EXPECT_NEAR(0.3, sp.GetScore(4), 0.001);
EXPECT_NEAR(0.2, sp.GetScore(5), 0.001);
EXPECT_NEAR(1.0, sp.GetScore(6), 0.001);
EXPECT_NEAR(3.0, sp.GetScore(7), 0.001);
EXPECT_TRUE(sp.IsUnknown(0));
EXPECT_FALSE(sp.IsUnknown(1));
EXPECT_FALSE(sp.IsUnknown(2));
EXPECT_FALSE(sp.IsUnknown(3));
EXPECT_FALSE(sp.IsUnknown(4));
EXPECT_FALSE(sp.IsUnknown(5));
EXPECT_FALSE(sp.IsUnknown(6));
EXPECT_FALSE(sp.IsUnknown(7));
EXPECT_FALSE(sp.IsControl(0));
EXPECT_TRUE(sp.IsControl(1));
EXPECT_TRUE(sp.IsControl(2));
EXPECT_FALSE(sp.IsControl(3));
EXPECT_FALSE(sp.IsControl(4));
EXPECT_FALSE(sp.IsControl(5));
EXPECT_FALSE(sp.IsControl(6));
EXPECT_FALSE(sp.IsControl(7));
{
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c"};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {7, 6, 5};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
sp.SetEncodeExtraOptions("bos");
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"<s>", WS, "ab", "c"};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {1, 7, 6, 5};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
sp.SetEncodeExtraOptions("eos");
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c", "</s>"};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {7, 6, 5, 2};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
sp.SetEncodeExtraOptions("reverse");
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"c", "ab", WS};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {5, 6, 7};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
sp.SetEncodeExtraOptions("bos:eos");
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"<s>", WS, "ab", "c",
"</s>"};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {1, 7, 6, 5, 2};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
sp.SetEncodeExtraOptions("reverse:bos:eos");
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"<s>", "c", "ab", WS,
"</s>"};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {1, 5, 6, 7, 2};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
sp.SetEncodeExtraOptions("bos:eos:reverse");
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"</s>", "c", "ab", WS,
"<s>"};
sp.Encode("abc", &sps);
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {2, 5, 6, 7, 1};
sp.Encode("abc", &ids);
EXPECT_EQ(expected_id, ids);
}
{
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("abc", output);
}
{
sp.SetDecodeExtraOptions("bos");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("abc", output);
}
{
sp.SetDecodeExtraOptions("eos");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("abc", output);
}
{
sp.SetDecodeExtraOptions("reverse");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("cab", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("cba", output);
}
{
sp.SetDecodeExtraOptions("bos:eos");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("abc", output);
}
{
sp.SetDecodeExtraOptions("reverse:bos:eos");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("cab", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("cba", output);
}
{
sp.SetDecodeExtraOptions("bos:eos:reverse");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("cab", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("cba", output);
}
{
sp.SetDecodeExtraOptions("reverse:reverse");
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
sp.Decode(sps, &output);
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
sp.Decode(ids, &output);
EXPECT_EQ("abc", output);
}
EXPECT_DEATH(sp.SetEncodeExtraOptions("foo"));
EXPECT_DEATH(sp.SetDecodeExtraOptions("foo"));
}
} // namespace sentencepiece

97
src/spm_decode_main.cc Normal file
View File

@ -0,0 +1,97 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "common.h"
#include "flags.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_processor.h"
#include "util.h"
DEFINE_string(model, "", "model file name");
DEFINE_string(output, "", "output filename");
DEFINE_string(input_format, "piece", "choose from piece or id");
DEFINE_string(output_format, "string", "choose from string or proto");
DEFINE_string(extra_options, "",
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
int main(int argc, char *argv[]) {
std::vector<std::string> rest_args;
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
CHECK_OR_HELP(model);
sentencepiece::SentencePieceProcessor sp;
sp.LoadOrDie(FLAGS_model);
sp.SetDecodeExtraOptions(FLAGS_extra_options);
sentencepiece::io::OutputBuffer output(FLAGS_output);
if (rest_args.empty()) {
rest_args.push_back(""); // empty means that reading from stdin.
}
std::string detok, line;
sentencepiece::SentencePieceText spt;
std::function<void(const std::vector<std::string> &pieces)> process;
auto ToIds = [&](const std::vector<std::string> &pieces) {
std::vector<int> ids;
for (const auto &s : pieces) {
ids.push_back(atoi(s.c_str()));
}
return ids;
};
if (FLAGS_input_format == "piece") {
if (FLAGS_output_format == "string") {
process = [&](const std::vector<std::string> &pieces) {
sp.Decode(pieces, &detok);
output.WriteLine(detok);
};
} else if (FLAGS_output_format == "proto") {
process = [&](const std::vector<std::string> &pieces) {
sp.Decode(pieces, &spt);
output.WriteLine(spt.Utf8DebugString());
};
} else {
LOG(FATAL) << "Unknown output format: " << FLAGS_output_format;
}
} else if (FLAGS_input_format == "id") {
if (FLAGS_output_format == "string") {
process = [&](const std::vector<std::string> &pieces) {
sp.Decode(ToIds(pieces), &detok);
output.WriteLine(detok);
};
} else if (FLAGS_output_format == "proto") {
process = [&](const std::vector<std::string> &pieces) {
sp.Decode(ToIds(pieces), &spt);
output.WriteLine(spt.Utf8DebugString());
};
} else {
LOG(FATAL) << "Unknown output format: " << FLAGS_output_format;
}
} else {
LOG(FATAL) << "Unknown input format: " << FLAGS_input_format;
}
for (const auto &filename : rest_args) {
sentencepiece::io::InputBuffer input(filename);
while (input.ReadLine(&line)) {
const auto pieces = sentencepiece::string_util::Split(line, " ");
process(pieces);
}
}
return 0;
}

79
src/spm_encode_main.cc Normal file
View File

@ -0,0 +1,79 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "common.h"
#include "flags.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_processor.h"
#include "util.h"
DEFINE_string(model, "", "model file name");
DEFINE_string(output_format, "piece", "choose from piece, id, or proto");
DEFINE_string(output, "", "output filename");
DEFINE_string(extra_options, "",
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
int main(int argc, char *argv[]) {
std::vector<std::string> rest_args;
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
CHECK_OR_HELP(model);
sentencepiece::SentencePieceProcessor sp;
sp.LoadOrDie(FLAGS_model);
sp.SetEncodeExtraOptions(FLAGS_extra_options);
sentencepiece::io::OutputBuffer output(FLAGS_output);
if (rest_args.empty()) {
rest_args.push_back(""); // empty means that reading from stdin.
}
std::string line;
std::vector<std::string> sps;
std::vector<int> ids;
sentencepiece::SentencePieceText spt;
std::function<void(const std::string &line)> process;
if (FLAGS_output_format == "piece") {
process = [&](const std::string &line) {
sp.Encode(line, &sps);
output.WriteLine(sentencepiece::string_util::Join(sps, " "));
};
} else if (FLAGS_output_format == "id") {
process = [&](const std::string &line) {
sp.Encode(line, &ids);
output.WriteLine(sentencepiece::string_util::Join(ids, " "));
};
} else if (FLAGS_output_format == "proto") {
process = [&](const std::string &line) {
sp.Encode(line, &spt);
output.WriteLine(spt.Utf8DebugString());
};
} else {
LOG(FATAL) << "Unknown output format: " << FLAGS_output_format;
}
for (const auto &filename : rest_args) {
sentencepiece::io::InputBuffer input(filename);
while (input.ReadLine(&line)) {
if (line.empty()) {
continue;
}
process(line);
}
}
return 0;
}

View File

@ -0,0 +1,44 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <sstream>
#include "common.h"
#include "flags.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "util.h"
DEFINE_string(output, "", "Output filename");
DEFINE_string(model, "", "input model file name");
DEFINE_string(output_format, "txt", "output format. choose from txt or proto");
int main(int argc, char *argv[]) {
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
sentencepiece::SentencePieceProcessor sp;
sp.LoadOrDie(FLAGS_model);
sentencepiece::io::OutputBuffer output(FLAGS_output);
if (FLAGS_output_format == "txt") {
for (const auto &piece : sp.model_proto().pieces()) {
std::ostringstream os;
os << piece.piece() << "\t" << piece.score();
output.WriteLine(os.str());
}
} else if (FLAGS_output_format == "proto") {
output.Write(sp.model_proto().Utf8DebugString());
}
return 0;
}

74
src/spm_normalize_main.cc Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "builder.h"
#include "common.h"
#include "flags.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "util.h"
DEFINE_string(model, "", "Model file name");
DEFINE_bool(use_internal_normalization, false,
"Use NormalizerSpec \"as-is\" to run the normalizer "
"for SentencePiece segmentation");
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
DEFINE_bool(remove_extra_whitespaces, true, "Remove extra whitespaces");
DEFINE_string(output, "", "Output filename");
int main(int argc, char *argv[]) {
std::vector<std::string> rest_args;
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
sentencepiece::NormalizerSpec spec;
if (FLAGS_normalization_rule_tsv.empty() && !FLAGS_model.empty()) {
sentencepiece::SentencePieceProcessor sp;
sp.LoadOrDie(FLAGS_model);
spec = sp.model_proto().normalizer_spec();
} else if (!FLAGS_normalization_rule_tsv.empty() && FLAGS_model.empty()) {
const auto cmap = sentencepiece::normalizer::Builder::BuildMapFromFile(
FLAGS_normalization_rule_tsv);
spec.set_precompiled_charsmap(
sentencepiece::normalizer::Builder::CompileCharsMap(cmap));
} else {
LOG(FATAL) << "Sets --model or normalization_rule_tsv flag";
}
// Uses the normalizer spec encoded in the model_pb.
if (!FLAGS_use_internal_normalization) {
spec.set_add_dummy_prefix(false); // do not add dummy prefix.
spec.set_escape_whitespaces(false); // do not output meta symbol.
spec.set_remove_extra_whitespaces(FLAGS_remove_extra_whitespaces);
}
sentencepiece::normalizer::Normalizer normalizer(spec);
sentencepiece::io::OutputBuffer output(FLAGS_output);
if (rest_args.empty()) {
rest_args.push_back(""); // empty means that read from stdin.
}
std::string line;
for (const auto &filename : rest_args) {
sentencepiece::io::InputBuffer input(filename);
while (input.ReadLine(&line)) {
output.WriteLine(normalizer.Normalize(line));
}
}
return 0;
}

153
src/spm_train_main.cc Normal file
View File

@ -0,0 +1,153 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "builder.h"
#include "flags.h"
#include "trainer_factory.h"
using sentencepiece::TrainerSpec;
using sentencepiece::NormalizerSpec;
using sentencepiece::normalizer::Builder;
namespace {
static sentencepiece::TrainerSpec kDefaultTrainerSpec;
static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
} // namespace
DEFINE_string(input, "", "comma separated list of input sentences");
DEFINE_string(model_prefix, "", "output model prefix");
DEFINE_string(model_type, "unigram",
"model algorithm: unigram, bpe, word or char");
DEFINE_int32(vocab_size, kDefaultTrainerSpec.vocab_size(), "vocabulary size");
DEFINE_string(accept_language, "",
"comma-separated list of languages this model can accept");
DEFINE_double(character_coverage, kDefaultTrainerSpec.character_coverage(),
"character coverage to determine the minimum symbols");
DEFINE_int32(input_sentence_size, kDefaultTrainerSpec.input_sentence_size(),
"maximum size of sentences the trainer loads");
DEFINE_int32(mining_sentence_size, kDefaultTrainerSpec.mining_sentence_size(),
"maximum size of sentences to make seed sentence piece");
DEFINE_int32(training_sentence_size,
kDefaultTrainerSpec.training_sentence_size(),
"maximum size of sentences to train sentence pieces");
DEFINE_int32(seed_sentencepiece_size,
kDefaultTrainerSpec.seed_sentencepiece_size(),
"the size of seed sentencepieces");
DEFINE_double(shrinking_factor, kDefaultTrainerSpec.shrinking_factor(),
"Keeps top shrinking_factor pieces with respect to the loss");
DEFINE_int32(num_threads, kDefaultTrainerSpec.num_threads(),
"number of threads for training");
DEFINE_int32(num_sub_iterations, kDefaultTrainerSpec.num_sub_iterations(),
"number of EM sub-iterations");
DEFINE_int32(max_sentencepiece_length,
kDefaultTrainerSpec.max_sentencepiece_length(),
"maximum length of sentence piece");
DEFINE_bool(split_by_unicode_script,
kDefaultTrainerSpec.split_by_unicode_script(),
"use Unicode script to split sentence pieces");
DEFINE_bool(split_by_whitespace, kDefaultTrainerSpec.split_by_whitespace(),
"use a white space to split sentence pieces");
DEFINE_string(control_symbols, "", "comma separated list of control symbols");
DEFINE_string(user_defined_symbols, "",
"comma separated list of user defined symbols");
DEFINE_string(normalization_rule_name, "nfkc",
"Normalization rule name. "
"Choose from nfkc or identity");
DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
DEFINE_bool(add_dummy_prefix, kDefaultNormalizerSpec.add_dummy_prefix(),
"Add dummy whitespace at the begging of text");
DEFINE_bool(remove_extra_whitespaces,
kDefaultNormalizerSpec.remove_extra_whitespaces(),
"Removes leading, trailing, and "
"duplicate internal whitespace");
namespace {
sentencepiece::NormalizerSpec MakeNormalizerSpec() {
if (!FLAGS_normalization_rule_tsv.empty()) {
const auto chars_map = sentencepiece::normalizer::Builder::BuildMapFromFile(
FLAGS_normalization_rule_tsv);
sentencepiece::NormalizerSpec spec;
spec.set_name("user_defined");
spec.set_precompiled_charsmap(
sentencepiece::normalizer::Builder::CompileCharsMap(chars_map));
return spec;
}
return sentencepiece::normalizer::Builder::GetNormalizerSpec(
FLAGS_normalization_rule_name);
}
} // namespace
int main(int argc, char *argv[]) {
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
sentencepiece::TrainerSpec trainer_spec;
sentencepiece::NormalizerSpec normalizer_spec;
CHECK_OR_HELP(input);
CHECK_OR_HELP(model_prefix);
// Populates the value from flags to spec.
#define SetTrainerSpecFromFlag(name) trainer_spec.set_##name(FLAGS_##name);
#define SetNormalizerSpecFromFlag(name) \
normalizer_spec.set_##name(FLAGS_##name);
#define SetRepeatedTrainerSpecFromFlag(name) \
if (!FLAGS_##name.empty()) { \
for (const auto v : \
sentencepiece::string_util::Split(FLAGS_##name, ",")) { \
trainer_spec.add_##name(v); \
} \
}
SetTrainerSpecFromFlag(model_prefix);
SetTrainerSpecFromFlag(vocab_size);
SetTrainerSpecFromFlag(character_coverage);
SetTrainerSpecFromFlag(input_sentence_size);
SetTrainerSpecFromFlag(mining_sentence_size);
SetTrainerSpecFromFlag(training_sentence_size);
SetTrainerSpecFromFlag(seed_sentencepiece_size);
SetTrainerSpecFromFlag(shrinking_factor);
SetTrainerSpecFromFlag(num_threads);
SetTrainerSpecFromFlag(num_sub_iterations);
SetTrainerSpecFromFlag(max_sentencepiece_length);
SetTrainerSpecFromFlag(split_by_unicode_script);
SetTrainerSpecFromFlag(split_by_whitespace);
SetRepeatedTrainerSpecFromFlag(accept_language);
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
normalizer_spec = MakeNormalizerSpec();
SetNormalizerSpecFromFlag(add_dummy_prefix);
SetNormalizerSpecFromFlag(remove_extra_whitespaces);
for (const auto &filename :
sentencepiece::string_util::Split(FLAGS_input, ",")) {
trainer_spec.add_input(filename);
}
const std::map<std::string, TrainerSpec::ModelType> kModelTypeMap = {
{"unigram", TrainerSpec::UNIGRAM},
{"bpe", TrainerSpec::BPE},
{"word", TrainerSpec::WORD},
{"char", TrainerSpec::CHAR}};
trainer_spec.set_model_type(
sentencepiece::port::FindOrDie(kModelTypeMap, FLAGS_model_type));
auto trainer =
sentencepiece::TrainerFactory::Create(trainer_spec, normalizer_spec);
trainer->Train();
return 0;
}

234
src/stringpiece.h Normal file
View File

@ -0,0 +1,234 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef STRINGPIECE_H_
#define STRINGPIECE_H_
#include <cstring>
#include <string>
namespace sentencepiece {
class StringPiece {
public:
typedef size_t size_type;
// Create an empty slice.
StringPiece() : data_(""), size_(0) {}
// Create a slice that refers to d[0,n-1].
StringPiece(const char *d, size_t n) : data_(d), size_(n) {}
// Create a slice that refers to the contents of "s"
StringPiece(const std::string &s) : data_(s.data()), size_(s.size()) {}
// Create a slice that refers to s[0,strlen(s)-1]
StringPiece(const char *s) : data_(s), size_(strlen(s)) {}
void set(const void *data, size_t len) {
data_ = reinterpret_cast<const char *>(data);
size_ = len;
}
void set(const char *data) {
data_ = data;
size_ = strlen(data);
}
// Return a pointer to the beginning of the referenced data
const char *data() const { return data_; }
// Return the length (in bytes) of the referenced data
size_t size() const { return size_; }
// Return true iff the length of the referenced data is zero
bool empty() const { return size_ == 0; }
typedef const char *const_iterator;
typedef const char *iterator;
iterator begin() const { return data_; }
iterator end() const { return data_ + size_; }
static const size_type npos = static_cast<size_type>(-1);
char operator[](size_t n) const { return data_[n]; }
// Change this slice to refer to an empty array
void clear() {
data_ = "";
size_ = 0;
}
// Drop the first "n" bytes from this slice.
void remove_prefix(size_t n) {
data_ += n;
size_ -= n;
}
void remove_suffix(size_t n) { size_ -= n; }
size_type find(StringPiece s, size_type pos = 0) const {
if (size_ <= 0 || pos > static_cast<size_type>(size_)) {
if (size_ == 0 && pos == 0 && s.size_ == 0) {
return 0;
}
return npos;
}
const char *result = memmatch(data_ + pos, size_ - pos, s.data_, s.size_);
return result ? result - data_ : npos;
}
size_type find(char c, size_type pos) const {
if (size_ <= 0 || pos >= static_cast<size_type>(size_)) {
return npos;
}
const char *result =
static_cast<const char *>(memchr(data_ + pos, c, size_ - pos));
return result != nullptr ? result - data_ : npos;
}
size_type find_first_of(char c, size_type pos = 0) const {
return find(c, pos);
}
size_type find_first_of(StringPiece s, size_type pos = 0) const {
if (size_ <= 0 || s.size_ <= 0) {
return npos;
}
if (s.size_ == 1) {
return find_first_of(s.data_[0], pos);
}
bool lookup[256] = {false};
for (size_t i = 0; i < s.size_; ++i) {
lookup[static_cast<unsigned char>(s.data_[i])] = true;
}
for (size_t i = pos; i < size_; ++i) {
if (lookup[static_cast<unsigned char>(data_[i])]) {
return i;
}
}
return npos;
}
bool Consume(StringPiece x) {
if (starts_with(x)) {
remove_prefix(x.size_);
return true;
}
return false;
}
StringPiece substr(size_type pos, size_type n = npos) const {
size_type size = static_cast<size_type>(size_);
if (pos > size) pos = size;
if (n > size - pos) n = size - pos;
return StringPiece(data_ + pos, n);
}
// Return a string that contains the copy of the referenced data.
std::string ToString() const { return std::string(data_, size_); }
std::string to_string() const { return std::string(data_, size_); }
// Three-way comparison. Returns value:
// < 0 iff "*this" < "b",
// == 0 iff "*this" == "b",
// > 0 iff "*this" > "b"
int compare(StringPiece b) const;
// Return true iff "x" is a prefix of "*this"
bool starts_with(StringPiece x) const {
return ((size_ >= x.size_) && (memcmp(data_, x.data_, x.size_) == 0));
}
// Return true iff "x" is a suffix of "*this"
bool ends_with(StringPiece x) const {
return ((size_ >= x.size_) &&
(memcmp(data_ + (size_ - x.size_), x.data_, x.size_) == 0));
}
private:
static const char *memmatch(const char *phaystack, size_t haylen,
const char *pneedle, size_t neelen) {
if (0 == neelen) {
return phaystack; // even if haylen is 0
}
if (haylen < neelen) {
return nullptr;
}
const char *match;
const char *hayend = phaystack + haylen - neelen + 1;
while ((match = (const char *)(memchr(phaystack, pneedle[0],
hayend - phaystack)))) {
if (memcmp(match, pneedle, neelen) == 0) {
return match;
} else {
phaystack = match + 1;
}
}
return nullptr;
}
const char *data_;
size_t size_;
};
inline bool operator==(StringPiece x, StringPiece y) {
return ((x.size() == y.size()) &&
(memcmp(x.data(), y.data(), x.size()) == 0));
}
inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
inline bool operator<=(StringPiece x, StringPiece y) {
return x.compare(y) <= 0;
}
inline bool operator>=(StringPiece x, StringPiece y) {
return x.compare(y) >= 0;
}
inline int StringPiece::compare(StringPiece b) const {
const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
int r = memcmp(data_, b.data_, min_len);
if (r == 0) {
if (size_ < b.size_) {
r = -1;
} else if (size_ > b.size_) {
r = +1;
}
}
return r;
}
inline std::ostream &operator<<(std::ostream &o, StringPiece piece) {
o << piece.data();
return o;
}
struct StringPieceHash {
// DJB hash function.
inline size_t operator()(const StringPiece &sp) const {
size_t hash = 5381;
for (size_t i = 0; i < sp.size(); ++i) {
hash = ((hash << 5) + hash) + sp[i];
}
return hash;
}
};
} // namespace sentencepiece
#endif // STRINGPIECE_H_

20
src/test_main.cc Normal file
View File

@ -0,0 +1,20 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "testharness.h"
int main(int argc, char **argv) {
sentencepiece::test::RunAllTests();
return 0;
}

76
src/testharness.cc Normal file
View File

@ -0,0 +1,76 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "testharness.h"
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "common.h"
#include "util.h"
#include <unistd.h>
namespace sentencepiece {
namespace test {
namespace {
struct Test {
const char *base;
const char *name;
void (*func)();
};
std::vector<Test> *tests;
} // namespace
bool RegisterTest(const char *base, const char *name, void (*func)()) {
if (tests == nullptr) {
tests = new std::vector<Test>;
}
Test t;
t.base = base;
t.name = name;
t.func = func;
tests->emplace_back(t);
return true;
}
int RunAllTests() {
int num = 0;
if (tests == nullptr) {
std::cerr << "No tests are found" << std::endl;
return 0;
}
for (const Test &t : *(tests)) {
std::cerr << "[ RUN ] " << t.base << "." << t.name << std::endl;
(*t.func)();
std::cerr << "[ OK ] " << t.base << "." << t.name << std::endl;
++num;
}
std::cerr << "==== PASSED " << num << " tests" << std::endl;
return 0;
}
ScopedTempFile::ScopedTempFile(const std::string &filename) {
char pid[64];
snprintf(pid, sizeof(pid), "%u", getpid());
filename_ = "/tmp/.XXX.tmp." + filename + "." + pid;
}
ScopedTempFile::~ScopedTempFile() { ::unlink(filename_.c_str()); }
} // namespace test
} // namespace sentencepiece

171
src/testharness.h Normal file
View File

@ -0,0 +1,171 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef TESTHARNESS_H_
#define TESTHARNESS_H_
#include <cmath>
#include <iostream>
#include <sstream>
#include <string>
#include "common.h"
namespace sentencepiece {
namespace test {
// Run some of the tests registered by the TEST() macro.
// TEST(Foo, Hello) { ... }
// TEST(Foo, World) { ... }
//
// Returns 0 if all tests pass.
// Dies or returns a non-zero value if some test fails.
int RunAllTests();
class ScopedTempFile {
public:
explicit ScopedTempFile(const std::string &filename);
~ScopedTempFile();
const char *filename() const { return filename_.c_str(); }
private:
std::string filename_;
};
// An instance of Tester is allocated to hold temporary state during
// the execution of an assertion.
class Tester {
public:
Tester(const char *fname, int line) : ok_(true), fname_(fname), line_(line) {}
~Tester() {
if (!ok_) {
std::cerr << "[ NG ] " << fname_ << ":" << line_ << ":" << ss_.str()
<< std::endl;
exit(-1);
}
}
Tester &Is(bool b, const char *msg) {
if (!b) {
ss_ << " failed: " << msg;
ok_ = false;
}
return *this;
}
Tester &IsNear(double val1, double val2, double abs_error, const char *msg1,
const char *msg2) {
const double diff = std::fabs(val1 - val2);
if (diff > abs_error) {
ss_ << "The difference between (" << msg1 << ") and (" << msg2 << ") is "
<< diff << ", which exceeds " << abs_error << ", where\n"
<< msg1 << " evaluates to " << val1 << ",\n"
<< msg2 << " evaluates to " << val2;
ok_ = false;
}
return *this;
}
#define BINARY_OP(name, op) \
template <class X, class Y> \
Tester &name(const X &x, const Y &y, const char *msg1, const char *msg2) { \
if (!(x op y)) { \
ss_ << " failed: " << msg1 << (" " #op " ") << msg2; \
ok_ = false; \
} \
return *this; \
}
BINARY_OP(IsEq, ==)
BINARY_OP(IsNe, !=)
BINARY_OP(IsGe, >=)
BINARY_OP(IsGt, >)
BINARY_OP(IsLe, <=)
BINARY_OP(IsLt, <)
#undef BINARY_OP
// Attach the specified value to the error message if an error has occurred
template <class V>
Tester &operator<<(const V &value) {
if (!ok_) {
ss_ << " " << value;
}
return *this;
}
private:
bool ok_;
const char *fname_;
int line_;
std::stringstream ss_;
};
#define EXPECT_TRUE(c) \
sentencepiece::test::Tester(__FILE__, __LINE__).Is((c), #c)
#define EXPECT_FALSE(c) \
sentencepiece::test::Tester(__FILE__, __LINE__).Is((!(c)), #c)
#define EXPECT_STREQ(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__) \
.IsEq(std::string(a), std::string(b), #a, #b)
#define EXPECT_EQ(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsEq((a), (b), #a, #b)
#define EXPECT_NE(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsNe((a), (b), #a, #b)
#define EXPECT_GE(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsGe((a), (b), #a, #b)
#define EXPECT_GT(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsGt((a), (b), #a, #b)
#define EXPECT_LE(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsLe((a), (b), #a, #b)
#define EXPECT_LT(a, b) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsLt((a), (b), #a, #b)
#define EXPECT_NEAR(a, b, c) \
sentencepiece::test::Tester(__FILE__, __LINE__).IsNear((a), (b), (c), #a, #b)
#define EXPECT_DEATH(statement) \
{ \
error::gTestMode = true; \
if (setjmp(error::gTestJmp) == 0) { \
do { \
statement; \
} while (false); \
EXPECT_TRUE(false); \
} else { \
error::gTestMode = false; \
} \
};
#define TCONCAT(a, b, c) TCONCAT1(a, b, c)
#define TCONCAT1(a, b, c) a##b##c
#define TEST(base, name) \
class TCONCAT(base, _Test_, name) { \
public: \
void _Run(); \
static void _RunIt() { \
TCONCAT(base, _Test_, name) t; \
t._Run(); \
} \
}; \
bool TCONCAT(base, _Test_ignored_, name) = \
sentencepiece::test::RegisterTest(#base, #name, \
&TCONCAT(base, _Test_, name)::_RunIt); \
void TCONCAT(base, _Test_, name)::_Run()
// Register the specified test. Typically not used directly, but
// invoked via the macro expansion of TEST.
extern bool RegisterTest(const char *base, const char *name, void (*func)());
} // namespace test
} // namespace sentencepiece
#endif // TESTHARNESS_H_

49
src/trainer_factory.cc Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "trainer_factory.h"
#include "bpe_model_trainer.h"
#include "char_model_trainer.h"
#include "unigram_model_trainer.h"
#include "util.h"
#include "word_model_trainer.h"
namespace sentencepiece {
// Instantiate Trainer instance from trainer_spec and normalization_spec
std::unique_ptr<TrainerInterface> TrainerFactory::Create(
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
switch (trainer_spec.model_type()) {
case TrainerSpec::UNIGRAM:
return port::MakeUnique<unigram::Trainer>(trainer_spec, normalizer_spec);
break;
case TrainerSpec::BPE:
return port::MakeUnique<bpe::Trainer>(trainer_spec, normalizer_spec);
break;
case TrainerSpec::WORD:
return port::MakeUnique<word::Trainer>(trainer_spec, normalizer_spec);
break;
case TrainerSpec::CHAR:
return port::MakeUnique<character::Trainer>(trainer_spec,
normalizer_spec);
break;
default:
LOG(FATAL) << "Unknown model_type: " << trainer_spec.model_type();
break;
}
return port::MakeUnique<unigram::Trainer>(trainer_spec, normalizer_spec);
}
} // namespace sentencepiece

31
src/trainer_factory.h Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef TRAINER_FACTORY_H_
#define TRAINER_FACTORY_H_
#include <memory>
#include "sentencepiece_model.pb.h"
#include "trainer_interface.h"
namespace sentencepiece {
class TrainerFactory {
public:
// Creates Trainer instance from |trainer_spec| and |normalizer_spec|.
static std::unique_ptr<TrainerInterface> Create(
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec);
};
} // namespace sentencepiece
#endif // TRAINER_FACTORY_H_

View File

@ -0,0 +1,44 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "trainer_factory.h"
#include "testharness.h"
namespace sentencepiece {
TEST(TrainerFactoryTest, BasicTest) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
{
trainer_spec.set_model_type(TrainerSpec::UNIGRAM);
auto m = TrainerFactory::Create(trainer_spec, normalizer_spec);
}
{
trainer_spec.set_model_type(TrainerSpec::BPE);
auto m = TrainerFactory::Create(trainer_spec, normalizer_spec);
}
{
trainer_spec.set_model_type(TrainerSpec::WORD);
auto m = TrainerFactory::Create(trainer_spec, normalizer_spec);
}
{
trainer_spec.set_model_type(TrainerSpec::CHAR);
auto m = TrainerFactory::Create(trainer_spec, normalizer_spec);
}
}
} // namespace sentencepiece

297
src/trainer_interface.cc Normal file
View File

@ -0,0 +1,297 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "trainer_interface.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "model_factory.h"
#include "normalizer.h"
#include "sentencepiece_processor.h"
#include "unicode_script.h"
#include "util.h"
namespace sentencepiece {
const char32 TrainerInterface::kWSChar = L'\u2581';
const char TrainerInterface::kWSStr[] = "\xe2\x96\x81";
const char32 TrainerInterface::kUNKChar = L'\u2585';
const char TrainerInterface::kUNKStr[] = "\xe2\x96\x85";
const char32 TrainerInterface::kUPPBoundaryChar = L'\u0009';
const char TrainerInterface::kUPPBoundaryStr[] = "\t";
TrainerInterface::TrainerInterface(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {}
TrainerInterface::~TrainerInterface() {}
bool TrainerInterface::IsValidSentencePiece(
const string_util::UnicodeText &sentencepiece) const {
// Returns false if the length of piece is invalid.
if (sentencepiece.empty() ||
sentencepiece.size() >
static_cast<size_t>(trainer_spec_.max_sentencepiece_length())) {
return false;
}
size_t pos = 0;
unicode_script::ScriptType prev_script =
static_cast<unicode_script::ScriptType>(-1);
for (auto it = sentencepiece.begin(); it != sentencepiece.end(); ++it) {
if (*it == kUNKChar) { // UNK must not be included
return false;
}
// kUPPBoundaryChar is included when split_by_upp_for_training is true.
if (*it == kUPPBoundaryChar) {
return false;
}
if (*it == 0x0020) {
LOG(WARNING) << "space must not be included in normalized string.";
return false;
}
if (*it == kWSChar) {
// Only allows whitespace to appear as a prefix of piece.
// When split_by_whitespace is false, we allow whitespaces to
// appear in the middle, "foo_bar", but do not allow them
// to appear as suffix, "foo_bar_".
// Regardless of the setting of split_by_whitespace,
// whitespace is treated as a prefix/infix of symbol or
// independent symbol.
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
pos == sentencepiece.size() - 1)) {
return false;
}
} else {
auto s = unicode_script::GetScript(*it);
// Merge Hiragana/Katakana into Han.
if (s == unicode_script::U_Hiragana || s == unicode_script::U_Katakana ||
*it == 0x30FC) { // long vowel sound (Katakana) should be Katakana
s = unicode_script::U_Han;
}
// Do not allow a piece to include multiple Unicode scripts
// when split_by_unicode_script() is true (default = true).
if (prev_script != -1 && prev_script != s &&
trainer_spec_.split_by_unicode_script()) {
return false;
}
prev_script = s;
}
++pos;
}
return true;
}
void TrainerInterface::LoadSentences() {
CHECK(sentences_.empty());
CHECK(required_chars_.empty());
const normalizer::Normalizer normalizer(normalizer_spec_);
for (const auto &filename : trainer_spec_.input()) {
LOG(INFO) << "Loading corpus: " << filename;
std::string sentence;
io::InputBuffer input(filename);
while (input.ReadLine(&sentence)) {
constexpr int kMaxLines = 2048;
if (sentence.size() > kMaxLines) {
continue;
}
if (sentence.find(kUNKStr) != std::string::npos) {
LOG(INFO) << "Reserved chars are found. Skipped: " << sentence;
continue;
}
// Normalizes sentence with Normalizer.
// whitespaces are replaced with kWSChar.
const std::string normalized = normalizer.Normalize(sentence);
if (sentences_.size() % 100000 == 0) {
LOG(INFO) << "Loading: " << normalized
<< "\tsize=" << sentences_.size();
}
CHECK(normalized.find(" ") == std::string::npos)
<< "Normalized string must not include spaces";
if (normalized.empty()) {
LOG(WARNING) << "Empty string found. removed";
continue;
}
// TODO(taku): We assumes that the sentence frequency is always 1.
// Support to use sentences with frequencies.
sentences_.emplace_back(normalized, 1);
if (sentences_.size() ==
static_cast<size_t>(trainer_spec_.input_sentence_size())) {
goto END;
}
}
}
END:
LOG(INFO) << "Loaded " << sentences_.size() << " sentences";
// Count character frequencies.
int64 all_chars_count = 0;
std::unordered_map<char32, int64> chars_count;
for (const auto &w : sentences_) {
for (const char32 c : string_util::UTF8ToUnicodeText(w.first)) {
if (c == 0x0020) {
// UTF8ToUnicodeText returns a white space if the text
// contains an interchange-invalid character.
CHECK(w.first.find(" ") == std::string::npos)
<< "space must not be included in normalized string.";
continue;
}
chars_count[c] += w.second;
all_chars_count += w.second;
}
}
LOG(INFO) << "all chars count=" << all_chars_count;
// Determines required_chars which must be included in the vocabulary.
int64 accumulated_chars_count = 0;
for (const auto &w : Sorted(chars_count)) {
const float coverage = 1.0 * accumulated_chars_count / all_chars_count;
if (coverage >= trainer_spec_.character_coverage()) {
LOG(INFO) << "Done: " << 100.0 * coverage << "% characters are covered.";
break;
}
accumulated_chars_count += w.second;
CHECK_NE(w.first, 0x0020)
<< "space must not be included in normalized string.";
required_chars_.insert(w);
}
LOG(INFO) << "alphabet size=" << required_chars_.size();
CHECK(!port::ContainsKey(required_chars_, kUNKChar));
// Replaces rare characters (characters not included in required_chars_)
// with kUNKChar.
for (auto &w : sentences_) {
string_util::UnicodeText uw2;
for (const char32 c : string_util::UTF8ToUnicodeText(w.first)) {
if (port::ContainsKey(required_chars_, c)) {
uw2.push_back(c);
} else {
uw2.push_back(kUNKChar);
}
}
w.first = string_util::UnicodeTextToUTF8(uw2);
}
LOG(INFO) << "Done! " << sentences_.size() << " sentences are loaded";
}
void TrainerInterface::SplitSentencesByWhitespace() {
LOG(INFO) << "Tokenizing input sentences with whitespace: "
<< sentences_.size();
std::unordered_map<std::string, int64> tokens;
for (const auto &s : sentences_) {
for (const auto &w : SplitIntoWords(s.first)) {
tokens[w.to_string()] += s.second;
}
}
sentences_ = Sorted(tokens);
LOG(INFO) << "Done! " << sentences_.size();
}
// #endif
void TrainerInterface::Serialize(ModelProto *model_proto) const {
// Duplicated sentencepiece is not allowed.
std::unordered_set<std::string> dup;
auto CheckPiece = [&dup](const std::string &piece) {
CHECK(!piece.empty());
CHECK(dup.insert(piece).second) << piece << " is already defined";
};
auto *unk = model_proto->add_pieces();
unk->set_piece("<unk>");
unk->set_type(ModelProto::SentencePiece::UNKNOWN);
CheckPiece(unk->piece());
for (const auto &w : {"<s>", "</s>"}) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w);
sp->set_type(ModelProto::SentencePiece::CONTROL);
CheckPiece(sp->piece());
}
for (const auto &w : trainer_spec_.control_symbols()) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w);
sp->set_type(ModelProto::SentencePiece::CONTROL);
CheckPiece(sp->piece());
}
for (const auto &w : trainer_spec_.user_defined_symbols()) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w);
sp->set_type(ModelProto::SentencePiece::USER_DEFINED);
sp->set_score(0.0);
CheckPiece(sp->piece());
}
for (const auto &w : final_pieces_) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w.first);
sp->set_score(w.second);
CheckPiece(sp->piece());
}
if (trainer_spec_.model_type() == TrainerSpec::CHAR) {
CHECK_GE(trainer_spec_.vocab_size(), model_proto->pieces_size());
CHECK_GE(static_cast<size_t>(trainer_spec_.vocab_size()), dup.size());
} else {
CHECK_EQ(trainer_spec_.vocab_size(), model_proto->pieces_size());
CHECK_EQ(static_cast<size_t>(trainer_spec_.vocab_size()), dup.size());
}
*(model_proto->mutable_trainer_spec()) = trainer_spec_;
*(model_proto->mutable_normalizer_spec()) = normalizer_spec_;
}
void TrainerInterface::SaveModel(StringPiece filename) const {
LOG(INFO) << "Saving model: " << filename;
ModelProto model_proto;
Serialize(&model_proto);
std::ofstream ofs(filename.data(), OUTPUT_MODE);
CHECK_OFS(ofs, filename.to_string());
CHECK(model_proto.SerializeToOstream(&ofs));
}
void TrainerInterface::SaveVocab(StringPiece filename) const {
LOG(INFO) << "Saving vocabs: " << filename;
ModelProto model_proto;
Serialize(&model_proto);
io::OutputBuffer output(filename);
for (const auto &piece : model_proto.pieces()) {
std::ostringstream os;
os << piece.piece() << "\t" << piece.score();
output.WriteLine(os.str());
}
}
void TrainerInterface::Save() const {
SaveModel(trainer_spec_.model_prefix() + ".model");
SaveVocab(trainer_spec_.model_prefix() + ".vocab");
// SaveSplits(trainer_spec_.model_prefix() + ".splits");
}
} // namespace sentencepiece

117
src/trainer_interface.h Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef TRAINER_INTERFACE_H_
#define TRAINER_INTERFACE_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "common.h"
#include "sentencepiece_model.pb.h"
#include "util.h"
namespace sentencepiece {
template <typename K, typename V>
std::vector<std::pair<K, V>> Sorted(const std::vector<std::pair<K, V>> &m) {
std::vector<std::pair<K, V>> v = m;
std::sort(v.begin(), v.end(),
[](const std::pair<K, V> &p1, const std::pair<K, V> &p2) {
return (p1.second > p2.second ||
(p1.second == p2.second && p1.first < p2.first));
});
return v;
}
template <typename K, typename V>
std::vector<std::pair<K, V>> Sorted(const std::unordered_map<K, V> &m) {
std::vector<std::pair<K, V>> v(m.begin(), m.end());
return Sorted(v);
}
// Base trainer class
class TrainerInterface {
public:
using Sentences = std::vector<std::pair<std::string, int64>>;
static const char32 kWSChar;
static const char32 kUNKChar;
static const char32 kUPPBoundaryChar;
static const char kWSStr[];
static const char kUNKStr[];
static const char kUPPBoundaryStr[];
TrainerInterface(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec);
virtual ~TrainerInterface();
virtual void Train() {}
FRIEND_TEST(TrainerInterfaceTest, IsValidSentencePieceTest);
protected:
// Returns true if |piece| is valid sentence piece.
// The result is affected by
// max_sentencepiece_length, split_by_whiespace, split_by_unicode_script.
bool IsValidSentencePiece(const string_util::UnicodeText &piece) const;
// Loads all sentences from spec.input().
// It loads at most input_sentence_size sentences.
void LoadSentences();
// Splits all sentencecs by whitespaces and
// replace the |sentences_| with tokenized string.
// e.g.,
// [ ["hello world ", 1], ["hi world]" ] =>
// [ ["hello", 1], ["hi", 1], ["world", 2] ]
void SplitSentencesByWhitespace();
// Save model files into spec.model_prefix().
void Save() const;
// Set of characters which must be included in the final vocab.
// The value of this map stores the frequency.
std::unordered_map<char32, int64> required_chars_;
// Final output pieces
std::vector<std::pair<std::string, float>> final_pieces_;
// All sentences.
Sentences sentences_;
// Trainer spec.
TrainerSpec trainer_spec_;
// Normalizer spec
NormalizerSpec normalizer_spec_;
private:
// Serialize final_pieces_ to |model_proto|.
void Serialize(ModelProto *model_proto) const;
// Saves the best sentence split with the current model for debugging.
void SaveSplits(StringPiece filename) const;
// Saves model file.
void SaveModel(StringPiece filename) const;
// Saves vocabulary file for NMT.
void SaveVocab(StringPiece filename) const;
};
} // namespace sentencepiece
#endif // TRAINER_INTERFACE_H_

View File

@ -0,0 +1,76 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "trainer_interface.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
// Space symbol
#define WS "\xe2\x96\x81"
TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
auto IsValid = [&trainer_spec, &normalizer_spec](const std::string &str) {
TrainerInterface trainer(trainer_spec, normalizer_spec);
const string_util::UnicodeText text = string_util::UTF8ToUnicodeText(str);
return trainer.IsValidSentencePiece(text);
};
// Default trainer spec.
EXPECT_FALSE(IsValid(""));
EXPECT_FALSE(IsValid("12345678912345678")); // too long
EXPECT_TRUE(IsValid("a"));
EXPECT_TRUE(IsValid(WS));
EXPECT_TRUE(IsValid(WS "a"));
EXPECT_FALSE(IsValid("a" WS));
EXPECT_FALSE(IsValid(WS "a" WS));
EXPECT_FALSE(IsValid("a" WS "b"));
EXPECT_FALSE(IsValid("a" WS "b" WS));
EXPECT_TRUE(IsValid("あいう"));
EXPECT_TRUE(IsValid("グーグル")); // "ー" is a part of Katakana
EXPECT_TRUE(IsValid("食べる"));
EXPECT_FALSE(IsValid("漢字ABC")); // mixed CJK scripts
EXPECT_FALSE(IsValid("F1"));
EXPECT_TRUE(IsValid("$10")); // $ and 1 are both "common" script.
EXPECT_FALSE(IsValid("$ABC"));
EXPECT_FALSE(IsValid("ab\tbc")); // "\t" is UPP boundary.
trainer_spec.set_split_by_whitespace(false);
EXPECT_TRUE(IsValid(WS));
EXPECT_TRUE(IsValid(WS "a"));
EXPECT_FALSE(IsValid("a" WS));
EXPECT_FALSE(IsValid(WS "a" WS));
EXPECT_TRUE(IsValid("a" WS "b")); // "a b" is a valid piece.
EXPECT_TRUE(IsValid(WS "a" WS "b"));
EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c"));
EXPECT_FALSE(IsValid("a" WS "b" WS));
trainer_spec.set_split_by_unicode_script(false);
EXPECT_TRUE(IsValid("あいう"));
EXPECT_TRUE(IsValid("グーグル"));
EXPECT_TRUE(IsValid("食べる"));
EXPECT_TRUE(IsValid("漢字ABC"));
EXPECT_TRUE(IsValid("F1"));
EXPECT_TRUE(IsValid("$10"));
EXPECT_TRUE(IsValid("$ABC"));
trainer_spec.set_max_sentencepiece_length(4);
EXPECT_TRUE(IsValid("1234"));
EXPECT_FALSE(IsValid("12345"));
}
} // namespace sentencepiece

41
src/unicode_script.cc Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unicode_script.h"
#include <unordered_map>
#include "unicode_script_map.h"
#include "util.h"
namespace sentencepiece {
namespace unicode_script {
namespace {
class GetScriptInternal {
public:
GetScriptInternal() { InitTable(&smap_); }
ScriptType GetScript(char32 c) const {
return port::FindWithDefault(smap_, c, ScriptType::U_Common);
}
private:
std::unordered_map<char32, ScriptType> smap_;
};
} // namespace
ScriptType GetScript(char32 c) {
static GetScriptInternal sc;
return sc.GetScript(c);
}
} // namespace unicode_script
} // namespace sentencepiece

165
src/unicode_script.h Normal file
View File

@ -0,0 +1,165 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef UNICODE_SCRIPT_H_
#define UNICODE_SCRIPT_H_
#include "common.h"
namespace sentencepiece {
namespace unicode_script {
enum ScriptType {
U_Adlam,
U_Ahom,
U_Anatolian_Hieroglyphs,
U_Arabic,
U_Armenian,
U_Avestan,
U_Balinese,
U_Bamum,
U_Bassa_Vah,
U_Batak,
U_Bengali,
U_Bhaiksuki,
U_Bopomofo,
U_Brahmi,
U_Braille,
U_Buginese,
U_Buhid,
U_Canadian_Aboriginal,
U_Carian,
U_Caucasian_Albanian,
U_Chakma,
U_Cham,
U_Cherokee,
U_Common,
U_Coptic,
U_Cuneiform,
U_Cypriot,
U_Cyrillic,
U_Deseret,
U_Devanagari,
U_Duployan,
U_Egyptian_Hieroglyphs,
U_Elbasan,
U_Ethiopic,
U_Georgian,
U_Glagolitic,
U_Gothic,
U_Grantha,
U_Greek,
U_Gujarati,
U_Gurmukhi,
U_Han,
U_Hangul,
U_Hanunoo,
U_Hatran,
U_Hebrew,
U_Hiragana,
U_Imperial_Aramaic,
U_Inherited,
U_Inscriptional_Pahlavi,
U_Inscriptional_Parthian,
U_Javanese,
U_Kaithi,
U_Kannada,
U_Katakana,
U_Kayah_Li,
U_Kharoshthi,
U_Khmer,
U_Khojki,
U_Khudawadi,
U_Lao,
U_Latin,
U_Lepcha,
U_Limbu,
U_Linear_A,
U_Linear_B,
U_Lisu,
U_Lycian,
U_Lydian,
U_Mahajani,
U_Malayalam,
U_Mandaic,
U_Manichaean,
U_Marchen,
U_Meetei_Mayek,
U_Mende_Kikakui,
U_Meroitic_Cursive,
U_Meroitic_Hieroglyphs,
U_Miao,
U_Modi,
U_Mongolian,
U_Mro,
U_Multani,
U_Myanmar,
U_Nabataean,
U_New_Tai_Lue,
U_Newa,
U_Nko,
U_Ogham,
U_Ol_Chiki,
U_Old_Hungarian,
U_Old_Italic,
U_Old_North_Arabian,
U_Old_Permic,
U_Old_Persian,
U_Old_South_Arabian,
U_Old_Turkic,
U_Oriya,
U_Osage,
U_Osmanya,
U_Pahawh_Hmong,
U_Palmyrene,
U_Pau_Cin_Hau,
U_Phags_Pa,
U_Phoenician,
U_Psalter_Pahlavi,
U_Rejang,
U_Runic,
U_Samaritan,
U_Saurashtra,
U_Sharada,
U_Shavian,
U_Siddham,
U_SignWriting,
U_Sinhala,
U_Sora_Sompeng,
U_Sundanese,
U_Syloti_Nagri,
U_Syriac,
U_Tagalog,
U_Tagbanwa,
U_Tai_Le,
U_Tai_Tham,
U_Tai_Viet,
U_Takri,
U_Tamil,
U_Tangut,
U_Telugu,
U_Thaana,
U_Thai,
U_Tibetan,
U_Tifinagh,
U_Tirhuta,
U_Ugaritic,
U_Vai,
U_Warang_Citi,
U_Yi
};
ScriptType GetScript(char32 c);
} // namespace unicode_script
} // namespace sentencepiece
#endif // UNICODE_SCRIPT

1955
src/unicode_script_map.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unicode_script.h"
#include "common.h"
#include "stringpiece.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace unicode_script {
ScriptType GetScriptType(StringPiece s) {
const auto ut = string_util::UTF8ToUnicodeText(s);
CHECK_EQ(1, ut.size());
return GetScript(ut[0]);
}
TEST(UnicodeScript, GetScriptTypeTest) {
EXPECT_EQ(U_Han, GetScriptType(""));
EXPECT_EQ(U_Han, GetScriptType(""));
EXPECT_EQ(U_Hiragana, GetScriptType(""));
EXPECT_EQ(U_Katakana, GetScriptType(""));
EXPECT_EQ(U_Common, GetScriptType(""));
EXPECT_EQ(U_Latin, GetScriptType("a"));
EXPECT_EQ(U_Latin, GetScriptType("A"));
EXPECT_EQ(U_Common, GetScriptType("0"));
EXPECT_EQ(U_Common, GetScriptType("$"));
EXPECT_EQ(U_Common, GetScriptType("@"));
EXPECT_EQ(U_Common, GetScriptType("-"));
}
} // namespace unicode_script
} // namespace sentencepiece

458
src/unigram_model.cc Normal file
View File

@ -0,0 +1,458 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model.h"
#include <cfloat>
#include <map>
#include <queue>
#include <string>
#include <vector>
#include "stringpiece.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
namespace {
constexpr size_t kNodeChunkSize = 512;
}
Lattice::Lattice() {}
Lattice::~Lattice() { Clear(); }
const std::vector<Lattice::Node *> &Lattice::begin_nodes(int pos) const {
return begin_nodes_[pos];
}
const std::vector<Lattice::Node *> &Lattice::end_nodes(int pos) const {
return end_nodes_[pos];
}
int Lattice::size() const {
// -1 because surface_ may include the EOS.
return std::max<int>(0, surface_.size() - 1);
}
int Lattice::utf8_size() const { return sentence_.size(); }
const char *Lattice::sentence() const { return sentence_.data(); }
const char *Lattice::surface(int pos) const { return surface_[pos]; }
Lattice::Node *Lattice::bos_node() const { return end_nodes_[0][0]; }
Lattice::Node *Lattice::eos_node() const { return begin_nodes_[size()][0]; }
Lattice::Node *Lattice::NewNode() {
Node *node = new Node;
memset(node, 0, sizeof(*node));
node->node_id = all_nodes_.size();
all_nodes_.push_back(node);
return node;
}
void Lattice::Clear() {
begin_nodes_.clear();
end_nodes_.clear();
sentence_.clear();
surface_.clear();
all_nodes_.clear();
port::STLDeleteElements(&all_nodes_);
}
void Lattice::SetSentence(StringPiece sentence) {
Clear();
sentence_ = sentence;
CHECK(!sentence_.empty());
const char *begin = sentence_.data();
const char *end = sentence_.data() + sentence_.size();
while (begin < end) {
const int mblen =
std::min<int>(string_util::OneCharLen(begin), end - begin);
surface_.push_back(begin);
begin += mblen;
}
surface_.push_back(end);
const int len = size();
begin_nodes_.resize(len + 1);
end_nodes_.resize(len + 1);
for (int i = 0; i <= len; ++i) {
begin_nodes_[i].reserve(16);
end_nodes_[i].reserve(16);
}
Node *bos = NewNode();
bos->id = -1;
bos->pos = 0;
end_nodes_[0].push_back(bos);
Node *eos = NewNode();
eos->id = -1;
eos->pos = len;
begin_nodes_[len].push_back(eos);
}
Lattice::Node *Lattice::Insert(int pos, int length) {
Node *node = NewNode();
node->pos = pos;
node->length = length;
const int utf8_length =
static_cast<int>(surface(pos + length) - surface(pos));
node->piece.set(surface(pos), utf8_length);
begin_nodes_[pos].push_back(node);
end_nodes_[pos + node->length].push_back(node);
return node;
}
std::vector<Lattice::Node *> Lattice::Viterbi() {
const int len = size();
CHECK_GT(len, 0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
rnode->prev = nullptr;
float best_score = 0.0;
Node *best_node = nullptr;
for (Node *lnode : end_nodes_[pos]) {
const float score = lnode->backtrace_score + rnode->score;
if (best_node == nullptr || score > best_score) {
best_node = lnode;
best_score = score;
}
}
CHECK(best_node);
rnode->prev = best_node;
rnode->backtrace_score = best_score;
}
}
// backtrace
std::vector<Node *> results;
for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr;
node = node->prev) {
results.push_back(node);
}
std::reverse(results.begin(), results.end());
return results;
}
float Lattice::PopulateMarginal(float freq,
std::vector<float> *expected) const {
CHECK_NOTNULL(expected);
// Returns log(exp(x) + exp(y)).
// if flg is true, returns log(exp(y)) == y.
// log(\sum_i exp(a[i])) can be computed as
// for (int i = 0; i < a.size(); ++i)
// x = LogSumExp(x, a[i], i == 0);
auto LogSumExp = [](float x, float y, bool init_mode) -> float {
if (init_mode) {
return y;
}
const float vmin = std::min(x, y);
const float vmax = std::max(x, y);
constexpr float kMinusLogEpsilon = 50;
if (vmax > vmin + kMinusLogEpsilon) {
return vmax;
} else {
return vmax + log(exp(vmin - vmax) + 1.0);
}
};
const int len = size();
CHECK_GT(len, 0);
// alpha and beta (accumulative log prob) in Forward Backward.
// the index of alpha/beta is Node::node_id.
std::vector<float> alpha(all_nodes_.size(), 0.0);
std::vector<float> beta(all_nodes_.size(), 0.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
for (Node *lnode : end_nodes_[pos]) {
alpha[rnode->node_id] = LogSumExp(alpha[rnode->node_id],
lnode->score + alpha[lnode->node_id],
lnode == end_nodes_[pos][0]);
}
}
}
for (int pos = len; pos >= 0; --pos) {
for (Node *lnode : end_nodes_[pos]) {
for (Node *rnode : begin_nodes_[pos]) {
beta[lnode->node_id] =
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
rnode == begin_nodes_[pos][0]);
}
}
}
const float Z = alpha[begin_nodes_[len][0]->node_id];
for (int pos = 0; pos < len; ++pos) {
for (Node *node : begin_nodes_[pos]) {
if (node->id >= 0) {
// the index of |expected| is a Node::id, which is a vocabulary id.
(*expected)[node->id] += freq * exp(alpha[node->node_id] + node->score +
beta[node->node_id] - Z);
}
}
}
return freq * Z;
}
std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
CHECK_GT(size(), 0);
CHECK_GE(nbest_size, 1);
// Uses A* search to enumerate N-bests.
// Given a lattice, enumerates hypotheses (paths) from EOS.
// At each partial path x, compute f(x) as follows
// f(x) = g(x) + h(x).
// g(x): the sum of scores from EOS to the left-most node in x.
// h(x): a heuristic that estimates the largest score from x to BOS.
// f(x): the priority to pop a new hypothesis from the priority queue.
//
// As left-to-right Viterbi search can tell the *exact* value of h(x),
// we can obtain the exact n-best results with A*.
struct Hypothesis {
Node *node;
Hypothesis *next;
float fx;
float gx;
};
class HypothesisComparator {
public:
const bool operator()(Hypothesis *h1, Hypothesis *h2) {
return (h1->fx < h2->fx);
}
};
using Agenda = std::priority_queue<Hypothesis *, std::vector<Hypothesis *>,
HypothesisComparator>;
Agenda agenda;
std::vector<Hypothesis *> allocated;
std::vector<std::vector<Node *>> results;
auto NewHypothesis = [&allocated]() {
Hypothesis *h = new Hypothesis;
memset(h, 0, sizeof(*h));
allocated.push_back(h);
return h;
};
auto *eos = NewHypothesis();
eos->node = eos_node();
eos->next = nullptr;
eos->fx = eos->node->score;
eos->gx = eos->node->score;
agenda.push(eos);
// Run Viterbi first to fill backtrace score.
Viterbi();
while (!agenda.empty()) {
auto *top = agenda.top();
agenda.pop();
auto *node = top->node;
// Reaches to BOS
if (node == bos_node()) {
results.resize(results.size() + 1);
for (auto *n = top->next; n->next != nullptr; n = n->next) {
results.back().push_back(n->node);
}
if (results.size() == nbest_size) {
break;
}
continue;
}
// Expands new node ending at node->pos
for (Node *lnode : end_nodes(node->pos)) {
auto *hyp = NewHypothesis();
hyp->node = lnode;
hyp->gx = lnode->score + top->gx; // just adds node->score
hyp->fx =
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
hyp->next = top;
agenda.push(hyp);
}
}
port::STLDeleteElements(&allocated);
return results;
}
ModelBase::ModelBase() {}
ModelBase::~ModelBase() {}
void ModelBase::PopulateNodes(Lattice *lattice) const {
CHECK_NOTNULL(lattice);
CHECK_NOTNULL(trie_);
auto GetCharsLength = [](const char *begin, int len) {
const char *end = begin + len;
int result = 0;
while (begin < end) {
begin += std::min<int>(string_util::OneCharLen(begin), end - begin);
++result;
}
return result;
};
constexpr float kUnkPenalty = 10.0;
const float unk_score = min_score() - kUnkPenalty;
const int len = lattice->size();
const char *end = lattice->sentence() + lattice->utf8_size();
// Initializes the buffer for return values.
CHECK_GT(trie_results_size_, 0);
// +1 just in case.
std::vector<Darts::DoubleArray::result_pair_type> trie_results(
trie_results_size_ + 1);
for (int begin_pos = 0; begin_pos < len; ++begin_pos) {
const char *begin = lattice->surface(begin_pos);
// Finds all pieces which are prefix of surface(begin_pos).
const size_t num_nodes = trie_->commonPrefixSearch(
begin, trie_results.data(), trie_results.size(),
static_cast<int>(end - begin));
CHECK_LT(num_nodes, trie_results.size());
bool has_single_node = false;
// Inserts pieces to the lattice.
for (size_t k = 0; k < num_nodes; ++k) {
const int length = GetCharsLength(begin, trie_results[k].length);
Lattice::Node *node = lattice->Insert(begin_pos, length);
node->id = trie_results[k].value; // the value of Trie stores vocab_id.
node->score = GetScore(node->id); // calls method defined in subclass.
if (!has_single_node && node->length == 1) {
has_single_node = true;
}
}
if (!has_single_node) {
Lattice::Node *node = lattice->Insert(begin_pos, 1);
node->id = kUnkID; // add UNK node.
node->score = unk_score;
}
}
}
int ModelBase::PieceToId(StringPiece piece) const {
auto it = reserved_id_map_.find(piece);
if (it != reserved_id_map_.end()) {
return it->second;
}
int id = 0;
trie_->exactMatchSearch(piece.data(), id);
return id == -1 ? kUnkID : id;
}
void ModelBase::BuildTrie(std::vector<std::pair<std::string, int>> *pieces) {
CHECK_NOTNULL(pieces);
CHECK(!pieces->empty());
// sort by sentencepiece since DoubleArray::build()
// only accepts sorted strings.
sort(pieces->begin(), pieces->end());
// Makes key/value set for DoubleArrayTrie.
std::vector<const char *> key(pieces->size());
std::vector<int> value(pieces->size());
for (size_t i = 0; i < pieces->size(); ++i) {
key[i] = (*pieces)[i].first.c_str(); // sorted piece.
value[i] = (*pieces)[i].second; // vocab_id
}
trie_ = port::MakeUnique<Darts::DoubleArray>();
CHECK_EQ(0,
trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
&value[0]))
<< "cannot build double-array";
// Computes the maximum number of shared prefixes in the trie.
const int kMaxTrieResultsSize = 1024;
std::vector<Darts::DoubleArray::result_pair_type> results(
kMaxTrieResultsSize);
trie_results_size_ = 0;
for (const auto &p : *pieces) {
const int num_nodes = trie_->commonPrefixSearch(
p.first.data(), results.data(), results.size(), p.first.size());
trie_results_size_ = std::max(trie_results_size_, num_nodes);
}
CHECK_GT(trie_results_size_, 0);
}
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
min_score_ = FLT_MAX;
CheckControlSymbols();
std::vector<std::pair<std::string, int>> pieces; // <piece, vocab_id>
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
CHECK(!sp.piece().empty());
if (sp.type() == ModelProto::SentencePiece::NORMAL ||
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
CHECK(sp.has_score());
pieces.emplace_back(sp.piece(), i);
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
}
if (sp.type() == ModelProto::SentencePiece::NORMAL) {
min_score_ = std::min(min_score_, sp.score());
}
}
BuildTrie(&pieces);
}
Model::~Model() {}
std::vector<std::pair<StringPiece, int>> Model::Encode(
StringPiece normalized) const {
if (normalized.empty()) {
return {};
}
Lattice lattice;
lattice.SetSentence(normalized);
PopulateNodes(&lattice);
std::vector<std::pair<StringPiece, int>> results;
for (const auto *node : lattice.Viterbi()) {
results.emplace_back(node->piece, node->id);
}
return results;
}
} // namespace unigram
} // namespace sentencepiece

149
src/unigram_model.h Normal file
View File

@ -0,0 +1,149 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef UNIGRAM_MODEL_H_
#define UNIGRAM_MODEL_H_
#include "common.h"
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
#include "third_party/darts_clone/darts.h"
namespace sentencepiece {
namespace unigram {
// Lattice represents a search space of sentence piece segmentation.
class Lattice {
public:
Lattice();
virtual ~Lattice();
struct Node {
StringPiece piece; // Sentence piece representation.
uint32 pos; // Unicode position in the sentence.
uint32 length; // Unicode length, not UT8 byte.
uint32 node_id; // unique id in the current lattice.
int id; // vocab id. (maybe -1 for UNK)
float score; // logprob of this sentencepiece.
float backtrace_score; // backtrace info used in Viterbi.
Node *prev; // best previous node on Viterbi path.
std::string DebugString() const;
};
// Returns bos node.
Node *bos_node() const;
// Returns eos node.
Node *eos_node() const;
// Returns nodes starting at |pos|.
const std::vector<Node *> &begin_nodes(int pos) const;
// Returns nodes ending at |pos|.
const std::vector<Node *> &end_nodes(int pos) const;
// Returns Unicode character length.
int size() const;
// Returns multi-byte (utf8) length.
int utf8_size() const;
// Returns the substring of sentence. sentence[pos:]
const char *surface(int pos) const;
// Returns immutable sentence. The same as surface(0)
const char *sentence() const;
// Clears the lattice.
void Clear();
// Sets new sentence.
void SetSentence(StringPiece sentence);
// Inserts a new node at [pos, pos + length - 1].
// After calling this method, The caller must set Node::score and Node::id.
Node *Insert(int pos, int length);
// Returns Viterbi path. All nodes must be populated in advance.
std::vector<Node *> Viterbi();
// Returns n-best results.
std::vector<std::vector<Node *>> NBest(size_t nbest_size);
// Populates marginal probability of every node in this lattice.
// |freq| is the frequency of the sentence.
// for (auto *node : all_nodes_) {
// (*expected)[node->id] += marginal_prob_of_node * freq;
// }
// Returns the log-likelihood of this sentence.
float PopulateMarginal(float freq, std::vector<float> *expected) const;
private:
// Returns new node.
// Lattice class has the ownership of the returned value.
Node *NewNode();
StringPiece sentence_;
std::vector<const char *> surface_;
std::vector<std::vector<Node *>> begin_nodes_;
std::vector<std::vector<Node *>> end_nodes_;
std::vector<Node *> all_nodes_;
};
// Base class for Unigram Model.
// We have base Model class because we will have different
// implementations for training and testing.
// Trie management part is shared by training and testing.
class ModelBase : public ModelInterface {
public:
ModelBase();
~ModelBase() override;
// Returns the minimum score in sentence pieces.
// min_score() - 10 is used for the cost of unknown sentence.
float min_score() const { return min_score_; }
// Populates all sentence pieces to the |lattice|.
// After calling this function, lattice.Viterbi() returns the
// best segmentation.
void PopulateNodes(Lattice *lattice) const;
// Returns a vocab id of |piece|.
int PieceToId(StringPiece piece) const override;
protected:
// Builds a Trie index.
void BuildTrie(std::vector<std::pair<std::string, int>> *pieces);
float min_score_;
std::unique_ptr<Darts::DoubleArray> trie_;
// Maximum size of the return value of Trie, which corresponds
// to the maximum size of shared common prefix in the sentence pieces.
int trie_results_size_;
};
// Unigram model class for decoding.
class Model : public ModelBase {
public:
explicit Model(const ModelProto &model_proto);
~Model() override;
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const override;
};
} // namespace unigram
} // namespace sentencepiece
#endif // UNIGRAM_MODEL_H_

403
src/unigram_model_test.cc Normal file
View File

@ -0,0 +1,403 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model.h"
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include "sentencepiece_model.pb.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
TEST(LatticeTest, SetSentenceTest) {
Lattice lattice;
EXPECT_EQ(0, lattice.size());
EXPECT_EQ(0, lattice.utf8_size());
lattice.SetSentence("test");
EXPECT_EQ(4, lattice.size());
EXPECT_EQ(4, lattice.utf8_size());
EXPECT_STREQ("test", lattice.sentence());
EXPECT_STREQ("test", lattice.surface(0));
EXPECT_STREQ("est", lattice.surface(1));
EXPECT_STREQ("st", lattice.surface(2));
EXPECT_STREQ("t", lattice.surface(3));
Lattice::Node *bos = lattice.bos_node();
Lattice::Node *eos = lattice.eos_node();
EXPECT_EQ(-1, bos->id);
EXPECT_EQ(-1, eos->id);
EXPECT_EQ(bos, lattice.end_nodes(0).front());
EXPECT_EQ(eos, lattice.begin_nodes(4).front());
lattice.SetSentence("テストab");
EXPECT_EQ(5, lattice.size());
EXPECT_EQ(11, lattice.utf8_size());
EXPECT_STREQ("テストab", lattice.sentence());
EXPECT_STREQ("テストab", lattice.surface(0));
EXPECT_STREQ("ストab", lattice.surface(1));
EXPECT_STREQ("トab", lattice.surface(2));
EXPECT_STREQ("ab", lattice.surface(3));
EXPECT_STREQ("b", lattice.surface(4));
lattice.Clear();
EXPECT_EQ(0, lattice.size());
EXPECT_EQ(0, lattice.utf8_size());
}
TEST(LatticeTest, InsertTest) {
Lattice lattice;
lattice.SetSentence("ABあい");
Lattice::Node *node[7];
node[0] = lattice.Insert(0, 1);
node[1] = lattice.Insert(1, 1);
node[2] = lattice.Insert(2, 1);
node[3] = lattice.Insert(3, 1);
node[4] = lattice.Insert(0, 2);
node[5] = lattice.Insert(1, 2);
node[6] = lattice.Insert(2, 2);
EXPECT_EQ("A", node[0]->piece);
EXPECT_EQ("B", node[1]->piece);
EXPECT_EQ("", node[2]->piece);
EXPECT_EQ("", node[3]->piece);
EXPECT_EQ("AB", node[4]->piece);
EXPECT_EQ("Bあ", node[5]->piece);
EXPECT_EQ("あい", node[6]->piece);
EXPECT_EQ("A", node[0]->piece);
EXPECT_EQ("B", node[1]->piece);
EXPECT_EQ("", node[2]->piece);
EXPECT_EQ("", node[3]->piece);
EXPECT_EQ("AB", node[4]->piece);
EXPECT_EQ("Bあ", node[5]->piece);
EXPECT_EQ("あい", node[6]->piece);
EXPECT_EQ(0, node[0]->pos);
EXPECT_EQ(1, node[1]->pos);
EXPECT_EQ(2, node[2]->pos);
EXPECT_EQ(3, node[3]->pos);
EXPECT_EQ(0, node[4]->pos);
EXPECT_EQ(1, node[5]->pos);
EXPECT_EQ(2, node[6]->pos);
EXPECT_EQ(1, node[0]->length);
EXPECT_EQ(1, node[1]->length);
EXPECT_EQ(1, node[2]->length);
EXPECT_EQ(1, node[3]->length);
EXPECT_EQ(2, node[4]->length);
EXPECT_EQ(2, node[5]->length);
EXPECT_EQ(2, node[6]->length);
EXPECT_EQ(0, lattice.bos_node()->node_id);
EXPECT_EQ(1, lattice.eos_node()->node_id);
EXPECT_EQ(2, node[0]->node_id);
EXPECT_EQ(3, node[1]->node_id);
EXPECT_EQ(4, node[2]->node_id);
EXPECT_EQ(5, node[3]->node_id);
EXPECT_EQ(6, node[4]->node_id);
EXPECT_EQ(7, node[5]->node_id);
EXPECT_EQ(8, node[6]->node_id);
EXPECT_EQ(2, lattice.begin_nodes(0).size());
EXPECT_EQ(2, lattice.begin_nodes(1).size());
EXPECT_EQ(2, lattice.begin_nodes(2).size());
EXPECT_EQ(1, lattice.begin_nodes(3).size());
EXPECT_EQ(1, lattice.begin_nodes(4).size()); // EOS
EXPECT_EQ(1, lattice.end_nodes(0).size()); // BOS
EXPECT_EQ(1, lattice.end_nodes(1).size());
EXPECT_EQ(2, lattice.end_nodes(2).size());
EXPECT_EQ(2, lattice.end_nodes(3).size());
EXPECT_EQ(2, lattice.end_nodes(4).size());
EXPECT_EQ(node[0], lattice.begin_nodes(0)[0]);
EXPECT_EQ(node[4], lattice.begin_nodes(0)[1]);
EXPECT_EQ(node[1], lattice.begin_nodes(1)[0]);
EXPECT_EQ(node[5], lattice.begin_nodes(1)[1]);
EXPECT_EQ(node[2], lattice.begin_nodes(2)[0]);
EXPECT_EQ(node[6], lattice.begin_nodes(2)[1]);
EXPECT_EQ(node[3], lattice.begin_nodes(3)[0]);
EXPECT_EQ(lattice.eos_node(), lattice.begin_nodes(4)[0]);
EXPECT_EQ(lattice.bos_node(), lattice.end_nodes(0)[0]);
EXPECT_EQ(node[0], lattice.end_nodes(1)[0]);
EXPECT_EQ(node[1], lattice.end_nodes(2)[0]);
EXPECT_EQ(node[4], lattice.end_nodes(2)[1]);
EXPECT_EQ(node[2], lattice.end_nodes(3)[0]);
EXPECT_EQ(node[5], lattice.end_nodes(3)[1]);
EXPECT_EQ(node[3], lattice.end_nodes(4)[0]);
EXPECT_EQ(node[6], lattice.end_nodes(4)[1]);
}
TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) {
Lattice lattice;
lattice.SetSentence("ABC");
EXPECT_DEATH(lattice.Viterbi());
// Still incomplete
lattice.Insert(0, 1);
EXPECT_DEATH(lattice.Viterbi());
lattice.Insert(1, 1);
lattice.Insert(2, 1);
lattice.Viterbi();
}
std::string GetTokenized(const std::vector<Lattice::Node *> &nodes) {
std::vector<std::string> tokens;
for (auto *node : nodes) {
tokens.push_back(node->piece.to_string());
}
return string_util::Join(tokens, " ");
}
void InsertWithScore(Lattice *lattice, int pos, int length, float score) {
lattice->Insert(pos, length)->score = score;
}
void InsertWithScoreAndId(Lattice *lattice, int pos, int length, float score,
int id) {
auto *node = lattice->Insert(pos, length);
node->score = score;
node->id = id;
}
TEST(LatticeTest, ViterbiTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.0); // C
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi()));
InsertWithScore(&lattice, 0, 2, 2.0); // AB
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi()));
InsertWithScore(&lattice, 1, 2, 5.0); // BC
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi()));
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi()));
}
TEST(LatticeTest, NBestTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.0); // C
InsertWithScore(&lattice, 0, 2, 2.0); // AB
InsertWithScore(&lattice, 1, 2, 5.0); // BC
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
auto nbests = lattice.NBest(10);
EXPECT_EQ(4, nbests.size());
EXPECT_EQ("ABC", GetTokenized(nbests[0]));
EXPECT_EQ("A BC", GetTokenized(nbests[1]));
EXPECT_EQ("AB C", GetTokenized(nbests[2]));
EXPECT_EQ("A B C", GetTokenized(nbests[3]));
}
TEST(LatticeTest, PopulateMarginalTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScoreAndId(&lattice, 0, 1, 1.0, 0); // A
InsertWithScoreAndId(&lattice, 1, 1, 1.2, 1); // B
InsertWithScoreAndId(&lattice, 2, 1, 2.5, 2); // C
InsertWithScoreAndId(&lattice, 0, 2, 3.0, 3); // AB
InsertWithScoreAndId(&lattice, 1, 2, 4.0, 4); // BC
InsertWithScoreAndId(&lattice, 0, 3, 2.0, 5); // ABC
std::vector<float> probs(6, 0.0);
// Expand all paths:
// A B C : exp(1.0 + 1.2 + 2.5) => path1
// AB C : exp(3.0 + 2.5) => path2
// A BC : exp(1.0 + 4.0) => path3
// ABC : exp(2.0) => path4
const float p1 = exp(1.0 + 1.2 + 2.5);
const float p2 = exp(3.0 + 2.5);
const float p3 = exp(1.0 + 4.0);
const float p4 = exp(2.0);
const float Z = p1 + p2 + p3 + p4;
const float logZ = lattice.PopulateMarginal(1.0, &probs);
EXPECT_NEAR((p1 + p3) / Z, probs[0], 0.001); // A
EXPECT_NEAR(p1 / Z, probs[1], 0.001); // B
EXPECT_NEAR((p1 + p2) / Z, probs[2], 0.001); // C
EXPECT_NEAR(p2 / Z, probs[3], 0.001); // AB
EXPECT_NEAR(p3 / Z, probs[4], 0.001); // BC
EXPECT_NEAR(p4 / Z, probs[5], 0.001); // ABC
EXPECT_NEAR(log(Z), logZ, 0.001);
}
ModelProto MakeBaseModelProto() {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(UnigramModelTest, SetUnigramModelTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a");
AddPiece(&model_proto, "b");
AddPiece(&model_proto, "c");
AddPiece(&model_proto, "d");
const Model model(model_proto);
EXPECT_EQ(model_proto.SerializeAsString(),
model.model_proto().SerializeAsString());
}
TEST(UnigramModelTest, PieceToIdTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a", 0.1);
AddPiece(&model_proto, "b", 0.2);
AddPiece(&model_proto, "c", 0.3);
AddPiece(&model_proto, "d", 0.4);
const Model model(model_proto);
EXPECT_EQ(model_proto.SerializeAsString(),
model.model_proto().SerializeAsString());
EXPECT_NEAR(0.1, model.min_score(), 0.001);
EXPECT_EQ(0, model.PieceToId("<unk>"));
EXPECT_EQ(1, model.PieceToId("<s>"));
EXPECT_EQ(2, model.PieceToId("</s>"));
EXPECT_EQ(3, model.PieceToId("a"));
EXPECT_EQ(4, model.PieceToId("b"));
EXPECT_EQ(5, model.PieceToId("c"));
EXPECT_EQ(6, model.PieceToId("d"));
EXPECT_EQ(0, model.PieceToId("e")); // unk
EXPECT_EQ(0, model.PieceToId("")); // unk
EXPECT_EQ("<unk>", model.IdToPiece(0));
EXPECT_EQ("<s>", model.IdToPiece(1));
EXPECT_EQ("</s>", model.IdToPiece(2));
EXPECT_EQ("a", model.IdToPiece(3));
EXPECT_EQ("b", model.IdToPiece(4));
EXPECT_EQ("c", model.IdToPiece(5));
EXPECT_EQ("d", model.IdToPiece(6));
EXPECT_TRUE(model.IsUnknown(0));
EXPECT_FALSE(model.IsUnknown(1));
EXPECT_FALSE(model.IsUnknown(2));
EXPECT_FALSE(model.IsUnknown(3));
EXPECT_FALSE(model.IsUnknown(4));
EXPECT_FALSE(model.IsUnknown(5));
EXPECT_FALSE(model.IsUnknown(6));
EXPECT_FALSE(model.IsControl(0));
EXPECT_TRUE(model.IsControl(1));
EXPECT_TRUE(model.IsControl(2));
EXPECT_FALSE(model.IsControl(3));
EXPECT_FALSE(model.IsControl(4));
EXPECT_FALSE(model.IsControl(5));
EXPECT_FALSE(model.IsControl(6));
EXPECT_NEAR(0, model.GetScore(0), 0.0001);
EXPECT_NEAR(0, model.GetScore(1), 0.0001);
EXPECT_NEAR(0, model.GetScore(2), 0.0001);
EXPECT_NEAR(0.1, model.GetScore(3), 0.0001);
EXPECT_NEAR(0.2, model.GetScore(4), 0.0001);
EXPECT_NEAR(0.3, model.GetScore(5), 0.0001);
EXPECT_NEAR(0.4, model.GetScore(6), 0.0001);
EXPECT_TRUE(model.Encode("").empty());
}
TEST(UnigramModelTest, PopulateNodesAllUnknownsTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "x");
const Model model(model_proto);
Lattice lattice;
lattice.SetSentence("abc");
model.PopulateNodes(&lattice);
EXPECT_EQ(1, lattice.begin_nodes(0).size());
EXPECT_EQ(1, lattice.begin_nodes(1).size());
EXPECT_EQ(1, lattice.begin_nodes(2).size());
EXPECT_EQ(0, lattice.begin_nodes(0)[0]->id);
EXPECT_EQ(0, lattice.begin_nodes(1)[0]->id);
EXPECT_EQ(0, lattice.begin_nodes(2)[0]->id);
}
TEST(UnigramModelTest, PopulateNodesTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a", 0.1); // 3
AddPiece(&model_proto, "b", 0.2); // 4
AddPiece(&model_proto, "ab", 0.3); // 5
AddPiece(&model_proto, "bc", 0.4); // 6
const Model model(model_proto);
Lattice lattice;
lattice.SetSentence("abc");
model.PopulateNodes(&lattice);
EXPECT_EQ(2, lattice.begin_nodes(0).size()); // a,ab
EXPECT_EQ(2, lattice.begin_nodes(1).size()); // b,bc
EXPECT_EQ(1, lattice.begin_nodes(2).size()); // c(unk)
EXPECT_EQ(3, lattice.begin_nodes(0)[0]->id);
EXPECT_EQ(5, lattice.begin_nodes(0)[1]->id);
EXPECT_EQ(4, lattice.begin_nodes(1)[0]->id);
EXPECT_EQ(6, lattice.begin_nodes(1)[1]->id);
EXPECT_EQ(0, lattice.begin_nodes(2)[0]->id);
EXPECT_NEAR(0.1, lattice.begin_nodes(0)[0]->score, 0.001);
EXPECT_NEAR(0.3, lattice.begin_nodes(0)[1]->score, 0.001);
EXPECT_NEAR(0.2, lattice.begin_nodes(1)[0]->score, 0.001);
EXPECT_NEAR(0.4, lattice.begin_nodes(1)[1]->score, 0.001);
}
} // namespace unigram
} // namespace sentencepiece

View File

@ -0,0 +1,552 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model_trainer.h"
#include "normalizer.h"
#include <cfloat>
#include <cmath>
#include <functional>
#include <memory>
#include <numeric>
#include <queue>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "third_party/esaxx/esa.hxx" // Suffix array library.
#include "unicode_script.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
namespace {
constexpr char32 kWSChar = L'\u2581';
constexpr char32 kUNKChar = L'\u2585';
constexpr char kUNKStr[] = "\xe2\x96\x85";
double Digamma(double x) {
double result = 0.0;
for (; x < 7; ++x) result -= 1 / x;
x -= 1.0 / 2.0;
const double xx = 1.0 / x;
const double xx2 = xx * xx;
const double xx4 = xx2 * xx2;
result += log(x) + (1.0 / 24.0) * xx2 - (7.0 / 960.0) * xx4 +
(31.0 / 8064.0) * xx4 * xx2 - (127.0 / 30720.0) * xx4 * xx4;
return result;
}
template <typename IT>
void ToLogProb(IT begin, IT end) {
float sum = 0.0;
for (auto it = begin; it != end; ++it) {
sum += it->second;
}
float logsum = log(sum);
for (auto it = begin; it != end; ++it) {
it->second = log(it->second) - logsum;
}
}
class ThreadPool {
public:
ThreadPool() {}
virtual ~ThreadPool() {
for (auto &task : tasks_) {
task.join();
}
}
void Schedule(std::function<void()> closure) { tasks_.emplace_back(closure); }
private:
std::vector<std::thread> tasks_;
};
} // namespace
TrainerModel::TrainerModel(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {}
TrainerModel::~TrainerModel() {}
const TrainerModel::SentencePieces &TrainerModel::GetSentencePieces() const {
return sentencepieces_;
}
void TrainerModel::SetSentencePieces(SentencePieces &&sentencepieces) {
sentencepieces_ = std::move(sentencepieces);
CHECK(!sentencepieces_.empty());
min_score_ = FLT_MAX;
std::vector<std::pair<std::string, int>> pieces;
for (size_t i = 0; i < sentencepieces_.size(); ++i) {
const std::string &w = sentencepieces_[i].first; // piece
const float score = sentencepieces_[i].second; // score.
CHECK(!std::isnan(score));
pieces.emplace_back(w, i);
min_score_ = std::min(min_score_, score);
}
BuildTrie(&pieces);
}
// Returns seed sentencepieces for EM training.
TrainerModel::SentencePieces Trainer::MakeSeedSentencePieces() const {
CHECK(!sentences_.empty());
CHECK(!required_chars_.empty());
CHECK(port::ContainsKey(required_chars_, kWSChar));
// Merges all sentences into one array with 0x0000 delimiter.
std::vector<char32> array;
std::unordered_map<std::string, int64> all_chars, substrs;
constexpr char32 kSentenceBoundary = 0x0000;
const size_t mining_size =
std::min<size_t>(sentences_.size(), trainer_spec_.mining_sentence_size());
LOG(INFO) << "Using " << mining_size
<< " sentences for making seed sentencepieces";
std::vector<std::string> mining_sentences(mining_size);
for (size_t i = 0; i < mining_size; ++i) {
mining_sentences[i] = sentences_[i].first;
}
for (const auto &w : mining_sentences) {
for (const auto &c : string_util::UTF8ToUnicodeText(w)) {
array.push_back(c);
if (c != kUNKChar && c != kSentenceBoundary) {
++all_chars[string_util::UnicodeCharToUTF8(c)];
}
}
array.push_back(kSentenceBoundary); // sentence boundary marker.
}
const int n = array.size();
std::vector<int> SA(n); // suffix array
std::vector<int> L(n); // left boundaries of internal node
std::vector<int> R(n); // right boundaries of internal node
std::vector<int> D(n); // depths of internal node
// Makes a suffix array to extract all sub strings occurring
// more than 2 times in the sentence.
constexpr int kAlphabetSize = 0x110000; // All UCS4 range.
int node_num = 0;
LOG(INFO) << "Making suffix array...";
CHECK_EQ(0, esaxx(array.begin(), SA.begin(), L.begin(), R.begin(), D.begin(),
n, kAlphabetSize, node_num));
LOG(INFO) << "Extracting frequent sub strings...";
std::vector<std::pair<int, int>> substr_index;
for (int i = 0; i < node_num; ++i) {
const int offset = SA[L[i]];
const int len = D[i];
if (len <= 1) {
continue;
}
const char32 *begin = &array[0] + offset;
const char32 *end = &array[0] + offset + len;
// Skips if a substring contains a sentence boundary.
if (std::find(begin, end, kSentenceBoundary) != end) {
continue;
}
const UnicodeText uw(begin, end);
if (!IsValidSentencePiece(uw)) {
continue;
}
// character-wise coverage is the default score.
const int freq = R[i] - L[i];
const int score = freq * len;
substr_index.emplace_back(i, score);
}
// all_chars must be included in the seed sentencepieces.
TrainerModel::SentencePieces seed_sentencepieces;
for (const auto &it : Sorted(all_chars)) {
seed_sentencepieces.emplace_back(it);
}
// Sort by the coverage of sub strings.
for (const auto &p : Sorted(substr_index)) {
const int offset = SA[L[p.first]];
const int len = D[p.first];
CHECK_GT(len, 0);
const char32 *begin = &array[offset];
const char32 *end = &array[offset + len];
const UnicodeText uw(begin, end);
CHECK(IsValidSentencePiece(uw)); // just in case.
const std::string w = string_util::UnicodeTextToUTF8(uw);
if (seed_sentencepieces.size() ==
static_cast<size_t>(trainer_spec_.seed_sentencepiece_size())) {
break;
}
CHECK(!port::ContainsKey(all_chars, w));
seed_sentencepieces.emplace_back(w, p.second);
}
ToLogProb(seed_sentencepieces.begin(), seed_sentencepieces.end());
LOG(INFO) << "Initialized " << seed_sentencepieces.size()
<< " seed sentencepieces";
return seed_sentencepieces;
}
std::vector<float> Trainer::RunEStep(const TrainerModel &model, float *obj,
int64 *num_tokens) const {
std::vector<std::vector<float>> expected(trainer_spec_.num_threads());
std::vector<float> objs(trainer_spec_.num_threads(), 0.0);
std::vector<int64> ntokens(trainer_spec_.num_threads(), 0.0);
auto pool = port::MakeUnique<ThreadPool>();
int64 all_sentence_freq = 0;
for (const auto &w : sentences_) {
all_sentence_freq += w.second;
}
// Executes E step in parallel
for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
pool->Schedule([&, n]() {
Lattice lattice;
expected[n].resize(model.GetPieceSize(), 0.0);
for (size_t i = n; i < sentences_.size();
i += trainer_spec_.num_threads()) {
const std::string &w = sentences_[i].first;
const int64 freq = sentences_[i].second;
lattice.SetSentence(w);
model.PopulateNodes(&lattice);
const float Z = lattice.PopulateMarginal(freq, &expected[n]);
ntokens[n] += lattice.Viterbi().size();
CHECK(!std::isnan(Z))
<< "likelihood is NAN. Input sentence may be too long";
objs[n] -= Z / all_sentence_freq;
}
});
}
pool.reset(nullptr);
// Merges expectations
for (int n = 1; n < trainer_spec_.num_threads(); ++n) {
objs[0] += objs[n];
ntokens[0] += ntokens[n];
for (size_t k = 0; k < expected[0].size(); ++k) {
expected[0][k] += expected[n][k];
}
}
*obj = objs[0];
*num_tokens = ntokens[0];
CHECK(!std::isnan(*obj));
return expected[0];
}
TrainerModel::SentencePieces Trainer::RunMStep(
const TrainerModel &model, const std::vector<float> &expected) const {
const auto &sentencepieces = model.GetSentencePieces();
CHECK_EQ(sentencepieces.size(), expected.size());
TrainerModel::SentencePieces new_sentencepieces;
float sum = 0.0;
for (size_t i = 0; i < expected.size(); ++i) {
const float freq = expected[i];
// Filter infrequent sentencepieces here.
constexpr float kExpectedFrequencyThreshold = 0.5;
if (freq < kExpectedFrequencyThreshold) {
continue;
}
new_sentencepieces.emplace_back(sentencepieces[i].first, freq);
sum += freq;
}
// Here we do not use the original EM, but use the
// Bayesianified/DPified EM algorithm.
// https://cs.stanford.edu/~pliang/papers/tutorial-acl2007-talk.pdf
// This modification will act as a sparse prior.
const float logsum = Digamma(sum);
for (auto &w : new_sentencepieces) {
w.second = Digamma(w.second) - logsum;
}
return new_sentencepieces;
}
TrainerModel::SentencePieces Trainer::PruneSentencePieces(
const TrainerModel &model) const {
const auto &sentencepieces = model.GetSentencePieces();
Lattice lattice;
std::vector<bool> always_keep(sentencepieces.size(), true);
std::vector<std::vector<int>> alternatives(sentencepieces.size());
// First, segments the current sentencepieces to know
// how each sentencepiece is resegmented if this sentencepiece is removed
// from the vocabulary.
// To do so, we take the second best segmentation of sentencepiece[i].
// alternatives[i] stores the sequence of second best sentencepieces.
for (size_t i = 0; i < sentencepieces.size(); ++i) {
const auto &w = sentencepieces[i];
lattice.SetSentence(w.first);
model.PopulateNodes(&lattice);
const auto nbests = lattice.NBest(2);
if (nbests.size() == 1) {
// No second-best result is found. always keep this sentencepiece.
always_keep[i] = true;
continue;
} else if (nbests[0].size() >= 2) {
// Can safely remove this sentencepiece if its Viterbi path is split.
always_keep[i] = false;
} else if (nbests[0].size() == 1) {
always_keep[i] = true;
for (const auto *node : nbests[1]) {
alternatives[i].push_back(node->id);
}
}
}
// Second, segments all sentences to compute likelihood
// with a unigram language model. inverted[i] stores
// the set of sentence index where the sentencepieces[i] appears.
float vsum = 0.0;
std::vector<float> freq(sentencepieces.size(), 0.0);
std::vector<std::vector<int>> inverted(sentencepieces.size());
{
std::vector<float> vsums(trainer_spec_.num_threads(), 0.0);
std::vector<std::vector<float>> freqs(trainer_spec_.num_threads());
std::vector<std::vector<std::vector<int>>> inverteds(
trainer_spec_.num_threads());
auto pool = port::MakeUnique<ThreadPool>();
for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
freqs[n].resize(sentencepieces.size(), 0.0);
inverteds[n].resize(sentencepieces.size());
pool->Schedule([&, n]() {
Lattice lattice;
for (size_t i = n; i < sentences_.size();
i += trainer_spec_.num_threads()) {
const auto &w = sentences_[i];
lattice.SetSentence(w.first);
model.PopulateNodes(&lattice);
vsums[n] += w.second;
for (const auto *node : lattice.Viterbi()) {
if (node->id >= 0) {
freqs[n][node->id] += w.second;
inverteds[n][node->id].push_back(i);
}
}
}
});
}
pool.reset(nullptr);
for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
vsum += vsums[n];
for (size_t i = 0; i < sentencepieces.size(); ++i) {
freq[i] += freqs[n][i];
std::copy(inverteds[n][i].begin(), inverteds[n][i].end(),
std::back_inserter(inverted[i]));
}
}
}
const float sum = std::accumulate(freq.begin(), freq.end(), 0.0);
const float logsum = log(sum);
std::vector<std::pair<int, float>> candidates;
TrainerModel::SentencePieces new_sentencepieces;
// Finally, computes how likely the LM likelihood is reduced if
// the sentencepiece[i] is removed from the vocabulary.
// Since the exact computation of loss is difficult, we compute the
// loss approximately by assuming that all sentencepiece[i] in the sentences
// are replaced with alternatives[i] when sentencepiece[i] is removed.
for (size_t i = 0; i < sentencepieces.size(); ++i) {
if (freq[i] == 0 || !always_keep[i]) {
// not found in Viterbi path. Can remove this entry safely.
continue;
} else if (alternatives[i].empty()) {
// no alternatives. Keeps this entry.
new_sentencepieces.push_back(sentencepieces[i]);
} else {
float F = 0.0; // the frequency of sentencepieces[i].
for (const int n : inverted[i]) {
F += sentences_[n].second;
}
F /= vsum; // normalizes by all sentence frequency.
// The logprob with the sentencepiece[i].
const float logprob_sp = log(freq[i]) - logsum;
// After removing the sentencepiece[i], its frequency freq[i] is
// re-assigned to alternatives.
// new_sum = current_sum - freq[i] + freq[i] * alternatives.size()
// = current_sum + freq[i] (alternatives - 1)
const float logsum_alt = log(sum + freq[i] * (alternatives.size() - 1));
// The frequencies of altenatives are increased by freq[i].
float logprob_alt = 0.0;
for (const int n : alternatives[i]) {
logprob_alt += (log(freq[n] + freq[i]) - logsum_alt);
}
// loss: the diff of likelihood after removing the sentencepieces[i].
const float loss = F * (logprob_sp - logprob_alt);
candidates.emplace_back(i, loss);
}
}
const int pruned_size =
std::max<int>(desired_vocab_size_,
trainer_spec_.shrinking_factor() * sentencepieces.size());
// Keeps trainer_spec_.shrinking_factor * sentencepieces.size() pieces.
// shrinking_factor is 0.75 by default.
for (const auto &w : Sorted(candidates)) {
if (new_sentencepieces.size() == static_cast<size_t>(pruned_size)) {
break;
}
new_sentencepieces.emplace_back(sentencepieces[w.first]);
}
return new_sentencepieces;
}
TrainerModel::SentencePieces Trainer::FinalizeSentencePieces(
const TrainerModel &model) const {
const auto &sentencepieces = model.GetSentencePieces();
std::unordered_map<std::string, float> final_sentencepieces;
std::unordered_map<std::string, float> sp(sentencepieces.begin(),
sentencepieces.end());
// required_chars_ must be included in the final sentencepieces.
float min_score_penalty = 0.0;
constexpr float kMinScorePenaltyDelta = 0.0001;
for (const auto &w : Sorted(required_chars_)) {
const std::string s = string_util::UnicodeCharToUTF8(w.first);
if (port::ContainsKey(sp, s)) {
final_sentencepieces[s] = sp[s];
} else {
// Add penalty to avoid required pieces from having the same score.
// Since the required_chars_ is sorted, frequent pieces have
// less penalties.
final_sentencepieces[s] = model.min_score() + min_score_penalty;
min_score_penalty += kMinScorePenaltyDelta;
}
}
const int meta_symbols_size = trainer_spec_.control_symbols().size() +
trainer_spec_.user_defined_symbols().size() +
3; // <s>, </s>, <unk>
const int vocab_size_size = trainer_spec_.vocab_size() - meta_symbols_size;
CHECK_GT(vocab_size_size, 0);
// Then keeps sentencepieces with higher scores.
for (const auto &w : Sorted(sentencepieces)) {
if (port::ContainsKey(final_sentencepieces, w.first)) {
continue;
}
if (static_cast<size_t>(vocab_size_size) == final_sentencepieces.size()) {
break;
}
final_sentencepieces[w.first] = w.second;
}
return Sorted(final_sentencepieces);
}
void Trainer::Train() {
#define CHECK_RANGE(variable, minval, maxval) \
CHECK(variable >= minval && variable <= maxval)
CHECK_GT(trainer_spec_.input().size(), 0);
CHECK(!trainer_spec_.model_prefix().empty());
CHECK_RANGE(trainer_spec_.vocab_size(), 100, 320000);
CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0);
CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000);
CHECK_RANGE(trainer_spec_.mining_sentence_size(), 100, 5000000);
CHECK_RANGE(trainer_spec_.training_sentence_size(), 100, 100000000);
CHECK_RANGE(trainer_spec_.max_sentencepiece_length(), 1, 64);
CHECK_RANGE(trainer_spec_.seed_sentencepiece_size(), 1000, 5000000);
CHECK_RANGE(trainer_spec_.shrinking_factor(), 0.5, 0.95);
CHECK_RANGE(trainer_spec_.num_threads(), 1, 128);
CHECK_RANGE(trainer_spec_.num_sub_iterations(), 1, 10);
#undef CHECK_RANGE
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
CHECK(normalizer_spec_.escape_whitespaces());
TrainerModel model(trainer_spec_, normalizer_spec_);
LoadSentences();
auto seed_sentencepieces = MakeSeedSentencePieces();
model.SetSentencePieces(std::move(seed_sentencepieces));
if (trainer_spec_.split_by_whitespace()) {
SplitSentencesByWhitespace();
} else {
const int training_size = std::min<size_t>(
sentences_.size(), trainer_spec_.training_sentence_size());
sentences_.resize(training_size);
}
LOG(INFO) << "Using " << sentences_.size() << " sentences for EM training";
desired_vocab_size_ = static_cast<size_t>(trainer_spec_.vocab_size() * 1.1);
while (true) {
// Sub-EM iteration.
for (int iter = 0; iter < trainer_spec_.num_sub_iterations(); ++iter) {
// Executes E step
float objective = 0.0;
int64 num_tokens = 0;
const auto expected = RunEStep(model, &objective, &num_tokens);
// Executes M step.
auto new_sentencepieces = RunMStep(model, expected);
model.SetSentencePieces(std::move(new_sentencepieces));
LOG(INFO) << "EM sub_iter=" << iter << " size=" << model.GetPieceSize()
<< " obj=" << objective << " num_tokens=" << num_tokens
<< " num_tokens/piece="
<< 1.0 * num_tokens / model.GetPieceSize();
} // end of Sub EM iteration
// Stops the iteration when the size of sentences reaches to the
// desired symbol size.
if (model.GetPieceSize() <= desired_vocab_size_) {
break;
}
// Prunes pieces.
auto new_sentencepieces = PruneSentencePieces(model);
model.SetSentencePieces(std::move(new_sentencepieces));
} // end of EM iteration
// Finally, adjusts the size of sentencepices to be |vocab_size|.
final_pieces_ = FinalizeSentencePieces(model);
Save();
}
} // namespace unigram
} // namespace sentencepiece

113
src/unigram_model_trainer.h Normal file
View File

@ -0,0 +1,113 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef UNIGRAM_MODEL_TRAINER_H_
#define UNIGRAM_MODEL_TRAINER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "sentencepiece_model.pb.h"
#include "stringpiece.h"
#include "trainer_interface.h"
#include "unigram_model.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
using string_util::UnicodeText;
class TrainerModel : public ModelBase {
public:
using SentencePieces = std::vector<std::pair<std::string, float>>;
TrainerModel() = delete;
TrainerModel(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizaiton_spec);
~TrainerModel() override;
// Returns the sentencepieces.
// The meta symbols, e.g., </s> are NOT included.
const SentencePieces &GetSentencePieces() const;
// Sets sentencepieces. The sentencepieces are moved.
// The meta symbols, e.g., </s> are NOT included.
void SetSentencePieces(SentencePieces &&sentencepieces);
int GetPieceSize() const override { return sentencepieces_.size(); }
float GetScore(int index) const override {
return sentencepieces_[index].second;
}
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const override {
return {};
}
private:
SentencePieces sentencepieces_;
TrainerSpec trainer_spec_;
NormalizerSpec normalizer_spec_;
};
class Trainer : public TrainerInterface {
public:
Trainer(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
void Train() override;
private:
FRIEND_TEST(TrainerTest, IsValidSentencePieceTest);
// Makes seed pieces from the training corpus.
// The size of seed pieces is determined by seed_sentencepiece_size.
TrainerModel::SentencePieces MakeSeedSentencePieces() const;
// Executes the E step of EM and returns expected count.
// The index of return array is the vocab id.
// |objective| is a negative likelihood of the current model.
// |num_token| is the number of total tokens to tokenize
// training corpus.
std::vector<float> RunEStep(const TrainerModel &model, float *objective,
int64 *num_tokens) const;
// Executes the M step of EM with the expected frequency and
// returns new pieces.
TrainerModel::SentencePieces RunMStep(
const TrainerModel &model, const std::vector<float> &expected) const;
// Heuristically prunes the current pieces.
// This is called after each EM sub-iteration.
TrainerModel::SentencePieces PruneSentencePieces(
const TrainerModel &model) const;
// Makes the final sentence pieces by incorporating the required characters
// and control/user defined symbols.
TrainerModel::SentencePieces FinalizeSentencePieces(
const TrainerModel &model) const;
// When the size of SentencePieces becomes less than desired_vocab_size_,
// break the main training loop. desired_vocab_size_ = 1.1 * vocab_size_
// for now.
int desired_vocab_size_;
};
} // namespace unigram
} // namespace sentencepiece
#endif // UNIGRAM_MODEL_TRAINER_H_

View File

@ -0,0 +1,75 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model_trainer.h"
#include "builder.h"
#include "normalizer.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
// Space symbol
#define WS "\xe2\x96\x81"
TEST(UnigramTrainerTest, EndToEndTest) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
normalizer_spec = normalizer::Builder::GetNormalizerSpec("nfkc");
trainer_spec.add_input("../data/wagahaiwa_nekodearu.txt");
constexpr int kVocabSize = 8000;
trainer_spec.set_vocab_size(kVocabSize);
trainer_spec.set_model_type(TrainerSpec::UNIGRAM);
trainer_spec.add_control_symbols("<ctrl>");
trainer_spec.add_user_defined_symbols("<user>");
test::ScopedTempFile sf("tmp_model");
trainer_spec.set_model_prefix(sf.filename());
unigram::Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(std::string(sf.filename()) + ".model"));
EXPECT_EQ(kVocabSize, sp.GetPieceSize());
const int cid = sp.PieceToId("<ctrl>");
const int uid = sp.PieceToId("<user>");
EXPECT_TRUE(sp.IsControl(cid));
EXPECT_FALSE(sp.IsUnknown(uid));
std::vector<std::string> tok;
sp.Encode("", &tok);
EXPECT_TRUE(tok.empty());
sp.Encode(
"吾輩《わがはい》は猫である。名前はまだ無い。"
"どこで生れたかとんと見当《けんとう》がつかぬ。"
"何でも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している"
"",
&tok);
EXPECT_EQ(WS
" 吾輩 《 わが はい 》 は 猫 である 。 名前 はまだ 無い 。 "
"どこ で 生 れた か とん と 見当 《 けん とう 》 が つか ぬ 。 "
"何でも 薄 暗 い じめ じめ した 所で ニャーニャー "
"泣 い ていた 事 だけは 記憶 している 。",
string_util::Join(tok, " "));
}
} // namespace unigram
} // namespace sentencepiece

258
src/util.cc Normal file
View File

@ -0,0 +1,258 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "util.h"
#include <iostream>
namespace sentencepiece {
namespace string_util {
template <typename T>
std::vector<T> SplitInternal(const T &str, const T &delim) {
std::vector<T> result;
size_t current_pos = 0;
size_t found_pos = 0;
while ((found_pos = str.find_first_of(delim, current_pos)) != T::npos) {
if (found_pos > current_pos) {
result.push_back(str.substr(current_pos, found_pos - current_pos));
}
current_pos = found_pos + 1;
}
if (str.size() > current_pos) {
result.push_back(str.substr(current_pos, str.size() - current_pos));
}
return result;
}
std::vector<std::string> Split(const std::string &str,
const std::string &delim) {
return SplitInternal<std::string>(str, delim);
}
std::vector<StringPiece> SplitPiece(StringPiece str, StringPiece delim) {
return SplitInternal<StringPiece>(str, delim);
}
std::string Join(const std::vector<std::string> &tokens, StringPiece delim) {
std::string result;
if (!tokens.empty()) {
result.append(tokens[0]);
}
for (size_t i = 1; i < tokens.size(); ++i) {
result.append(delim.data(), delim.size());
result.append(tokens[i]);
}
return result;
}
std::string Join(const std::vector<int> &tokens, StringPiece delim) {
std::string result;
char buf[32];
if (!tokens.empty()) {
const size_t len = Itoa(tokens[0], buf);
result.append(buf, len);
}
for (size_t i = 1; i < tokens.size(); ++i) {
result.append(delim.data(), delim.size());
const size_t len = Itoa(tokens[i], buf);
result.append(buf, len);
}
return result;
}
std::string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all) {
std::string ret;
StringReplace(s, oldsub, newsub, replace_all, &ret);
return ret;
}
void StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all, std::string *res) {
if (oldsub.empty()) {
res->append(s.data(), s.size());
return;
}
StringPiece::size_type start_pos = 0;
do {
const StringPiece::size_type pos = s.find(oldsub, start_pos);
if (pos == StringPiece::npos) {
break;
}
res->append(s.data() + start_pos, pos - start_pos);
res->append(newsub.data(), newsub.size());
start_pos = pos + oldsub.size();
} while (replace_all);
res->append(s.data() + start_pos, s.size() - start_pos);
}
// mblen sotres the number of bytes consumed after decoding.
// decoder_utf8 is optimized for speed. It doesn't check
// the following malformed UTF8:
// 1) Redundant UTF8
// 2) BOM (returns value is undefined).
// 3) Trailing byte after leading byte (c & 0xc0 == 0x80)
char32 DecodeUTF8(const char *begin, const char *end, size_t *mblen) {
const size_t len = end - begin;
if (len >= 3 && (begin[0] & 0xf0) == 0xe0) {
*mblen = 3;
return (((begin[0] & 0x0f) << 12) | ((begin[1] & 0x3f) << 6) |
((begin[2] & 0x3f)));
} else if (static_cast<unsigned char>(begin[0]) < 0x80) {
*mblen = 1;
return static_cast<unsigned char>(begin[0]);
} else if (len >= 2 && (begin[0] & 0xe0) == 0xc0) {
*mblen = 2;
return (((begin[0] & 0x1f) << 6) | ((begin[1] & 0x3f)));
} else if (len >= 4 && (begin[0] & 0xf8) == 0xf0) {
*mblen = 4;
return (((begin[0] & 0x07) << 18) | ((begin[1] & 0x3f) << 12) |
((begin[2] & 0x3f) << 6) | ((begin[3] & 0x3f)));
} else if (len >= 5 && (begin[0] & 0xfc) == 0xf8) {
*mblen = 5;
return (((begin[0] & 0x03) << 24) | ((begin[1] & 0x3f) << 18) |
((begin[2] & 0x3f) << 12) | ((begin[3] & 0x3f) << 6) |
((begin[4] & 0x3f)));
} else if (len >= 6 && (begin[0] & 0xfe) == 0xfc) {
*mblen = 6;
return (((begin[0] & 0x01) << 30) | ((begin[1] & 0x3f) << 24) |
((begin[2] & 0x3f) << 18) | ((begin[3] & 0x3f) << 12) |
((begin[4] & 0x3f) << 6) | ((begin[5] & 0x3f)));
}
*mblen = 1;
return 0;
}
size_t EncodeUTF8(char32 c, char *output) {
if (c == 0) {
// Do nothing if |c| is NUL. Previous implementation of UCS4ToUTF8Append
// worked like this.
output[0] = '\0';
return 0;
}
if (c < 0x00080) {
output[0] = static_cast<char>(c & 0xFF);
output[1] = '\0';
return 1;
}
if (c < 0x00800) {
output[0] = static_cast<char>(0xC0 + ((c >> 6) & 0x1F));
output[1] = static_cast<char>(0x80 + (c & 0x3F));
output[2] = '\0';
return 2;
}
if (c < 0x10000) {
output[0] = static_cast<char>(0xE0 + ((c >> 12) & 0x0F));
output[1] = static_cast<char>(0x80 + ((c >> 6) & 0x3F));
output[2] = static_cast<char>(0x80 + (c & 0x3F));
output[3] = '\0';
return 3;
}
if (c < 0x200000) {
output[0] = static_cast<char>(0xF0 + ((c >> 18) & 0x07));
output[1] = static_cast<char>(0x80 + ((c >> 12) & 0x3F));
output[2] = static_cast<char>(0x80 + ((c >> 6) & 0x3F));
output[3] = static_cast<char>(0x80 + (c & 0x3F));
output[4] = '\0';
return 4;
}
// below is not in UCS4 but in 32bit int.
if (c < 0x8000000) {
output[0] = static_cast<char>(0xF8 + ((c >> 24) & 0x03));
output[1] = static_cast<char>(0x80 + ((c >> 18) & 0x3F));
output[2] = static_cast<char>(0x80 + ((c >> 12) & 0x3F));
output[3] = static_cast<char>(0x80 + ((c >> 6) & 0x3F));
output[4] = static_cast<char>(0x80 + (c & 0x3F));
output[5] = '\0';
return 5;
}
output[0] = static_cast<char>(0xFC + ((c >> 30) & 0x01));
output[1] = static_cast<char>(0x80 + ((c >> 24) & 0x3F));
output[2] = static_cast<char>(0x80 + ((c >> 18) & 0x3F));
output[3] = static_cast<char>(0x80 + ((c >> 12) & 0x3F));
output[4] = static_cast<char>(0x80 + ((c >> 6) & 0x3F));
output[5] = static_cast<char>(0x80 + (c & 0x3F));
output[6] = '\0';
return 6;
}
std::string UnicodeCharToUTF8(const char32 c) { return UnicodeTextToUTF8({c}); }
UnicodeText UTF8ToUnicodeText(StringPiece utf8) {
UnicodeText uc;
const char *begin = utf8.data();
const char *end = utf8.data() + utf8.size();
while (begin < end) {
size_t mblen;
const char32 c = DecodeUTF8(begin, end, &mblen);
uc.push_back(c);
begin += mblen;
}
return uc;
}
std::string UnicodeTextToUTF8(const UnicodeText &utext) {
char buf[8];
std::string result;
for (const char32 c : utext) {
const size_t mblen = EncodeUTF8(c, buf);
result.append(buf, mblen);
}
return result;
}
} // namespace string_util
namespace io {
InputBuffer::InputBuffer(StringPiece filename)
: is_(filename.empty() ? &std::cin
: new std::ifstream(WPATH(filename.data()))) {
CHECK_IFS(*is_, filename.data());
}
InputBuffer::~InputBuffer() {
if (is_ != &std::cin) {
delete is_;
}
}
bool InputBuffer::ReadLine(std::string *line) {
return static_cast<bool>(std::getline(*is_, *line));
}
OutputBuffer::OutputBuffer(StringPiece filename)
: os_(filename.empty()
? &std::cout
: new std::ofstream(WPATH(filename.data()), OUTPUT_MODE)) {
CHECK_OFS(*os_, filename.data());
}
OutputBuffer::~OutputBuffer() {
if (os_ != &std::cout) {
delete os_;
}
}
void OutputBuffer::Write(StringPiece text) {
os_->write(text.data(), text.size());
}
void OutputBuffer::WriteLine(StringPiece text) {
Write(text);
Write("\n");
}
} // namespace io
} // namespace sentencepiece

318
src/util.h Normal file
View File

@ -0,0 +1,318 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef UTIL_H_
#define UTIL_H_
#include <algorithm>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "common.h"
#include "stringpiece.h"
namespace sentencepiece {
template <typename T>
std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
for (const auto n : v) {
out << " " << n;
}
return out;
}
// String utilities
namespace string_util {
std::vector<std::string> Split(const std::string &str,
const std::string &delim);
std::vector<StringPiece> SplitPiece(StringPiece str, StringPiece delim);
std::string Join(const std::vector<std::string> &tokens, StringPiece delim);
std::string Join(const std::vector<int> &tokens, StringPiece delim);
std::string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all);
void StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all, std::string *res);
template <typename T>
inline bool DecodePOD(StringPiece str, T *result) {
CHECK_NOTNULL(result);
if (sizeof(*result) != str.size()) {
return false;
}
memcpy(result, str.data(), sizeof(T));
return true;
}
template <typename T>
inline std::string EncodePOD(const T &value) {
std::string s;
s.resize(sizeof(T));
memcpy(const_cast<char *>(s.data()), &value, sizeof(T));
return s;
}
inline bool StartsWith(const StringPiece str, StringPiece prefix) {
return str.starts_with(prefix);
}
inline bool EndsWith(const StringPiece str, StringPiece suffix) {
return str.ends_with(suffix);
}
template <typename T>
inline std::string IntToHex(T value) {
std::ostringstream os;
os << std::hex << std::uppercase << value;
return os.str();
}
template <typename T>
inline T HexToInt(StringPiece value) {
T n;
std::istringstream is(value.data());
is >> std::hex >> n;
return n;
}
template <typename T>
inline size_t Itoa(T val, char *s) {
char *org = s;
if (val < 0) {
*s++ = '-';
val = -val;
}
char *t = s;
T mod = 0;
while (val) {
mod = val % 10;
*t++ = static_cast<char>(mod) + '0';
val /= 10;
}
if (s == t) {
*t++ = '0';
}
*t = '\0';
std::reverse(s, t);
return static_cast<size_t>(t - org);
}
// Return length of a single UTF-8 source character
inline size_t OneCharLen(const char *src) {
// Table of UTF-8 character lengths, based on first byte
constexpr unsigned char kUTF8LenTable[256] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 1, 1};
return kUTF8LenTable[*reinterpret_cast<const unsigned char *>(src)];
}
using UnicodeText = std::vector<char32>;
char32 DecodeUTF8(const char *begin, const char *end, size_t *mblen);
size_t EncodeUTF8(char32 c, char *output);
std::string UnicodeCharToUTF8(const char32 c);
UnicodeText UTF8ToUnicodeText(StringPiece utf8);
std::string UnicodeTextToUTF8(const UnicodeText &utext);
} // namespace string_util
// IO related utilities.
namespace io {
class InputBuffer {
public:
explicit InputBuffer(StringPiece filename);
~InputBuffer();
bool ReadLine(std::string *line);
private:
std::istream *is_;
};
class OutputBuffer {
public:
explicit OutputBuffer(StringPiece filename);
~OutputBuffer();
void Write(StringPiece text);
void WriteLine(StringPiece text);
private:
std::ostream *os_;
};
} // namespace io
// other map/ptr utilties
namespace port {
template <class Collection, class Key>
bool ContainsKey(const Collection &collection, const Key &key) {
return collection.find(key) != collection.end();
}
template <class Collection>
const typename Collection::value_type::second_type &FindOrDie(
const Collection &collection,
const typename Collection::value_type::first_type &key) {
typename Collection::const_iterator it = collection.find(key);
CHECK(it != collection.end()) << "Map key not found: " << key;
return it->second;
}
template <class Collection>
const typename Collection::value_type::second_type &FindWithDefault(
const Collection &collection,
const typename Collection::value_type::first_type &key,
const typename Collection::value_type::second_type &value) {
typename Collection::const_iterator it = collection.find(key);
if (it == collection.end()) {
return value;
}
return it->second;
}
template <class Collection>
bool InsertIfNotPresent(Collection *const collection,
const typename Collection::value_type &vt) {
return collection->insert(vt).second;
}
template <class Collection>
bool InsertIfNotPresent(
Collection *const collection,
const typename Collection::value_type::first_type &key,
const typename Collection::value_type::second_type &value) {
return InsertIfNotPresent(collection,
typename Collection::value_type(key, value));
}
template <class Collection>
void InsertOrDie(Collection *const collection,
const typename Collection::value_type::first_type &key,
const typename Collection::value_type::second_type &data) {
CHECK(InsertIfNotPresent(collection, key, data)) << "duplicate key";
}
// hash
inline void mix(uint64 &a, uint64 &b, uint64 &c) { // 64bit version
a -= b;
a -= c;
a ^= (c >> 43);
b -= c;
b -= a;
b ^= (a << 9);
c -= a;
c -= b;
c ^= (b >> 8);
a -= b;
a -= c;
a ^= (c >> 38);
b -= c;
b -= a;
b ^= (a << 23);
c -= a;
c -= b;
c ^= (b >> 5);
a -= b;
a -= c;
a ^= (c >> 35);
b -= c;
b -= a;
b ^= (a << 49);
c -= a;
c -= b;
c ^= (b >> 11);
a -= b;
a -= c;
a ^= (c >> 12);
b -= c;
b -= a;
b ^= (a << 18);
c -= a;
c -= b;
c ^= (b >> 22);
}
inline uint64 FingerprintCat(uint64 x, uint64 y) {
uint64 b = 0xe08c1d668b756f82; // more of the golden ratio
mix(x, b, y);
return y;
}
// Trait to select overloads and return types for MakeUnique.
template <typename T>
struct MakeUniqueResult {
using scalar = std::unique_ptr<T>;
};
template <typename T>
struct MakeUniqueResult<T[]> {
using array = std::unique_ptr<T[]>;
};
template <typename T, size_t N>
struct MakeUniqueResult<T[N]> {
using invalid = void;
};
// MakeUnique<T>(...) is an early implementation of C++14 std::make_unique.
// It is designed to be 100% compatible with std::make_unique so that the
// eventual switchover will be a simple renaming operation.
template <typename T, typename... Args>
typename MakeUniqueResult<T>::scalar MakeUnique(Args &&... args) { // NOLINT
return std::unique_ptr<T>(
new T(std::forward<Args>(args)...)); // NOLINT(build/c++11)
}
// Overload for array of unknown bound.
// The allocation of arrays needs to use the array form of new,
// and cannot take element constructor arguments.
template <typename T>
typename MakeUniqueResult<T>::array MakeUnique(size_t n) {
return std::unique_ptr<T>(new typename std::remove_extent<T>::type[n]());
}
// Reject arrays of known bound.
template <typename T, typename... Args>
typename MakeUniqueResult<T>::invalid MakeUnique(Args &&... /* args */) =
delete; // NOLINT
template <typename T>
void STLDeleteElements(std::vector<T *> *vec) {
for (auto item : *vec) {
delete item;
}
vec->clear();
}
} // namespace port
} // namespace sentencepiece
#endif // UTIL_H_

449
src/util_test.cc Normal file
View File

@ -0,0 +1,449 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "util.h"
#include <map>
#include "testharness.h"
namespace sentencepiece {
TEST(UtilTest, CheckNotNullTest) {
int a = 0;
CHECK_NOTNULL(&a);
EXPECT_DEATH(CHECK_NOTNULL(nullptr));
}
TEST(UtilTest, StartsWith) {
const std::string str = "abcdefg";
EXPECT_TRUE(string_util::StartsWith(str, ""));
EXPECT_TRUE(string_util::StartsWith(str, "a"));
EXPECT_TRUE(string_util::StartsWith(str, "abc"));
EXPECT_TRUE(string_util::StartsWith(str, "abcdefg"));
EXPECT_FALSE(string_util::StartsWith(str, "abcdefghi"));
EXPECT_FALSE(string_util::StartsWith(str, "foobar"));
}
TEST(UtilTest, EndsWith) {
const std::string str = "abcdefg";
EXPECT_TRUE(string_util::EndsWith(str, ""));
EXPECT_TRUE(string_util::EndsWith(str, "g"));
EXPECT_TRUE(string_util::EndsWith(str, "fg"));
EXPECT_TRUE(string_util::EndsWith(str, "abcdefg"));
EXPECT_FALSE(string_util::EndsWith(str, "aaabcdefg"));
EXPECT_FALSE(string_util::EndsWith(str, "foobar"));
EXPECT_FALSE(string_util::EndsWith(str, "foobarbuzbuz"));
}
TEST(UtilTest, Hex) {
for (char32 a = 0; a < 100000; ++a) {
const std::string s = string_util::IntToHex<char32>(a);
CHECK_EQ(a, string_util::HexToInt<char32>(s));
}
const int n = 151414;
CHECK_EQ("24F76", string_util::IntToHex(n));
CHECK_EQ(n, string_util::HexToInt<int>("24F76"));
}
TEST(UtilTest, SplitTest) {
std::vector<std::string> tokens;
tokens = string_util::Split("this is a\ttest", " \t");
EXPECT_EQ(4, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a");
EXPECT_EQ(tokens[3], "test");
tokens = string_util::Split("this is a \t test", " \t");
EXPECT_EQ(4, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a");
EXPECT_EQ(tokens[3], "test");
tokens = string_util::Split("this is a\ttest", " ");
EXPECT_EQ(3, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a\ttest");
tokens = string_util::Split(" this is a test ", " ");
EXPECT_EQ(4, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a");
EXPECT_EQ(tokens[3], "test");
tokens = string_util::Split("", "");
EXPECT_TRUE(tokens.empty());
}
TEST(UtilTest, SplitPieceTest) {
std::vector<StringPiece> tokens;
tokens = string_util::SplitPiece("this is a\ttest", " \t");
EXPECT_EQ(4, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a");
EXPECT_EQ(tokens[3], "test");
tokens = string_util::SplitPiece("this is a \t test", " \t");
EXPECT_EQ(4, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a");
EXPECT_EQ(tokens[3], "test");
tokens = string_util::SplitPiece("this is a\ttest", " ");
EXPECT_EQ(3, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a\ttest");
tokens = string_util::SplitPiece(" this is a test ", " ");
EXPECT_EQ(4, tokens.size());
EXPECT_EQ(tokens[0], "this");
EXPECT_EQ(tokens[1], "is");
EXPECT_EQ(tokens[2], "a");
EXPECT_EQ(tokens[3], "test");
tokens = string_util::SplitPiece("", "");
EXPECT_TRUE(tokens.empty());
}
TEST(UtilTest, JoinTest) {
std::vector<std::string> tokens;
tokens.push_back("this");
tokens.push_back("is");
tokens.push_back("a");
tokens.push_back("test");
EXPECT_EQ(string_util::Join(tokens, " "), "this is a test");
EXPECT_EQ(string_util::Join(tokens, ":"), "this:is:a:test");
EXPECT_EQ(string_util::Join(tokens, ""), "thisisatest");
tokens[2] = "";
EXPECT_EQ(string_util::Join(tokens, " "), "this is test");
}
TEST(UtilTest, JoinIntTest) {
std::vector<int> tokens;
tokens.push_back(10);
tokens.push_back(2);
tokens.push_back(-4);
tokens.push_back(5);
EXPECT_EQ(string_util::Join(tokens, " "), "10 2 -4 5");
EXPECT_EQ(string_util::Join(tokens, ":"), "10:2:-4:5");
EXPECT_EQ(string_util::Join(tokens, ""), "102-45");
}
TEST(UtilTest, StringPieceTest) {
StringPiece s;
EXPECT_EQ(0, s.find("", 0));
}
TEST(UtilTest, StringReplaceTest) {
EXPECT_EQ("fbb", string_util::StringReplace("foo", "o", "b", true));
EXPECT_EQ("fbo", string_util::StringReplace("foo", "o", "b", false));
EXPECT_EQ("abcDEf", string_util::StringReplace("abcdef", "de", "DE", true));
EXPECT_EQ("abcf", string_util::StringReplace("abcdef", "de", "", true));
EXPECT_EQ("aBCaBC", string_util::StringReplace("abcabc", "bc", "BC", true));
EXPECT_EQ("aBCabc", string_util::StringReplace("abcabc", "bc", "BC", false));
EXPECT_EQ("", string_util::StringReplace("", "bc", "BC", false));
EXPECT_EQ("", string_util::StringReplace("", "bc", "", false));
EXPECT_EQ("", string_util::StringReplace("", "", "", false));
EXPECT_EQ("abc", string_util::StringReplace("abc", "", "b", false));
}
TEST(UtilTest, EncodePODTet) {
std::string tmp;
{
float v = 0.0;
tmp = string_util::EncodePOD<float>(10.0);
EXPECT_TRUE(string_util::DecodePOD<float>(tmp, &v));
EXPECT_EQ(10.0, v);
}
{
double v = 0.0;
tmp = string_util::EncodePOD<double>(10.0);
EXPECT_TRUE(string_util::DecodePOD<double>(tmp, &v));
EXPECT_EQ(10.0, v);
}
{
int32 v = 0;
tmp = string_util::EncodePOD<int32>(10);
EXPECT_TRUE(string_util::DecodePOD<int32>(tmp, &v));
EXPECT_EQ(10, v);
}
{
int16 v = 0;
tmp = string_util::EncodePOD<int16>(10);
EXPECT_TRUE(string_util::DecodePOD<int16>(tmp, &v));
EXPECT_EQ(10, v);
}
{
int64 v = 0;
tmp = string_util::EncodePOD<int64>(10);
EXPECT_TRUE(string_util::DecodePOD<int64>(tmp, &v));
EXPECT_EQ(10, v);
}
// Invalid data
{
int32 v = 0;
tmp = string_util::EncodePOD<int64>(10);
EXPECT_FALSE(string_util::DecodePOD<int32>(tmp, &v));
}
}
TEST(UtilTest, ItoaTest) {
auto Itoa = [](int v) {
char buf[16];
string_util::Itoa(v, buf);
return std::string(buf);
};
EXPECT_EQ("0", Itoa(0));
EXPECT_EQ("10", Itoa(10));
EXPECT_EQ("-10", Itoa(-10));
EXPECT_EQ("718", Itoa(718));
EXPECT_EQ("-522", Itoa(-522));
}
TEST(UtilTest, OneCharLenTest) {
EXPECT_EQ(1, string_util::OneCharLen("abc"));
EXPECT_EQ(3, string_util::OneCharLen("テスト"));
}
TEST(UtilTest, DecodeUTF8Test) {
size_t mblen = 0;
{
const std::string input = "";
EXPECT_EQ(0, string_util::DecodeUTF8(input.data(),
input.data() + input.size(), &mblen));
EXPECT_EQ(1, mblen); // mblen always returns >= 1
}
{
const std::string input = "\x01";
EXPECT_EQ(1, string_util::DecodeUTF8(input.data(),
input.data() + input.size(), &mblen));
EXPECT_EQ(1, mblen);
}
{
const std::string input = "\x7F";
EXPECT_EQ(0x7F, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(1, mblen);
}
{
const std::string input = "\xC2\x80 ";
EXPECT_EQ(0x80, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(2, mblen);
}
{
const std::string input = "\xDF\xBF ";
EXPECT_EQ(0x7FF, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(2, mblen);
}
{
const std::string input = "\xE0\xA0\x80 ";
EXPECT_EQ(0x800, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(3, mblen);
}
{
const std::string input = "\xF0\x90\x80\x80 ";
EXPECT_EQ(0x10000, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(4, mblen);
}
{
const std::string input = "\xF7\xBF\xBF\xBF ";
EXPECT_EQ(0x1FFFFF, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(4, mblen);
}
{
const std::string input = "\xF8\x88\x80\x80\x80 ";
EXPECT_EQ(0x200000, string_util::DecodeUTF8(
input.data(), input.data() + input.size(), &mblen));
EXPECT_EQ(5, mblen);
}
{
const std::string input = "\xFC\x84\x80\x80\x80\x80 ";
EXPECT_EQ(0x4000000,
string_util::DecodeUTF8(input.data(), input.data() + input.size(),
&mblen));
EXPECT_EQ(6, mblen);
}
{
const char *kInvalidData[] = {
"\xC2", // must be 2byte.
"\xE0\xE0", // must be 3byte.
"\xFF", // BOM
"\xFE" // BOM
};
for (size_t i = 0; i < 4; ++i) {
// return values of string_util::DecodeUTF8 is not defined.
// TODO(taku) implement an workaround.
string_util::DecodeUTF8(
kInvalidData[i], kInvalidData[i] + strlen(kInvalidData[i]), &mblen);
EXPECT_EQ(1, mblen);
}
}
}
TEST(UtilTest, EncodeUTF8Test) {
constexpr int kMaxUnicode = 0x110000;
char buf[16];
for (char32 cp = 1; cp <= kMaxUnicode; ++cp) {
const size_t mblen = string_util::EncodeUTF8(cp, buf);
size_t mblen2;
char32 c = string_util::DecodeUTF8(buf, buf + 16, &mblen2);
EXPECT_EQ(mblen2, mblen);
EXPECT_EQ(cp, c);
}
EXPECT_EQ(0, string_util::EncodeUTF8(0, buf));
EXPECT_EQ('\0', buf[0]);
// non UCS4
size_t mblen;
EXPECT_EQ(5, string_util::EncodeUTF8(0x7000000, buf));
string_util::DecodeUTF8(buf, buf + 16, &mblen);
EXPECT_EQ(5, mblen);
EXPECT_EQ(6, string_util::EncodeUTF8(0x8000001, buf));
string_util::DecodeUTF8(buf, buf + 16, &mblen);
EXPECT_EQ(6, mblen);
}
TEST(UtilTest, UnicodeCharToUTF8Test) {
constexpr int kMaxUnicode = 0x110000;
for (char32 cp = 1; cp <= kMaxUnicode; ++cp) {
const auto s = string_util::UnicodeCharToUTF8(cp);
const auto ut = string_util::UTF8ToUnicodeText(s);
EXPECT_EQ(1, ut.size());
EXPECT_EQ(cp, ut[0]);
}
}
TEST(UtilTest, UnicodeTextToUTF8Test) {
string_util::UnicodeText ut;
ut = string_util::UTF8ToUnicodeText("test");
EXPECT_EQ("test", string_util::UnicodeTextToUTF8(ut));
ut = string_util::UTF8ToUnicodeText("テスト");
EXPECT_EQ("テスト", string_util::UnicodeTextToUTF8(ut));
ut = string_util::UTF8ToUnicodeText("これはtest");
EXPECT_EQ("これはtest", string_util::UnicodeTextToUTF8(ut));
}
TEST(UtilTest, MapUtilTest) {
const std::map<std::string, std::string> kMap = {
{"a", "A"}, {"b", "B"}, {"c", "C"}};
EXPECT_TRUE(port::ContainsKey(kMap, "a"));
EXPECT_TRUE(port::ContainsKey(kMap, "b"));
EXPECT_FALSE(port::ContainsKey(kMap, ""));
EXPECT_FALSE(port::ContainsKey(kMap, "x"));
EXPECT_EQ("A", port::FindOrDie(kMap, "a"));
EXPECT_EQ("B", port::FindOrDie(kMap, "b"));
EXPECT_DEATH(port::FindOrDie(kMap, "x"));
EXPECT_EQ("A", port::FindWithDefault(kMap, "a", "x"));
EXPECT_EQ("B", port::FindWithDefault(kMap, "b", "x"));
EXPECT_EQ("x", port::FindWithDefault(kMap, "d", "x"));
EXPECT_EQ("A", port::FindOrDie(kMap, "a"));
EXPECT_DEATH(port::FindOrDie(kMap, "d"));
}
TEST(UtilTest, MapUtilVecTest) {
const std::map<std::vector<int>, std::string> kMap = {{{0, 1}, "A"}};
EXPECT_DEATH(port::FindOrDie(kMap, {0, 2}));
}
TEST(UtilTest, InputOutputBufferTest) {
test::ScopedTempFile sf("test_file");
const char *kData[] = {
"This"
"is"
"a"
"test"};
{
io::OutputBuffer output(sf.filename());
for (size_t i = 0; i < arraysize(kData); ++i) {
output.WriteLine(kData[i]);
}
}
{
io::InputBuffer input(sf.filename());
std::string line;
for (size_t i = 0; i < arraysize(kData); ++i) {
EXPECT_TRUE(input.ReadLine(&line));
EXPECT_EQ(kData[i], line);
}
EXPECT_FALSE(input.ReadLine(&line));
}
}
TEST(UtilTest, InputOutputBufferInvalidFileTest) {
EXPECT_DEATH(io::InputBuffer input("__UNKNOWN__FILE__"));
}
TEST(UtilTest, STLDeleteELementsTest) {
class Item {
public:
explicit Item(int *counter) : counter_(counter) {}
~Item() { ++*counter_; }
private:
int *counter_;
};
std::vector<Item *> data;
int counter = 0;
for (int i = 0; i < 10; ++i) {
data.push_back(new Item(&counter));
}
port::STLDeleteElements(&data);
CHECK_EQ(10, counter);
EXPECT_EQ(0, data.size());
}
} // namespace sentencepiece

56
src/word_model.cc Normal file
View File

@ -0,0 +1,56 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "word_model.h"
#include "util.h"
namespace sentencepiece {
namespace word {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
CheckControlSymbols();
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
CHECK(!sp.piece().empty());
if (sp.type() == ModelProto::SentencePiece::NORMAL ||
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
CHECK(sp.has_score());
port::InsertOrDie(&pieces_, sp.piece(), i);
} else {
port::InsertOrDie(&reserved_id_map_, sp.piece(), i);
}
}
}
Model::~Model() {}
std::vector<std::pair<StringPiece, int>> Model::Encode(
StringPiece normalized) const {
if (normalized.empty()) {
return {};
}
std::vector<std::pair<StringPiece, int>> output;
for (const auto &w : SplitIntoWords(normalized)) {
output.emplace_back(w, PieceToId(w));
}
return output;
}
} // namespace word
} // namespace sentencepiece

35
src/word_model.h Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef WORD_MODEL_H_
#define WORD_MODEL_H_
#include "model_interface.h"
#include "sentencepiece_model.pb.h"
namespace sentencepiece {
namespace word {
// Tokenize text with whitespaces.
class Model : public ModelInterface {
public:
explicit Model(const ModelProto &model_proto);
~Model() override;
std::vector<std::pair<StringPiece, int>> Encode(
StringPiece normalized) const override;
};
} // namespace word
} // namespace sentencepiece
#endif // WORD_MODEL_H_

87
src/word_model_test.cc Normal file
View File

@ -0,0 +1,87 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "word_model.h"
#include <unordered_map>
#include <unordered_set>
#include "sentencepiece_model.pb.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace word {
namespace {
// Space symbol (U+2581)
#define WS "\xe2\x96\x81"
ModelProto MakeBaseModelProto() {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(WordModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, WS "ab");
AddPiece(&model_proto, WS "cd");
AddPiece(&model_proto, WS "abc");
AddPiece(&model_proto, WS "a", 0.1);
AddPiece(&model_proto, WS "b", 0.2);
AddPiece(&model_proto, WS "c", 0.3);
AddPiece(&model_proto, WS "d", 0.4);
const Model model(model_proto);
std::vector<std::pair<StringPiece, int>> result;
result = model.Encode("");
EXPECT_TRUE(result.empty());
result = model.Encode(WS "a" WS "b" WS "c");
EXPECT_EQ(3, result.size());
EXPECT_EQ(WS "a", result[0].first);
EXPECT_EQ(WS "b", result[1].first);
EXPECT_EQ(WS "c", result[2].first);
result = model.Encode(WS "ab" WS "cd" WS "abc");
EXPECT_EQ(3, result.size());
EXPECT_EQ(WS "ab", result[0].first);
EXPECT_EQ(WS "cd", result[1].first);
EXPECT_EQ(WS "abc", result[2].first);
}
} // namespace
} // namespace word
} // namespace sentencepiece

76
src/word_model_trainer.cc Normal file
View File

@ -0,0 +1,76 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "word_model_trainer.h"
#include "stringpiece.h"
#include "util.h"
#include "word_model.h"
namespace sentencepiece {
namespace word {
void Trainer::Train() {
#define CHECK_RANGE(variable, minval, maxval) \
CHECK(variable >= minval && variable <= maxval)
CHECK_GT(trainer_spec_.input().size(), 0);
CHECK(!trainer_spec_.model_prefix().empty());
CHECK_RANGE(trainer_spec_.character_coverage(), 0.98, 1.0);
CHECK_RANGE(trainer_spec_.input_sentence_size(), 100, 100000000);
CHECK_GT(trainer_spec_.vocab_size(), 0);
#undef CHECK_RANGE
LOG(INFO) << "Starts training with : \n" << trainer_spec_.Utf8DebugString();
CHECK(normalizer_spec_.escape_whitespaces());
CHECK_EQ(TrainerSpec::WORD, trainer_spec_.model_type());
LoadSentences();
std::unordered_map<std::string, uint64> freq;
for (const auto &it : sentences_) {
for (const auto &s : SplitIntoWords(it.first)) {
freq[s.to_string()] += it.second;
}
}
const int meta_symbols_size = trainer_spec_.control_symbols().size() +
trainer_spec_.user_defined_symbols().size() +
3; // <s>, </s>, <unk>
const int vocab_size = trainer_spec_.vocab_size() - meta_symbols_size;
CHECK_GE(vocab_size, 0);
uint64 sum = 0;
for (const auto &it : freq) {
sum += it.second;
}
const float logsum = log(sum);
CHECK(final_pieces_.empty());
for (const auto &it : Sorted(freq)) {
if (it.first.find(kUNKStr) != std::string::npos) {
continue;
}
if (final_pieces_.size() == static_cast<size_t>(vocab_size)) {
break;
}
final_pieces_.emplace_back(it.first, log(it.second) - logsum);
}
Save();
}
} // namespace word
} // namespace sentencepiece

39
src/word_model_trainer.h Normal file
View File

@ -0,0 +1,39 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#ifndef WORD_MODEL_TRAINER_H_
#define WORD_MODEL_TRAINER_H_
#include "sentencepiece_model.pb.h"
#include "trainer_interface.h"
namespace sentencepiece {
namespace word {
// Trainer class for word model.
//
// Word model simply counts the frequency of
// space-delimited tokens, then keep top
// |vocab_size| frequent tokens.
class Trainer : public TrainerInterface {
public:
Trainer(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec) {}
void Train() override;
};
} // namespace word
} // namespace sentencepiece
#endif // WORD_MODEL_TRAINER_H_

View File

@ -0,0 +1,73 @@
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "word_model_trainer.h"
#include "builder.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace word {
namespace {
// Space symbol (U+2581)
#define WS "\xE2\x96\x81"
std::string RunTrainer(const std::vector<std::string> &input, int size) {
test::ScopedTempFile input_scoped_file("input");
test::ScopedTempFile model_scoped_file("model");
const std::string input_file = input_scoped_file.filename();
const std::string model_prefix = model_scoped_file.filename();
{
io::OutputBuffer output(input_file);
for (const auto &line : input) {
output.WriteLine(line);
}
}
TrainerSpec trainer_spec;
trainer_spec.set_model_type(TrainerSpec::WORD);
trainer_spec.add_input(input_file);
trainer_spec.set_vocab_size(size - 3); // remove <unk>, <s>, </s>
trainer_spec.set_model_prefix(model_prefix);
auto normalizer_spec = normalizer::Builder::GetNormalizerSpec("identity");
normalizer_spec.set_add_dummy_prefix(true);
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
SentencePieceProcessor processor;
processor.Load(model_prefix + ".model");
const auto &model = processor.model_proto();
std::vector<std::string> pieces;
// remove <unk>, <s>, </s>
for (int i = 3; i < model.pieces_size(); ++i) {
pieces.emplace_back(model.pieces(i).piece());
}
return string_util::Join(pieces, " ");
}
} // namespace
TEST(TrainerTest, BasicTest) {
EXPECT_EQ(WS "I " WS "apple " WS "have " WS "pen",
RunTrainer({"I have a pen", "I have an apple", "apple pen"}, 10));
}
} // namespace word
} // namespace sentencepiece

10
third_party/darts_clone/LICENSE vendored Normal file
View File

@ -0,0 +1,10 @@
Copyright (c) 2008-2011, Susumu Yata
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
- Neither the name of the <ORGANIZATION> nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

1926
third_party/darts_clone/darts.h vendored Normal file

File diff suppressed because it is too large Load Diff

24
third_party/esaxx/LICENSE vendored Normal file
View File

@ -0,0 +1,24 @@
This is the esaxx copyright.
Copyright (c) 2010 Daisuke Okanohara All Rights Reserved.
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.

125
third_party/esaxx/esa.hxx vendored Normal file
View File

@ -0,0 +1,125 @@
/*
* esa.hxx
* Copyright (c) 2010 Daisuke Okanohara All Rights Reserved.
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef _ESA_HXX
#define _ESA_HXX
#include <vector>
#include <utility>
#include <cassert>
#include "sais.hxx"
namespace esaxx_private {
template<typename string_type, typename sarray_type, typename index_type>
index_type suffixtree(string_type T, sarray_type SA, sarray_type L, sarray_type R, sarray_type D, index_type n){
if (n == 0){
return 0;
}
sarray_type Psi = L;
Psi[SA[0]] = SA[n-1];
for (index_type i = 1; i < n; ++i){
Psi[SA[i]] = SA[i-1];
}
// Compare at most 2n log n charcters. Practically fastest
// "Permuted Longest-Common-Prefix Array", Juha Karkkainen, CPM 09
sarray_type PLCP = R;
index_type h = 0;
for (index_type i = 0; i < n; ++i){
index_type j = Psi[i];
while (i+h < n && j+h < n &&
T[i+h] == T[j+h]){
++h;
}
PLCP[i] = h;
if (h > 0) --h;
}
sarray_type H = L;
for (index_type i = 0; i < n; ++i){
H[i] = PLCP[SA[i]];
}
H[0] = -1;
std::vector<std::pair<index_type, index_type> > S;
S.push_back(std::make_pair((index_type)-1, (index_type)-1));
size_t nodeNum = 0;
for (index_type i = 0; ; ++i){
std::pair<index_type, index_type> cur (i, (i == n) ? -1 : H[i]);
std::pair<index_type, index_type> cand(S.back());
while (cand.second > cur.second){
if (i - cand.first > 1){
L[nodeNum] = cand.first;
R[nodeNum] = i;
D[nodeNum] = cand.second;
++nodeNum;
}
cur.first = cand.first;
S.pop_back();
cand = S.back();
}
if (cand.second < cur.second){
S.push_back(cur);
}
if (i == n) break;
S.push_back(std::make_pair(i, n - SA[i] + 1));
}
return nodeNum;
}
}
/**
* @brief Build an enhanced suffix array of a given string in linear time
* For an input text T, esaxx() builds an enhancd suffix array in linear time.
* i-th internal node is represented as a triple (L[i], R[i], D[i]);
* L[i] and R[i] is the left/right boundary of the suffix array as SA[L[i]....R[i]-1]
* D[i] is the depth of the internal node
* The number of internal node is at most N-1 and return the actual number by
* @param T[0...n-1] The input string. (random access iterator)
* @param SA[0...n-1] The output suffix array (random access iterator)
* @param L[0...n-1] The output left boundary of internal node (random access iterator)
* @param R[0...n-1] The output right boundary of internal node (random access iterator)
* @param D[0...n-1] The output depth of internal node (random access iterator)
* @param n The length of the input string
* @param k The alphabet size
* @pram nodeNum The output the number of internal node
* @return 0 if succeded, -1 or -2 otherwise
*/
template<typename string_type, typename sarray_type, typename index_type>
int esaxx(string_type T, sarray_type SA, sarray_type L, sarray_type R, sarray_type D,
index_type n, index_type k, index_type& nodeNum) {
if ((n < 0) || (k <= 0)) return -1;
int err = saisxx(T, SA, n, k);
if (err != 0){
return err;
}
nodeNum = esaxx_private::suffixtree(T, SA, L, R, D, n);
return 0;
}
#endif // _ESA_HXX

364
third_party/esaxx/sais.hxx vendored Normal file
View File

@ -0,0 +1,364 @@
/*
* sais.hxx for sais-lite
* Copyright (c) 2008-2009 Yuta Mori All Rights Reserved.
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
#ifndef _SAIS_HXX
#define _SAIS_HXX 1
#ifdef __cplusplus
#ifdef __INTEL_COMPILER
#pragma warning(disable : 383 981 1418)
// for icc 64-bit
//#define __builtin_vsnprintf(a, b, c, d) __builtin_vsnprintf(a, b, c, (char *)d)
#endif
#include <iterator>
#ifdef _OPENMP
# include <omp.h>
#endif
namespace saisxx_private {
/* find the start or end of each bucket */
template<typename string_type, typename bucket_type, typename index_type>
void
getCounts(const string_type T, bucket_type C, index_type n, index_type k) {
#ifdef _OPENMP
bucket_type D;
index_type i, j, p, sum, first, last;
int thnum, maxthreads = omp_get_max_threads();
#pragma omp parallel default(shared) private(D, i, thnum, first, last)
{
thnum = omp_get_thread_num();
D = C + thnum * k;
first = n / maxthreads * thnum;
last = (thnum < (maxthreads - 1)) ? n / maxthreads * (thnum + 1) : n;
for(i = 0; i < k; ++i) { D[i] = 0; }
for(i = first; i < last; ++i) { ++D[T[i]]; }
}
if(1 < maxthreads) {
#pragma omp parallel for default(shared) private(i, j, p, sum)
for(i = 0; i < k; ++i) {
for(j = 1, p = i + k, sum = C[i]; j < maxthreads; ++j, p += k) {
sum += C[p];
}
C[i] = sum;
}
}
#else
index_type i;
for(i = 0; i < k; ++i) { C[i] = 0; }
for(i = 0; i < n; ++i) { ++C[T[i]]; }
#endif
}
template<typename bucket_type, typename index_type>
void
getBuckets(const bucket_type C, bucket_type B, index_type k, bool end) {
index_type i, sum = 0;
if(end) { for(i = 0; i < k; ++i) { sum += C[i]; B[i] = sum; } }
else { for(i = 0; i < k; ++i) { sum += C[i]; B[i] = sum - C[i]; } }
}
/* compute SA and BWT */
template<typename string_type, typename sarray_type,
typename bucket_type, typename index_type>
void
induceSA(string_type T, sarray_type SA, bucket_type C, bucket_type B,
index_type n, index_type k) {
typedef typename std::iterator_traits<string_type>::value_type char_type;
sarray_type b;
index_type i, j;
char_type c0, c1;
/* compute SAl */
if(C == B) { getCounts(T, C, n, k); }
getBuckets(C, B, k, false); /* find starts of buckets */
b = SA + B[c1 = T[j = n - 1]];
*b++ = ((0 < j) && (T[j - 1] < c1)) ? ~j : j;
for(i = 0; i < n; ++i) {
j = SA[i], SA[i] = ~j;
if(0 < j) {
if((c0 = T[--j]) != c1) { B[c1] = b - SA; b = SA + B[c1 = c0]; }
*b++ = ((0 < j) && (T[j - 1] < c1)) ? ~j : j;
}
}
/* compute SAs */
if(C == B) { getCounts(T, C, n, k); }
getBuckets(C, B, k, true); /* find ends of buckets */
for(i = n - 1, b = SA + B[c1 = 0]; 0 <= i; --i) {
if(0 < (j = SA[i])) {
if((c0 = T[--j]) != c1) { B[c1] = b - SA; b = SA + B[c1 = c0]; }
*--b = ((j == 0) || (T[j - 1] > c1)) ? ~j : j;
} else {
SA[i] = ~j;
}
}
}
template<typename string_type, typename sarray_type,
typename bucket_type, typename index_type>
int
computeBWT(string_type T, sarray_type SA, bucket_type C, bucket_type B,
index_type n, index_type k) {
typedef typename std::iterator_traits<string_type>::value_type char_type;
sarray_type b;
index_type i, j, pidx = -1;
char_type c0, c1;
/* compute SAl */
if(C == B) { getCounts(T, C, n, k); }
getBuckets(C, B, k, false); /* find starts of buckets */
b = SA + B[c1 = T[j = n - 1]];
*b++ = ((0 < j) && (T[j - 1] < c1)) ? ~j : j;
for(i = 0; i < n; ++i) {
if(0 < (j = SA[i])) {
SA[i] = ~(c0 = T[--j]);
if(c0 != c1) { B[c1] = b - SA; b = SA + B[c1 = c0]; }
*b++ = ((0 < j) && (T[j - 1] < c1)) ? ~j : j;
} else if(j != 0) {
SA[i] = ~j;
}
}
/* compute SAs */
if(C == B) { getCounts(T, C, n, k); }
getBuckets(C, B, k, true); /* find ends of buckets */
for(i = n - 1, b = SA + B[c1 = 0]; 0 <= i; --i) {
if(0 < (j = SA[i])) {
SA[i] = (c0 = T[--j]);
if(c0 != c1) { B[c1] = b - SA; b = SA + B[c1 = c0]; }
*--b = ((0 < j) && (T[j - 1] > c1)) ? ~((index_type)T[j - 1]) : j;
} else if(j != 0) {
SA[i] = ~j;
} else {
pidx = i;
}
}
return pidx;
}
/* find the suffix array SA of T[0..n-1] in {0..k}^n
use a working space (excluding s and SA) of at most 2n+O(1) for a constant alphabet */
template<typename string_type, typename sarray_type, typename index_type>
int
suffixsort(string_type T, sarray_type SA,
index_type fs, index_type n, index_type k,
bool isbwt) {
typedef typename std::iterator_traits<string_type>::value_type char_type;
sarray_type RA;
index_type i, j, m, p, q, plen, qlen, name, pidx = 0;
bool diff;
int c;
#ifdef _OPENMP
int maxthreads = omp_get_max_threads();
#else
# define maxthreads 1
#endif
char_type c0, c1;
/* stage 1: reduce the problem by at least 1/2
sort all the S-substrings */
if(fs < (maxthreads * k)) {
index_type *C, *B;
if((C = new index_type[maxthreads * k]) == 0) { return -2; }
B = (1 < maxthreads) ? C + k : C;
getCounts(T, C, n, k); getBuckets(C, B, k, true); /* find ends of buckets */
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i)
#endif
for(i = 0; i < n; ++i) { SA[i] = 0; }
for(i = n - 2, c = 0, c1 = T[n - 1]; 0 <= i; --i, c1 = c0) {
if((c0 = T[i]) < (c1 + c)) { c = 1; }
else if(c != 0) { SA[--B[c1]] = i + 1, c = 0; }
}
induceSA(T, SA, C, B, n, k);
delete [] C;
} else {
sarray_type C, B;
C = SA + n;
B = ((1 < maxthreads) || (k <= (fs - k))) ? C + k : C;
getCounts(T, C, n, k); getBuckets(C, B, k, true); /* find ends of buckets */
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i)
#endif
for(i = 0; i < n; ++i) { SA[i] = 0; }
for(i = n - 2, c = 0, c1 = T[n - 1]; 0 <= i; --i, c1 = c0) {
if((c0 = T[i]) < (c1 + c)) { c = 1; }
else if(c != 0) { SA[--B[c1]] = i + 1, c = 0; }
}
induceSA(T, SA, C, B, n, k);
}
/* compact all the sorted substrings into the first m items of SA
2*m must be not larger than n (proveable) */
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i, j, p, c0, c1)
for(i = 0; i < n; ++i) {
p = SA[i];
if((0 < p) && (T[p - 1] > (c0 = T[p]))) {
for(j = p + 1; (j < n) && (c0 == (c1 = T[j])); ++j) { }
if((j < n) && (c0 < c1)) { SA[i] = ~p; }
}
}
for(i = 0, m = 0; i < n; ++i) { if((p = SA[i]) < 0) { SA[m++] = ~p; } }
#else
for(i = 0, m = 0; i < n; ++i) {
p = SA[i];
if((0 < p) && (T[p - 1] > (c0 = T[p]))) {
for(j = p + 1; (j < n) && (c0 == (c1 = T[j])); ++j) { }
if((j < n) && (c0 < c1)) { SA[m++] = p; }
}
}
#endif
j = m + (n >> 1);
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i)
#endif
for(i = m; i < j; ++i) { SA[i] = 0; } /* init the name array buffer */
/* store the length of all substrings */
for(i = n - 2, j = n, c = 0, c1 = T[n - 1]; 0 <= i; --i, c1 = c0) {
if((c0 = T[i]) < (c1 + c)) { c = 1; }
else if(c != 0) { SA[m + ((i + 1) >> 1)] = j - i - 1; j = i + 1; c = 0; }
}
/* find the lexicographic names of all substrings */
for(i = 0, name = 0, q = n, qlen = 0; i < m; ++i) {
p = SA[i], plen = SA[m + (p >> 1)], diff = true;
if(plen == qlen) {
for(j = 0; (j < plen) && (T[p + j] == T[q + j]); ++j) { }
if(j == plen) { diff = false; }
}
if(diff != false) { ++name, q = p, qlen = plen; }
SA[m + (p >> 1)] = name;
}
/* stage 2: solve the reduced problem
recurse if names are not yet unique */
if(name < m) {
RA = SA + n + fs - m;
for(i = m + (n >> 1) - 1, j = m - 1; m <= i; --i) {
if(SA[i] != 0) { RA[j--] = SA[i] - 1; }
}
if(suffixsort(RA, SA, fs + n - m * 2, m, name, false) != 0) { return -2; }
for(i = n - 2, j = m - 1, c = 0, c1 = T[n - 1]; 0 <= i; --i, c1 = c0) {
if((c0 = T[i]) < (c1 + c)) { c = 1; }
else if(c != 0) { RA[j--] = i + 1, c = 0; } /* get p1 */
}
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i)
#endif
for(i = 0; i < m; ++i) { SA[i] = RA[SA[i]]; } /* get index in s */
}
/* stage 3: induce the result for the original problem */
if(fs < (maxthreads * k)) {
index_type *B, *C;
if((C = new index_type[maxthreads * k]) == 0) { return -2; }
B = (1 < maxthreads) ? C + k : C;
/* put all left-most S characters into their buckets */
getCounts(T, C, n, k); getBuckets(C, B, k, true); /* find ends of buckets */
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i)
#endif
for(i = m; i < n; ++i) { SA[i] = 0; } /* init SA[m..n-1] */
for(i = m - 1; 0 <= i; --i) {
j = SA[i], SA[i] = 0;
SA[--B[T[j]]] = j;
}
if(isbwt == false) { induceSA(T, SA, C, B, n, k); }
else { pidx = computeBWT(T, SA, C, B, n, k); }
delete [] C;
} else {
sarray_type C, B;
C = SA + n;
B = ((1 < maxthreads) || (k <= (fs - k))) ? C + k : C;
/* put all left-most S characters into their buckets */
getCounts(T, C, n, k); getBuckets(C, B, k, true); /* find ends of buckets */
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(i)
#endif
for(i = m; i < n; ++i) { SA[i] = 0; } /* init SA[m..n-1] */
for(i = m - 1; 0 <= i; --i) {
j = SA[i], SA[i] = 0;
SA[--B[T[j]]] = j;
}
if(isbwt == false) { induceSA(T, SA, C, B, n, k); }
else { pidx = computeBWT(T, SA, C, B, n, k); }
}
return pidx;
#ifndef _OPENMP
# undef maxthreads
#endif
}
} /* namespace saisxx_private */
/**
* @brief Constructs the suffix array of a given string in linear time.
* @param T[0..n-1] The input string. (random access iterator)
* @param SA[0..n-1] The output array of suffixes. (random access iterator)
* @param n The length of the given string.
* @param k The alphabet size.
* @return 0 if no error occurred, -1 or -2 otherwise.
*/
template<typename string_type, typename sarray_type, typename index_type>
int
saisxx(string_type T, sarray_type SA, index_type n, index_type k = 256) {
int err;
if((n < 0) || (k <= 0)) { return -1; }
if(n <= 1) { if(n == 1) { SA[0] = 0; } return 0; }
try { err = saisxx_private::suffixsort(T, SA, 0, n, k, false); }
catch(...) { err = -2; }
return err;
}
/**
* @brief Constructs the burrows-wheeler transformed string of a given string in linear time.
* @param T[0..n-1] The input string. (random access iterator)
* @param U[0..n-1] The output string. (random access iterator)
* @param A[0..n-1] The temporary array. (random access iterator)
* @param n The length of the given string.
* @param k The alphabet size.
* @return The primary index if no error occurred, -1 or -2 otherwise.
*/
template<typename string_type, typename sarray_type, typename index_type>
index_type
saisxx_bwt(string_type T, string_type U, sarray_type A, index_type n, index_type k = 256) {
typedef typename std::iterator_traits<string_type>::value_type char_type;
index_type i, pidx;
if((n < 0) || (k <= 0)) { return -1; }
if(n <= 1) { if(n == 1) { U[0] = T[0]; } return n; }
try {
pidx = saisxx_private::suffixsort(T, A, 0, n, k, true);
if(0 <= pidx) {
U[0] = T[n - 1];
for(i = 0; i < pidx; ++i) { U[i + 1] = (char_type)A[i]; }
for(i += 1; i < n; ++i) { U[i] = (char_type)A[i]; }
pidx += 1;
}
} catch(...) { pidx = -2; }
return pidx;
}
#endif /* __cplusplus */
#endif /* _SAIS_HXX */

6
third_party/m4/LICENSE vendored Normal file
View File

@ -0,0 +1,6 @@
# Copyright (c) 2008 Akos Maroy <darkeye@tyrell.hu>
#
# Copying and distribution of this file, with or without modification, are
# permitted in any medium without royalty provided the copyright notice
# and this notice are preserved. This file is offered as-is, without any
# warranty.

80
third_party/m4/ax_check_icu.m4 vendored Normal file
View File

@ -0,0 +1,80 @@
# ===========================================================================
# http://www.gnu.org/software/autoconf-archive/ax_check_icu.html
# ===========================================================================
#
# SYNOPSIS
#
# AX_CHECK_ICU(version, action-if, action-if-not)
#
# DESCRIPTION
#
# Defines ICU_LIBS, ICU_CFLAGS, ICU_CXXFLAGS. See icu-config(1) man page.
#
# LICENSE
#
# Copyright (c) 2008 Akos Maroy <darkeye@tyrell.hu>
#
# Copying and distribution of this file, with or without modification, are
# permitted in any medium without royalty provided the copyright notice
# and this notice are preserved. This file is offered as-is, without any
# warranty.
#serial 6
AU_ALIAS([AC_CHECK_ICU], [AX_CHECK_ICU])
AC_DEFUN([AX_CHECK_ICU], [
succeeded=no
if test -z "$ICU_CONFIG"; then
AC_PATH_PROG(ICU_CONFIG, icu-config, no)
fi
if test "$ICU_CONFIG" = "no" ; then
echo "*** The icu-config script could not be found. Make sure it is"
echo "*** in your path, and that taglib is properly installed."
echo "*** Or see http://ibm.com/software/globalization/icu/"
else
ICU_VERSION=`$ICU_CONFIG --version`
AC_MSG_CHECKING(for ICU >= $1)
VERSION_CHECK=`expr $ICU_VERSION \>\= $1`
if test "$VERSION_CHECK" = "1" ; then
AC_MSG_RESULT(yes)
succeeded=yes
AC_MSG_CHECKING(ICU_CPPFLAGS)
ICU_CPPFLAGS=`$ICU_CONFIG --cppflags`
AC_MSG_RESULT($ICU_CPPFLAGS)
AC_MSG_CHECKING(ICU_CFLAGS)
ICU_CFLAGS=`$ICU_CONFIG --cflags`
AC_MSG_RESULT($ICU_CFLAGS)
AC_MSG_CHECKING(ICU_CXXFLAGS)
ICU_CXXFLAGS=`$ICU_CONFIG --cxxflags`
AC_MSG_RESULT($ICU_CXXFLAGS)
AC_MSG_CHECKING(ICU_LIBS)
ICU_LIBS=`$ICU_CONFIG --ldflags`
AC_MSG_RESULT($ICU_LIBS)
else
ICU_CPPFLAGS=""
ICU_CFLAGS=""
ICU_CXXFLAGS=""
ICU_LIBS=""
## If we have a custom action on failure, don't print errors, but
## do set a variable so people can do so.
ifelse([$3], ,echo "can't find ICU >= $1",)
fi
AC_SUBST(ICU_CPPFLAGS)
AC_SUBST(ICU_CFLAGS)
AC_SUBST(ICU_CXXFLAGS)
AC_SUBST(ICU_LIBS)
fi
if test $succeeded = yes; then
ifelse([$2], , :, [$2])
else
ifelse([$3], , AC_MSG_ERROR([Library requirements (ICU) not met.]), [$3])
fi
])