2011-09-21 20:06:48 +04:00
# include "lm/left.hh"
# include "lm/model.hh"
# include "util/tokenize_piece.hh"
# include <vector>
# define BOOST_TEST_MODULE LeftTest
# include <boost/test/unit_test.hpp>
# include <boost/test/floating_point_comparison.hpp>
namespace lm {
namespace ngram {
namespace {
# define Term(word) score.Terminal(m.GetVocabulary().Index(word));
# define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value);
template < class M > void Short ( const M & m ) {
ChartState base ;
{
RuleScore < M > score ( m , base ) ;
Term ( " more " ) ;
Term ( " loin " ) ;
BOOST_CHECK_CLOSE ( - 1.206319 - 0.3561665 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK ( base . full ) ;
BOOST_CHECK_EQUAL ( 2 , base . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , base . right . length ) ;
VCheck ( " loin " , base . right . words [ 0 ] ) ;
ChartState more_left ;
{
RuleScore < M > score ( m , more_left ) ;
Term ( " little " ) ;
score . NonTerminal ( base , - 1.206319 - 0.3561665 ) ;
// p(little more loin | null context)
BOOST_CHECK_CLOSE ( - 1.56538 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 3 , more_left . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , more_left . right . length ) ;
VCheck ( " loin " , more_left . right . words [ 0 ] ) ;
BOOST_CHECK ( more_left . full ) ;
ChartState shorter ;
{
RuleScore < M > score ( m , shorter ) ;
Term ( " to " ) ;
score . NonTerminal ( base , - 1.206319 - 0.3561665 ) ;
BOOST_CHECK_CLOSE ( - 0.30103 - 1.687872 - 1.206319 - 0.3561665 , score . Finish ( ) , 0.01 ) ;
}
BOOST_CHECK_EQUAL ( 1 , shorter . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , shorter . right . length ) ;
VCheck ( " loin " , shorter . right . words [ 0 ] ) ;
BOOST_CHECK ( shorter . full ) ;
}
template < class M > void Charge ( const M & m ) {
ChartState base ;
{
RuleScore < M > score ( m , base ) ;
Term ( " on " ) ;
Term ( " more " ) ;
BOOST_CHECK_CLOSE ( - 1.509559 - 0.4771212 - 1.206319 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 1 , base . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , base . right . length ) ;
VCheck ( " more " , base . right . words [ 0 ] ) ;
BOOST_CHECK ( base . full ) ;
ChartState extend ;
{
RuleScore < M > score ( m , extend ) ;
Term ( " looking " ) ;
score . NonTerminal ( base , - 1.509559 - 0.4771212 - 1.206319 ) ;
BOOST_CHECK_CLOSE ( - 3.91039 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 2 , extend . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , extend . right . length ) ;
VCheck ( " more " , extend . right . words [ 0 ] ) ;
BOOST_CHECK ( extend . full ) ;
ChartState tobos ;
{
RuleScore < M > score ( m , tobos ) ;
score . BeginSentence ( ) ;
score . NonTerminal ( extend , - 3.91039 ) ;
BOOST_CHECK_CLOSE ( - 3.471169 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 0 , tobos . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , tobos . right . length ) ;
}
template < class M > float LeftToRight ( const M & m , const std : : vector < WordIndex > & words ) {
float ret = 0.0 ;
State right = m . NullContextState ( ) ;
for ( std : : vector < WordIndex > : : const_iterator i = words . begin ( ) ; i ! = words . end ( ) ; + + i ) {
State copy ( right ) ;
ret + = m . Score ( copy , * i , right ) ;
}
return ret ;
}
template < class M > float RightToLeft ( const M & m , const std : : vector < WordIndex > & words ) {
float ret = 0.0 ;
ChartState state ;
state . left . length = 0 ;
state . right . length = 0 ;
state . full = false ;
for ( std : : vector < WordIndex > : : const_reverse_iterator i = words . rbegin ( ) ; i ! = words . rend ( ) ; + + i ) {
ChartState copy ( state ) ;
RuleScore < M > score ( m , state ) ;
score . Terminal ( * i ) ;
score . NonTerminal ( copy , ret ) ;
ret = score . Finish ( ) ;
}
return ret ;
}
template < class M > float TreeMiddle ( const M & m , const std : : vector < WordIndex > & words ) {
std : : vector < std : : pair < ChartState , float > > states ( words . size ( ) ) ;
for ( unsigned int i = 0 ; i < words . size ( ) ; + + i ) {
RuleScore < M > score ( m , states [ i ] . first ) ;
score . Terminal ( words [ i ] ) ;
states [ i ] . second = score . Finish ( ) ;
}
while ( states . size ( ) > 1 ) {
std : : vector < std : : pair < ChartState , float > > upper ( ( states . size ( ) + 1 ) / 2 ) ;
for ( unsigned int i = 0 ; i < states . size ( ) / 2 ; + + i ) {
RuleScore < M > score ( m , upper [ i ] . first ) ;
score . NonTerminal ( states [ i * 2 ] . first , states [ i * 2 ] . second ) ;
score . NonTerminal ( states [ i * 2 + 1 ] . first , states [ i * 2 + 1 ] . second ) ;
upper [ i ] . second = score . Finish ( ) ;
}
if ( states . size ( ) % 2 ) {
upper . back ( ) = states . back ( ) ;
}
std : : swap ( states , upper ) ;
}
return states . empty ( ) ? 0 : states . back ( ) . second ;
}
template < class M > void LookupVocab ( const M & m , const StringPiece & str , std : : vector < WordIndex > & out ) {
out . clear ( ) ;
2011-11-17 23:12:19 +04:00
for ( util : : TokenIter < util : : SingleCharacter , true > i ( str , ' ' ) ; i ; + + i ) {
2011-09-21 20:06:48 +04:00
out . push_back ( m . GetVocabulary ( ) . Index ( * i ) ) ;
}
}
# define TEXT_TEST(str) \
{ \
std : : vector < WordIndex > words ; \
LookupVocab ( m , str , words ) ; \
float expect = LeftToRight ( m , words ) ; \
BOOST_CHECK_CLOSE ( expect , RightToLeft ( m , words ) , 0.001 ) ; \
BOOST_CHECK_CLOSE ( expect , TreeMiddle ( m , words ) , 0.001 ) ; \
}
// Build sentences, or parts thereof, from right to left.
template < class M > void GrowBig ( const M & m ) {
TEXT_TEST ( " in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s> " ) ;
TEXT_TEST ( " on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s> " ) ;
TEXT_TEST ( " on a little more loin also would consider higher to look good " ) ;
TEXT_TEST ( " more loin also would consider higher to look good " ) ;
TEXT_TEST ( " more loin also would consider higher to look " ) ;
TEXT_TEST ( " also would consider higher to look " ) ;
TEXT_TEST ( " also would consider higher " ) ;
TEXT_TEST ( " would consider higher to look " ) ;
TEXT_TEST ( " consider higher to look " ) ;
TEXT_TEST ( " consider higher to " ) ;
TEXT_TEST ( " consider higher " ) ;
}
template < class M > void AlsoWouldConsiderHigher ( const M & m ) {
ChartState also ;
{
RuleScore < M > score ( m , also ) ;
score . Terminal ( m . GetVocabulary ( ) . Index ( " also " ) ) ;
BOOST_CHECK_CLOSE ( - 1.687872 , score . Finish ( ) , 0.001 ) ;
}
ChartState would ;
{
RuleScore < M > score ( m , would ) ;
score . Terminal ( m . GetVocabulary ( ) . Index ( " would " ) ) ;
BOOST_CHECK_CLOSE ( - 1.687872 , score . Finish ( ) , 0.001 ) ;
}
ChartState combine_also_would ;
{
RuleScore < M > score ( m , combine_also_would ) ;
score . NonTerminal ( also , - 1.687872 ) ;
score . NonTerminal ( would , - 1.687872 ) ;
BOOST_CHECK_CLOSE ( - 1.687872 - 2.0 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 2 , combine_also_would . right . length ) ;
ChartState also_would ;
{
RuleScore < M > score ( m , also_would ) ;
score . Terminal ( m . GetVocabulary ( ) . Index ( " also " ) ) ;
score . Terminal ( m . GetVocabulary ( ) . Index ( " would " ) ) ;
BOOST_CHECK_CLOSE ( - 1.687872 - 2.0 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 2 , also_would . right . length ) ;
ChartState consider ;
{
RuleScore < M > score ( m , consider ) ;
score . Terminal ( m . GetVocabulary ( ) . Index ( " consider " ) ) ;
BOOST_CHECK_CLOSE ( - 1.687872 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 1 , consider . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , consider . right . length ) ;
BOOST_CHECK ( ! consider . full ) ;
ChartState higher ;
float higher_score ;
{
RuleScore < M > score ( m , higher ) ;
score . Terminal ( m . GetVocabulary ( ) . Index ( " higher " ) ) ;
higher_score = score . Finish ( ) ;
}
BOOST_CHECK_CLOSE ( - 1.509559 , higher_score , 0.001 ) ;
BOOST_CHECK_EQUAL ( 1 , higher . left . length ) ;
BOOST_CHECK_EQUAL ( 1 , higher . right . length ) ;
BOOST_CHECK ( ! higher . full ) ;
VCheck ( " higher " , higher . right . words [ 0 ] ) ;
BOOST_CHECK_CLOSE ( - 0.30103 , higher . right . backoff [ 0 ] , 0.001 ) ;
ChartState consider_higher ;
{
RuleScore < M > score ( m , consider_higher ) ;
score . NonTerminal ( consider , - 1.687872 ) ;
score . NonTerminal ( higher , higher_score ) ;
BOOST_CHECK_CLOSE ( - 1.509559 - 1.687872 - 0.30103 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 2 , consider_higher . left . length ) ;
BOOST_CHECK ( ! consider_higher . full ) ;
ChartState full ;
{
RuleScore < M > score ( m , full ) ;
score . NonTerminal ( combine_also_would , - 1.687872 - 2.0 ) ;
score . NonTerminal ( consider_higher , - 1.509559 - 1.687872 - 0.30103 ) ;
BOOST_CHECK_CLOSE ( - 10.6879 , score . Finish ( ) , 0.001 ) ;
}
BOOST_CHECK_EQUAL ( 4 , full . right . length ) ;
}
template < class M > void GrowSmall ( const M & m ) {
TEXT_TEST ( " in biarritz watching considering looking . </s> " ) ;
TEXT_TEST ( " in biarritz watching considering looking . " ) ;
TEXT_TEST ( " in biarritz " ) ;
}
# define CHECK_SCORE(str, val) \
{ \
float got = val ; \
std : : vector < WordIndex > indices ; \
LookupVocab ( m , str , indices ) ; \
BOOST_CHECK_CLOSE ( LeftToRight ( m , indices ) , got , 0.001 ) ; \
}
template < class M > void FullGrow ( const M & m ) {
std : : vector < WordIndex > words ;
LookupVocab ( m , " in biarritz watching considering looking . </s> " , words ) ;
ChartState lexical [ 7 ] ;
float lexical_scores [ 7 ] ;
for ( unsigned int i = 0 ; i < 7 ; + + i ) {
RuleScore < M > score ( m , lexical [ i ] ) ;
score . Terminal ( words [ i ] ) ;
lexical_scores [ i ] = score . Finish ( ) ;
}
CHECK_SCORE ( " in " , lexical_scores [ 0 ] ) ;
CHECK_SCORE ( " biarritz " , lexical_scores [ 1 ] ) ;
CHECK_SCORE ( " watching " , lexical_scores [ 2 ] ) ;
CHECK_SCORE ( " </s> " , lexical_scores [ 6 ] ) ;
ChartState l1 [ 4 ] ;
float l1_scores [ 4 ] ;
{
RuleScore < M > score ( m , l1 [ 0 ] ) ;
score . NonTerminal ( lexical [ 0 ] , lexical_scores [ 0 ] ) ;
score . NonTerminal ( lexical [ 1 ] , lexical_scores [ 1 ] ) ;
CHECK_SCORE ( " in biarritz " , l1_scores [ 0 ] = score . Finish ( ) ) ;
}
{
RuleScore < M > score ( m , l1 [ 1 ] ) ;
score . NonTerminal ( lexical [ 2 ] , lexical_scores [ 2 ] ) ;
score . NonTerminal ( lexical [ 3 ] , lexical_scores [ 3 ] ) ;
CHECK_SCORE ( " watching considering " , l1_scores [ 1 ] = score . Finish ( ) ) ;
}
{
RuleScore < M > score ( m , l1 [ 2 ] ) ;
score . NonTerminal ( lexical [ 4 ] , lexical_scores [ 4 ] ) ;
score . NonTerminal ( lexical [ 5 ] , lexical_scores [ 5 ] ) ;
CHECK_SCORE ( " looking . " , l1_scores [ 2 ] = score . Finish ( ) ) ;
}
BOOST_CHECK_EQUAL ( l1 [ 2 ] . left . length , 1 ) ;
l1 [ 3 ] = lexical [ 6 ] ;
l1_scores [ 3 ] = lexical_scores [ 6 ] ;
ChartState l2 [ 2 ] ;
float l2_scores [ 2 ] ;
{
RuleScore < M > score ( m , l2 [ 0 ] ) ;
score . NonTerminal ( l1 [ 0 ] , l1_scores [ 0 ] ) ;
score . NonTerminal ( l1 [ 1 ] , l1_scores [ 1 ] ) ;
CHECK_SCORE ( " in biarritz watching considering " , l2_scores [ 0 ] = score . Finish ( ) ) ;
}
{
RuleScore < M > score ( m , l2 [ 1 ] ) ;
score . NonTerminal ( l1 [ 2 ] , l1_scores [ 2 ] ) ;
score . NonTerminal ( l1 [ 3 ] , l1_scores [ 3 ] ) ;
CHECK_SCORE ( " looking . </s> " , l2_scores [ 1 ] = score . Finish ( ) ) ;
}
BOOST_CHECK_EQUAL ( l2 [ 1 ] . left . length , 1 ) ;
BOOST_CHECK ( l2 [ 1 ] . full ) ;
ChartState top ;
{
RuleScore < M > score ( m , top ) ;
score . NonTerminal ( l2 [ 0 ] , l2_scores [ 0 ] ) ;
score . NonTerminal ( l2 [ 1 ] , l2_scores [ 1 ] ) ;
CHECK_SCORE ( " in biarritz watching considering looking . </s> " , score . Finish ( ) ) ;
}
}
2011-11-17 23:12:19 +04:00
const char * FileLocation ( ) {
if ( boost : : unit_test : : framework : : master_test_suite ( ) . argc < 2 ) {
return " test.arpa " ;
}
return boost : : unit_test : : framework : : master_test_suite ( ) . argv [ 1 ] ;
}
2011-09-21 20:06:48 +04:00
template < class M > void Everything ( ) {
Config config ;
config . messages = NULL ;
2011-11-17 23:12:19 +04:00
M m ( FileLocation ( ) , config ) ;
2011-09-21 20:06:48 +04:00
Short ( m ) ;
Charge ( m ) ;
GrowBig ( m ) ;
AlsoWouldConsiderHigher ( m ) ;
GrowSmall ( m ) ;
FullGrow ( m ) ;
}
BOOST_AUTO_TEST_CASE ( ProbingAll ) {
Everything < Model > ( ) ;
}
BOOST_AUTO_TEST_CASE ( TrieAll ) {
Everything < TrieModel > ( ) ;
}
BOOST_AUTO_TEST_CASE ( QuantTrieAll ) {
Everything < QuantTrieModel > ( ) ;
}
BOOST_AUTO_TEST_CASE ( ArrayQuantTrieAll ) {
Everything < QuantArrayTrieModel > ( ) ;
}
BOOST_AUTO_TEST_CASE ( ArrayTrieAll ) {
Everything < ArrayTrieModel > ( ) ;
}
} // namespace
} // namespace ngram
} // namespace lm