Merge up to 3791 from trunk.

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/mira-mtm5@3792 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
bhaddow 2011-01-05 13:49:44 +00:00
commit a2730c445d
233 changed files with 9091 additions and 2899 deletions

View File

@ -8,9 +8,12 @@ and build it. The SRILM can be downloaded from
http://www.speech.sri.com/projects/srilm/download.html .
If you want to use IRST's, you will need to download its source
and build it. The IRSTLM can be downloaded from
http://sourceforge.net/projects/irstlm/
If you want to use IRST's, you will need to download its source and
build it. The IRSTLM can be downloaded from either the SourceForge
website
http://sourceforge.net/projects/irstlm
or the official IRSTLM website
http://hlt.fbk.eu/en/irstlm
Ken's LM is included with the Moses distribution.

View File

@ -1,4 +1,23 @@
#undef _GLIBCXX_DEBUG
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <algorithm>
#include <iostream>
#include <string>
@ -54,7 +73,7 @@ int main (int argc, char * const argv[])
if (lineNum%100000 == 0) cerr << lineNum << flush;
//cerr << lineNum << " " << line << endl;
std::vector<float> misc;
std::vector<float> misc(1);
SourcePhrase sourcePhrase;
TargetPhrase *targetPhrase = new TargetPhrase(numScores);
Tokenize(sourcePhrase, *targetPhrase, line, onDiskWrapper, numScores, misc);
@ -135,11 +154,8 @@ void Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhrase, char *line
break;
case 5:
{ // count info. Only store the 2nd one
if (misc.size() == 0)
{
float val = Moses::Scan<float>(tok);
misc.push_back(val);
}
float val = Moses::Scan<float>(tok);
misc[0] = val;
++stage;
break;
}

View File

@ -1,12 +1,23 @@
#pragma once
/*
* Main.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <string>
#include "../../OnDiskPt/src/SourcePhrase.h"
#include "../../OnDiskPt/src/TargetPhrase.h"

View File

@ -1,11 +1,22 @@
/*
* OnDiskWrapper.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#ifdef WIN32
#include <direct.h>
#endif

View File

@ -1,12 +1,23 @@
#pragma once
/*
* OnDiskWrapper.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <string>
#include <fstream>
#include "Vocab.h"

View File

@ -1,11 +1,22 @@
/*
* Phrase.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <iostream>
#include <cassert>
#include "../../moses/src/Util.h"

View File

@ -1,12 +1,23 @@
#pragma once
/*
* Phrase.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <vector>
#include <iostream>
#include "Word.h"

View File

@ -1,11 +1,22 @@
/*
* PhraseNode.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 01/01/2010.
* Copyright 2010 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <cassert>
#include "PhraseNode.h"
#include "OnDiskWrapper.h"

View File

@ -1,12 +1,23 @@
#pragma once
/*
* PhraseNode.h
* CreateOnDisk
*
* Created by Hieu Hoang on 01/01/2010.
* Copyright 2010 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <fstream>
#include <vector>
#include <map>

View File

@ -1,11 +1,22 @@
/*
* SourcePhrase.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <cassert>
#include "SourcePhrase.h"

View File

@ -1,12 +1,23 @@
#pragma once
/*
* SourcePhrase.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <vector>
#include "Phrase.h"
#include "Word.h"

View File

@ -1,11 +1,23 @@
/*
* TargetPhrase.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <algorithm>
#include <iostream>
#include "../../moses/src/Util.h"
@ -188,14 +200,10 @@ Moses::TargetPhrase *TargetPhrase::ConvertToMoses(const std::vector<Moses::Facto
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const Moses::Phrase &sourcePhrase) const
, const Moses::LMList &lmList) const
{
Moses::TargetPhrase *ret = new Moses::TargetPhrase(Moses::Output);
// source phrase
ret->SetSourcePhrase(&sourcePhrase);
// words
size_t phraseSize = GetSize();
assert(phraseSize > 0); // last word is lhs

View File

@ -1,12 +1,24 @@
#pragma once
/*
* TargetPhrase.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <fstream>
#include <string>
#include <vector>
@ -72,8 +84,7 @@ public:
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const Moses::Phrase &sourcePhrase) const;
, const Moses::LMList &lmList) const;
UINT64 ReadOtherInfoFromFile(UINT64 filePos, std::fstream &fileTPColl);
UINT64 ReadFromFile(std::fstream &fileTP, size_t numFactors);

View File

@ -1,3 +1,22 @@
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <algorithm>
#include <iostream>
@ -103,7 +122,6 @@ Moses::TargetPhraseCollection *TargetPhraseCollection::ConvertToMoses(const std:
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const Moses::Phrase &sourcePhrase
, const std::string &filePath
, Vocab &vocab) const
{
@ -118,8 +136,7 @@ Moses::TargetPhraseCollection *TargetPhraseCollection::ConvertToMoses(const std:
, phraseDict
, weightT
, wpProducer
, lmList
, sourcePhrase);
, lmList);
/*
// debugging output

View File

@ -1,3 +1,22 @@
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#pragma once
#include "TargetPhrase.h"
@ -53,7 +72,6 @@ public:
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const Moses::Phrase &sourcePhrase
, const std::string &filePath
, Vocab &vocab) const;
void ReadFromFile(size_t tableLimit, UINT64 filePos, OnDiskWrapper &onDiskWrapper);

View File

@ -1,11 +1,22 @@
/*
* Vocab.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <string>
#include <fstream>
#include "OnDiskWrapper.h"

View File

@ -1,12 +1,23 @@
#pragma once
/*
* Vocab.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <string>
#include <map>
#include "../../moses/src/TypeDef.h"

View File

@ -1,11 +1,22 @@
/*
* Word.cpp
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "../../moses/src/Util.h"
#include "../../moses/src/Word.h"

View File

@ -1,12 +1,23 @@
#pragma once
/*
* Word.h
* CreateOnDisk
*
* Created by Hieu Hoang on 31/12/2009.
* Copyright 2009 __MyCompanyName__. All rights reserved.
*
*/
// $Id$
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <string>
#include <vector>
#include <iostream>

View File

@ -207,13 +207,21 @@ then
SAVE_CPPFLAGS="$CPPFLAGS"
CPPFLAGS="$CPPFLAGS -I${with_irstlm}/include"
AC_MSG_NOTICE([])
AC_MSG_NOTICE([!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!])
AC_MSG_NOTICE([!!! You are linking the IRSTLM library; be sure the release is >= 5.50.01 !!!])
AC_MSG_NOTICE([!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!])
AC_MSG_NOTICE([])
AC_CHECK_HEADER(n_gram.h,
[AC_DEFINE([HAVE_IRSTLM], [], [flag for IRSTLM])],
[AC_MSG_ERROR([Cannot find IRST-LM in ${with_irstlm}])])
MY_ARCH=`uname -m`
LIB_IRSTLM="-lirstlm"
LDFLAGS="$LDFLAGS -L${with_irstlm}/lib/${MY_ARCH}"
LDFLAGS="$LDFLAGS -L${with_irstlm}/lib"
LIBS="$LIBS $LIB_IRSTLM"
FMTLIBS="$FMTLIBS libirstlm.a"
AM_CONDITIONAL([IRST_LM], true)
@ -224,9 +232,9 @@ then
SAVE_CPPFLAGS="$CPPFLAGS"
CPPFLAGS="$CPPFLAGS -I${PWD}/kenlm"
AC_CHECK_HEADER(lm/ngram.hh,
AC_CHECK_HEADER(lm/model.hh,
[AC_DEFINE([HAVE_KENLM], [], [flag for KENLM])],
[AC_MSG_ERROR([Cannot find KEN-LM in ${with_kenlm}])])
[AC_MSG_ERROR([Cannot find KEN-LM in ${PWD}/kenlm])])
LIB_KENLM="-lkenlm"
LDFLAGS="$LDFLAGS -L${PWD}/kenlm"

View File

@ -4,7 +4,12 @@ bin_PROGRAMS = query build_binary
AM_CPPFLAGS = -W -Wall -ffor-scope -D_FILE_OFFSET_BITS=64 -D_LARGE_FILES $(BOOST_CPPFLAGS)
libkenlm_a_SOURCES = \
lm/lm_exception.cc \
lm/ngram.cc \
lm/config.cc \
lm/model.cc \
lm/search_hashed.cc \
lm/search_trie.cc \
lm/trie.cc \
lm/binary_format.cc \
lm/read_arpa.cc \
lm/virtual_interface.cc \
lm/vocab.cc \
@ -15,13 +20,14 @@ libkenlm_a_SOURCES = \
util/file_piece.cc \
util/ersatz_progress.cc \
util/exception.cc \
util/string_piece.cc
util/string_piece.cc \
util/bit_packing.cc
query_SOURCES = lm/ngram_query.cc
query_DEPENDENCIES = libkenlm.a
query_LDADD = -L$(top_srcdir)/kenlm -lkenlm
query_LDADD = -L$(top_srcdir)/kenlm -lkenlm -lz
build_binary_SOURCES = lm/ngram_build_binary.cc
build_binary_SOURCES = lm/build_binary.cc
build_binary_DEPENDENCIES = libkenlm.a
build_binary_LDADD = -L$(top_srcdir)/kenlm -lkenlm
build_binary_LDADD = -L$(top_srcdir)/kenlm -lkenlm -lz

View File

@ -1,12 +1,7 @@
Language model inference code by Kenneth Heafield <infer at kheafield.com>
See LICENSE for list of files by other people and their licenses.
The official website is http://kheafield.com/code/mt/infer.html . If you're a decoder developer, please download the latest version from there instead of copying from another decoder.
Compile: ./compile.sh
Run: ./query lm/test.arpa <text
Build binary format: ./build_binary lm/test.arpa test.binary
Use binary format: ./query test.binary <text
Test (uses Boost): ./test.sh
This documentation is directed at decoder developers.
Currently, it loads an ARPA file in 2/3 the time SRI takes and uses 6.5 GB when SRI takes 11 GB. These are compared to the default SRI build (i.e. without their smaller structures). I'm working on optimizing this even further.
@ -14,13 +9,13 @@ Binary format via mmap is supported. Run ./build_binary to make one then pass t
Currently, it assumes POSIX APIs for errno, sterror_r, open, close, mmap, munmap, ftruncate, fstat, and read. This is tested on Linux and the non-UNIX Mac OS X. I welcome submissions porting (via #ifdef) to other systems (e.g. Windows) but proudly have no machine on which to test it.
A brief note to Mac OS X users: your gcc is too old to recognize the pack pragma. The warning effectively means that, on 64-bit machines, the model will use 16 bytes instead of 12 bytes per n-gram of maximum order.
A brief note to Mac OS X users: your gcc is too old to recognize the pack pragma. The warning effectively means that, on 64-bit machines, the model will use 16 bytes instead of 12 bytes per n-gram of maximum order (those of lower order are already 16 bytes) in the probing and sorted models. The trie is not impacted by this.
It does not depend on Boost or ICU. However, if you use Boost and/or ICU in the rest of your code, you should define USE_BOOST and/or USE_ICU in util/string_piece.hh. Defining USE_BOOST will let you hash StringPiece. Defining USE_ICU will use ICU's StringPiece to prevent a conflict with the one provided here. By the way, ICU's StringPiece is buggy and I reported this bug: http://bugs.icu-project.org/trac/ticket/7924 .
It does not depend on Boost or ICU. However, if you use Boost and/or ICU in the rest of your code, you should define HAVE_BOOST and/or HAVE_ICU in util/string_piece.hh. Defining HAVE_BOOST will let you hash StringPiece. Defining HAVE_ICU will use ICU's StringPiece to prevent a conflict with the one provided here. By the way, ICU's StringPiece is buggy and I reported this bug: http://bugs.icu-project.org/trac/ticket/7924 .
The recommend way to use this:
Copy the code and distribute with your decoder.
Set USE_ICU and USE_BOOST at the top of util/string_piece.hh as instructed above.
Set HAVE_ICU and HAVE_BOOST at the top of util/string_piece.hh as instructed above.
Look at compile.sh and reimplement using your build system.
Use either the interface in lm/ngram.hh or lm/virtual_interface.hh
Interface documentation is in comments of lm/virtual_interface.hh (including for lm/ngram.hh).

View File

@ -9,19 +9,118 @@
/* Begin PBXBuildFile section */
1E2B85C412555DB1000770D6 /* lm_exception.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E2B85C112555DB1000770D6 /* lm_exception.cc */; };
1E2B85C512555DB1000770D6 /* lm_exception.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E2B85C212555DB1000770D6 /* lm_exception.hh */; };
1E37EBC112496A7D00C1C73A /* ngram.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E37EBBF12496A7D00C1C73A /* ngram.cc */; };
1E37EBC212496A7D00C1C73A /* ngram.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E37EBC012496A7D00C1C73A /* ngram.hh */; };
1E37EBC712496AB400C1C73A /* virtual_interface.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E37EBC512496AB400C1C73A /* virtual_interface.cc */; };
1E37EBC812496AB400C1C73A /* virtual_interface.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E37EBC612496AB400C1C73A /* virtual_interface.hh */; };
1E8A94FE1288BD570022C4EB /* build_binary.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8A94F41288BD570022C4EB /* build_binary.cc */; };
1E8A94FF1288BD570022C4EB /* config.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8A94F51288BD570022C4EB /* config.cc */; };
1E8A95001288BD570022C4EB /* config.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8A94F61288BD570022C4EB /* config.hh */; };
1E8A95011288BD570022C4EB /* model_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8A94F71288BD570022C4EB /* model_test.cc */; };
1E8A95021288BD570022C4EB /* model.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8A94F81288BD570022C4EB /* model.cc */; };
1E8A95031288BD570022C4EB /* model.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8A94F91288BD570022C4EB /* model.hh */; };
1E8A95041288BD570022C4EB /* search_hashed.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8A94FA1288BD570022C4EB /* search_hashed.cc */; };
1E8A95051288BD570022C4EB /* search_hashed.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8A94FB1288BD570022C4EB /* search_hashed.hh */; };
1E8A95061288BD570022C4EB /* search_trie.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8A94FC1288BD570022C4EB /* search_trie.cc */; };
1E8A95071288BD570022C4EB /* search_trie.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8A94FD1288BD570022C4EB /* search_trie.hh */; };
1E8BF78A1278A434009F10C1 /* binary_format.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8BF7871278A434009F10C1 /* binary_format.cc */; };
1E8BF78B1278A434009F10C1 /* binary_format.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8BF7881278A434009F10C1 /* binary_format.hh */; };
1E8BF78C1278A434009F10C1 /* enumerate_vocab.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8BF7891278A434009F10C1 /* enumerate_vocab.hh */; };
1E8BF79D1278A443009F10C1 /* trie.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8BF7951278A443009F10C1 /* trie.cc */; };
1E8BF79E1278A443009F10C1 /* trie.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1E8BF7961278A443009F10C1 /* trie.hh */; };
1E8BF7D51278A600009F10C1 /* bit_packing.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1E8BF7D41278A600009F10C1 /* bit_packing.cc */; };
1EBB16D7126C158600AE6102 /* ersatz_progress.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16BF126C158600AE6102 /* ersatz_progress.cc */; };
1EBB16D8126C158600AE6102 /* ersatz_progress.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16C0126C158600AE6102 /* ersatz_progress.hh */; };
1EBB16D9126C158600AE6102 /* exception.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16C1126C158600AE6102 /* exception.cc */; };
1EBB16DA126C158600AE6102 /* exception.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16C2126C158600AE6102 /* exception.hh */; };
1EBB16DC126C158600AE6102 /* file_piece.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16C4126C158600AE6102 /* file_piece.cc */; };
1EBB16DD126C158600AE6102 /* file_piece.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16C5126C158600AE6102 /* file_piece.hh */; };
1EBB16DE126C158600AE6102 /* joint_sort_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16C6126C158600AE6102 /* joint_sort_test.cc */; };
1EBB16DF126C158600AE6102 /* joint_sort.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16C7126C158600AE6102 /* joint_sort.hh */; };
1EBB16E0126C158600AE6102 /* key_value_packing_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16C8126C158600AE6102 /* key_value_packing_test.cc */; };
1EBB16E1126C158600AE6102 /* key_value_packing.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16C9126C158600AE6102 /* key_value_packing.hh */; };
1EBB16E2126C158600AE6102 /* mmap.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16CA126C158600AE6102 /* mmap.cc */; };
1EBB16E3126C158600AE6102 /* mmap.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16CB126C158600AE6102 /* mmap.hh */; };
1EBB16E4126C158600AE6102 /* murmur_hash.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16CC126C158600AE6102 /* murmur_hash.cc */; };
1EBB16E5126C158600AE6102 /* murmur_hash.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16CD126C158600AE6102 /* murmur_hash.hh */; };
1EBB16E6126C158600AE6102 /* probing_hash_table_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16CE126C158600AE6102 /* probing_hash_table_test.cc */; };
1EBB16E7126C158600AE6102 /* probing_hash_table.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16CF126C158600AE6102 /* probing_hash_table.hh */; };
1EBB16E8126C158600AE6102 /* proxy_iterator.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D0126C158600AE6102 /* proxy_iterator.hh */; };
1EBB16E9126C158600AE6102 /* scoped.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16D1126C158600AE6102 /* scoped.cc */; };
1EBB16EA126C158600AE6102 /* scoped.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D2126C158600AE6102 /* scoped.hh */; };
1EBB16EB126C158600AE6102 /* sorted_uniform_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16D3126C158600AE6102 /* sorted_uniform_test.cc */; };
1EBB16EC126C158600AE6102 /* sorted_uniform.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D4126C158600AE6102 /* sorted_uniform.hh */; };
1EBB16ED126C158600AE6102 /* string_piece.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16D5126C158600AE6102 /* string_piece.cc */; };
1EBB16EE126C158600AE6102 /* string_piece.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D6126C158600AE6102 /* string_piece.hh */; };
1EBB1717126C15C500AE6102 /* facade.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB1708126C15C500AE6102 /* facade.hh */; };
1EBB171A126C15C500AE6102 /* ngram_query.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB170B126C15C500AE6102 /* ngram_query.cc */; };
1EBB171C126C15C500AE6102 /* read_arpa.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB170D126C15C500AE6102 /* read_arpa.cc */; };
1EBB171D126C15C500AE6102 /* read_arpa.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB170E126C15C500AE6102 /* read_arpa.hh */; };
1EBB171E126C15C500AE6102 /* sri_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB170F126C15C500AE6102 /* sri_test.cc */; };
1EBB171F126C15C500AE6102 /* sri.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB1710126C15C500AE6102 /* sri.cc */; };
1EBB1720126C15C500AE6102 /* sri.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB1711126C15C500AE6102 /* sri.hh */; };
1EBB1721126C15C500AE6102 /* vocab.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB1713126C15C500AE6102 /* vocab.cc */; };
1EBB1722126C15C500AE6102 /* vocab.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB1714126C15C500AE6102 /* vocab.hh */; };
1EBB1723126C15C500AE6102 /* weights.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB1715126C15C500AE6102 /* weights.hh */; };
1EBB1724126C15C500AE6102 /* word_index.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB1716126C15C500AE6102 /* word_index.hh */; };
1ED9988712783457006BBB6C /* file_piece_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1ED9988612783457006BBB6C /* file_piece_test.cc */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
1E2B85C112555DB1000770D6 /* lm_exception.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = lm_exception.cc; path = lm/lm_exception.cc; sourceTree = "<group>"; };
1E2B85C212555DB1000770D6 /* lm_exception.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = lm_exception.hh; path = lm/lm_exception.hh; sourceTree = "<group>"; };
1E37EBBF12496A7D00C1C73A /* ngram.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ngram.cc; path = lm/ngram.cc; sourceTree = "<group>"; };
1E37EBC012496A7D00C1C73A /* ngram.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = ngram.hh; path = lm/ngram.hh; sourceTree = "<group>"; };
1E37EBC512496AB400C1C73A /* virtual_interface.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = virtual_interface.cc; path = lm/virtual_interface.cc; sourceTree = "<group>"; };
1E37EBC612496AB400C1C73A /* virtual_interface.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = virtual_interface.hh; path = lm/virtual_interface.hh; sourceTree = "<group>"; };
1E8A94F41288BD570022C4EB /* build_binary.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = build_binary.cc; path = lm/build_binary.cc; sourceTree = "<group>"; };
1E8A94F51288BD570022C4EB /* config.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = config.cc; path = lm/config.cc; sourceTree = "<group>"; };
1E8A94F61288BD570022C4EB /* config.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = config.hh; path = lm/config.hh; sourceTree = "<group>"; };
1E8A94F71288BD570022C4EB /* model_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = model_test.cc; path = lm/model_test.cc; sourceTree = "<group>"; };
1E8A94F81288BD570022C4EB /* model.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = model.cc; path = lm/model.cc; sourceTree = "<group>"; };
1E8A94F91288BD570022C4EB /* model.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = model.hh; path = lm/model.hh; sourceTree = "<group>"; };
1E8A94FA1288BD570022C4EB /* search_hashed.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = search_hashed.cc; path = lm/search_hashed.cc; sourceTree = "<group>"; };
1E8A94FB1288BD570022C4EB /* search_hashed.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = search_hashed.hh; path = lm/search_hashed.hh; sourceTree = "<group>"; };
1E8A94FC1288BD570022C4EB /* search_trie.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = search_trie.cc; path = lm/search_trie.cc; sourceTree = "<group>"; };
1E8A94FD1288BD570022C4EB /* search_trie.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = search_trie.hh; path = lm/search_trie.hh; sourceTree = "<group>"; };
1E8BF7871278A434009F10C1 /* binary_format.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = binary_format.cc; path = lm/binary_format.cc; sourceTree = "<group>"; };
1E8BF7881278A434009F10C1 /* binary_format.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = binary_format.hh; path = lm/binary_format.hh; sourceTree = "<group>"; };
1E8BF7891278A434009F10C1 /* enumerate_vocab.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = enumerate_vocab.hh; path = lm/enumerate_vocab.hh; sourceTree = "<group>"; };
1E8BF7951278A443009F10C1 /* trie.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = trie.cc; path = lm/trie.cc; sourceTree = "<group>"; };
1E8BF7961278A443009F10C1 /* trie.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = trie.hh; path = lm/trie.hh; sourceTree = "<group>"; };
1E8BF7D41278A600009F10C1 /* bit_packing.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = bit_packing.cc; path = util/bit_packing.cc; sourceTree = "<group>"; };
1EBB16BF126C158600AE6102 /* ersatz_progress.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ersatz_progress.cc; path = util/ersatz_progress.cc; sourceTree = "<group>"; };
1EBB16C0126C158600AE6102 /* ersatz_progress.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = ersatz_progress.hh; path = util/ersatz_progress.hh; sourceTree = "<group>"; };
1EBB16C1126C158600AE6102 /* exception.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = exception.cc; path = util/exception.cc; sourceTree = "<group>"; };
1EBB16C2126C158600AE6102 /* exception.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = exception.hh; path = util/exception.hh; sourceTree = "<group>"; };
1EBB16C4126C158600AE6102 /* file_piece.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = file_piece.cc; path = util/file_piece.cc; sourceTree = "<group>"; };
1EBB16C5126C158600AE6102 /* file_piece.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = file_piece.hh; path = util/file_piece.hh; sourceTree = "<group>"; };
1EBB16C6126C158600AE6102 /* joint_sort_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = joint_sort_test.cc; path = util/joint_sort_test.cc; sourceTree = "<group>"; };
1EBB16C7126C158600AE6102 /* joint_sort.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = joint_sort.hh; path = util/joint_sort.hh; sourceTree = "<group>"; };
1EBB16C8126C158600AE6102 /* key_value_packing_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = key_value_packing_test.cc; path = util/key_value_packing_test.cc; sourceTree = "<group>"; };
1EBB16C9126C158600AE6102 /* key_value_packing.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = key_value_packing.hh; path = util/key_value_packing.hh; sourceTree = "<group>"; };
1EBB16CA126C158600AE6102 /* mmap.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = mmap.cc; path = util/mmap.cc; sourceTree = "<group>"; };
1EBB16CB126C158600AE6102 /* mmap.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = mmap.hh; path = util/mmap.hh; sourceTree = "<group>"; };
1EBB16CC126C158600AE6102 /* murmur_hash.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = murmur_hash.cc; path = util/murmur_hash.cc; sourceTree = "<group>"; };
1EBB16CD126C158600AE6102 /* murmur_hash.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = murmur_hash.hh; path = util/murmur_hash.hh; sourceTree = "<group>"; };
1EBB16CE126C158600AE6102 /* probing_hash_table_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = probing_hash_table_test.cc; path = util/probing_hash_table_test.cc; sourceTree = "<group>"; };
1EBB16CF126C158600AE6102 /* probing_hash_table.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = probing_hash_table.hh; path = util/probing_hash_table.hh; sourceTree = "<group>"; };
1EBB16D0126C158600AE6102 /* proxy_iterator.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = proxy_iterator.hh; path = util/proxy_iterator.hh; sourceTree = "<group>"; };
1EBB16D1126C158600AE6102 /* scoped.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = scoped.cc; path = util/scoped.cc; sourceTree = "<group>"; };
1EBB16D2126C158600AE6102 /* scoped.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = scoped.hh; path = util/scoped.hh; sourceTree = "<group>"; };
1EBB16D3126C158600AE6102 /* sorted_uniform_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = sorted_uniform_test.cc; path = util/sorted_uniform_test.cc; sourceTree = "<group>"; };
1EBB16D4126C158600AE6102 /* sorted_uniform.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = sorted_uniform.hh; path = util/sorted_uniform.hh; sourceTree = "<group>"; };
1EBB16D5126C158600AE6102 /* string_piece.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = string_piece.cc; path = util/string_piece.cc; sourceTree = "<group>"; };
1EBB16D6126C158600AE6102 /* string_piece.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = string_piece.hh; path = util/string_piece.hh; sourceTree = "<group>"; };
1EBB1708126C15C500AE6102 /* facade.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = facade.hh; path = lm/facade.hh; sourceTree = "<group>"; };
1EBB170B126C15C500AE6102 /* ngram_query.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ngram_query.cc; path = lm/ngram_query.cc; sourceTree = "<group>"; };
1EBB170D126C15C500AE6102 /* read_arpa.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = read_arpa.cc; path = lm/read_arpa.cc; sourceTree = "<group>"; };
1EBB170E126C15C500AE6102 /* read_arpa.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = read_arpa.hh; path = lm/read_arpa.hh; sourceTree = "<group>"; };
1EBB170F126C15C500AE6102 /* sri_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = sri_test.cc; path = lm/sri_test.cc; sourceTree = "<group>"; };
1EBB1710126C15C500AE6102 /* sri.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = sri.cc; path = lm/sri.cc; sourceTree = "<group>"; };
1EBB1711126C15C500AE6102 /* sri.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = sri.hh; path = lm/sri.hh; sourceTree = "<group>"; };
1EBB1712126C15C500AE6102 /* test.arpa */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; name = test.arpa; path = lm/test.arpa; sourceTree = "<group>"; };
1EBB1713126C15C500AE6102 /* vocab.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = vocab.cc; path = lm/vocab.cc; sourceTree = "<group>"; };
1EBB1714126C15C500AE6102 /* vocab.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = vocab.hh; path = lm/vocab.hh; sourceTree = "<group>"; };
1EBB1715126C15C500AE6102 /* weights.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = weights.hh; path = lm/weights.hh; sourceTree = "<group>"; };
1EBB1716126C15C500AE6102 /* word_index.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = word_index.hh; path = lm/word_index.hh; sourceTree = "<group>"; };
1ED9988612783457006BBB6C /* file_piece_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = file_piece_test.cc; path = util/file_piece_test.cc; sourceTree = "<group>"; };
D2AAC046055464E500DB518D /* libkenlm.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libkenlm.a; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
@ -49,12 +148,62 @@
08FB7795FE84155DC02AAC07 /* Source */ = {
isa = PBXGroup;
children = (
1E8A94F41288BD570022C4EB /* build_binary.cc */,
1E8A94F51288BD570022C4EB /* config.cc */,
1E8A94F61288BD570022C4EB /* config.hh */,
1E8A94F71288BD570022C4EB /* model_test.cc */,
1E8A94F81288BD570022C4EB /* model.cc */,
1E8A94F91288BD570022C4EB /* model.hh */,
1E8A94FA1288BD570022C4EB /* search_hashed.cc */,
1E8A94FB1288BD570022C4EB /* search_hashed.hh */,
1E8A94FC1288BD570022C4EB /* search_trie.cc */,
1E8A94FD1288BD570022C4EB /* search_trie.hh */,
1E8BF7D41278A600009F10C1 /* bit_packing.cc */,
1E8BF7951278A443009F10C1 /* trie.cc */,
1E8BF7961278A443009F10C1 /* trie.hh */,
1E8BF7871278A434009F10C1 /* binary_format.cc */,
1E8BF7881278A434009F10C1 /* binary_format.hh */,
1E8BF7891278A434009F10C1 /* enumerate_vocab.hh */,
1ED9988612783457006BBB6C /* file_piece_test.cc */,
1EBB1708126C15C500AE6102 /* facade.hh */,
1EBB170B126C15C500AE6102 /* ngram_query.cc */,
1EBB170D126C15C500AE6102 /* read_arpa.cc */,
1EBB170E126C15C500AE6102 /* read_arpa.hh */,
1EBB170F126C15C500AE6102 /* sri_test.cc */,
1EBB1710126C15C500AE6102 /* sri.cc */,
1EBB1711126C15C500AE6102 /* sri.hh */,
1EBB1712126C15C500AE6102 /* test.arpa */,
1EBB1713126C15C500AE6102 /* vocab.cc */,
1EBB1714126C15C500AE6102 /* vocab.hh */,
1EBB1715126C15C500AE6102 /* weights.hh */,
1EBB1716126C15C500AE6102 /* word_index.hh */,
1EBB16BF126C158600AE6102 /* ersatz_progress.cc */,
1EBB16C0126C158600AE6102 /* ersatz_progress.hh */,
1EBB16C1126C158600AE6102 /* exception.cc */,
1EBB16C2126C158600AE6102 /* exception.hh */,
1EBB16C4126C158600AE6102 /* file_piece.cc */,
1EBB16C5126C158600AE6102 /* file_piece.hh */,
1EBB16C6126C158600AE6102 /* joint_sort_test.cc */,
1EBB16C7126C158600AE6102 /* joint_sort.hh */,
1EBB16C8126C158600AE6102 /* key_value_packing_test.cc */,
1EBB16C9126C158600AE6102 /* key_value_packing.hh */,
1EBB16CA126C158600AE6102 /* mmap.cc */,
1EBB16CB126C158600AE6102 /* mmap.hh */,
1EBB16CC126C158600AE6102 /* murmur_hash.cc */,
1EBB16CD126C158600AE6102 /* murmur_hash.hh */,
1EBB16CE126C158600AE6102 /* probing_hash_table_test.cc */,
1EBB16CF126C158600AE6102 /* probing_hash_table.hh */,
1EBB16D0126C158600AE6102 /* proxy_iterator.hh */,
1EBB16D1126C158600AE6102 /* scoped.cc */,
1EBB16D2126C158600AE6102 /* scoped.hh */,
1EBB16D3126C158600AE6102 /* sorted_uniform_test.cc */,
1EBB16D4126C158600AE6102 /* sorted_uniform.hh */,
1EBB16D5126C158600AE6102 /* string_piece.cc */,
1EBB16D6126C158600AE6102 /* string_piece.hh */,
1E2B85C112555DB1000770D6 /* lm_exception.cc */,
1E2B85C212555DB1000770D6 /* lm_exception.hh */,
1E37EBC512496AB400C1C73A /* virtual_interface.cc */,
1E37EBC612496AB400C1C73A /* virtual_interface.hh */,
1E37EBBF12496A7D00C1C73A /* ngram.cc */,
1E37EBC012496A7D00C1C73A /* ngram.hh */,
);
name = Source;
sourceTree = "<group>";
@ -81,9 +230,33 @@
isa = PBXHeadersBuildPhase;
buildActionMask = 2147483647;
files = (
1E37EBC212496A7D00C1C73A /* ngram.hh in Headers */,
1E37EBC812496AB400C1C73A /* virtual_interface.hh in Headers */,
1E2B85C512555DB1000770D6 /* lm_exception.hh in Headers */,
1EBB16D8126C158600AE6102 /* ersatz_progress.hh in Headers */,
1EBB16DA126C158600AE6102 /* exception.hh in Headers */,
1EBB16DD126C158600AE6102 /* file_piece.hh in Headers */,
1EBB16DF126C158600AE6102 /* joint_sort.hh in Headers */,
1EBB16E1126C158600AE6102 /* key_value_packing.hh in Headers */,
1EBB16E3126C158600AE6102 /* mmap.hh in Headers */,
1EBB16E5126C158600AE6102 /* murmur_hash.hh in Headers */,
1EBB16E7126C158600AE6102 /* probing_hash_table.hh in Headers */,
1EBB16E8126C158600AE6102 /* proxy_iterator.hh in Headers */,
1EBB16EA126C158600AE6102 /* scoped.hh in Headers */,
1EBB16EC126C158600AE6102 /* sorted_uniform.hh in Headers */,
1EBB16EE126C158600AE6102 /* string_piece.hh in Headers */,
1EBB1717126C15C500AE6102 /* facade.hh in Headers */,
1EBB171D126C15C500AE6102 /* read_arpa.hh in Headers */,
1EBB1720126C15C500AE6102 /* sri.hh in Headers */,
1EBB1722126C15C500AE6102 /* vocab.hh in Headers */,
1EBB1723126C15C500AE6102 /* weights.hh in Headers */,
1EBB1724126C15C500AE6102 /* word_index.hh in Headers */,
1E8BF78B1278A434009F10C1 /* binary_format.hh in Headers */,
1E8BF78C1278A434009F10C1 /* enumerate_vocab.hh in Headers */,
1E8BF79E1278A443009F10C1 /* trie.hh in Headers */,
1E8A95001288BD570022C4EB /* config.hh in Headers */,
1E8A95031288BD570022C4EB /* model.hh in Headers */,
1E8A95051288BD570022C4EB /* search_hashed.hh in Headers */,
1E8A95071288BD570022C4EB /* search_trie.hh in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -129,9 +302,34 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
1E37EBC112496A7D00C1C73A /* ngram.cc in Sources */,
1E37EBC712496AB400C1C73A /* virtual_interface.cc in Sources */,
1E2B85C412555DB1000770D6 /* lm_exception.cc in Sources */,
1EBB16D7126C158600AE6102 /* ersatz_progress.cc in Sources */,
1EBB16D9126C158600AE6102 /* exception.cc in Sources */,
1EBB16DC126C158600AE6102 /* file_piece.cc in Sources */,
1EBB16DE126C158600AE6102 /* joint_sort_test.cc in Sources */,
1EBB16E0126C158600AE6102 /* key_value_packing_test.cc in Sources */,
1EBB16E2126C158600AE6102 /* mmap.cc in Sources */,
1EBB16E4126C158600AE6102 /* murmur_hash.cc in Sources */,
1EBB16E6126C158600AE6102 /* probing_hash_table_test.cc in Sources */,
1EBB16E9126C158600AE6102 /* scoped.cc in Sources */,
1EBB16EB126C158600AE6102 /* sorted_uniform_test.cc in Sources */,
1EBB16ED126C158600AE6102 /* string_piece.cc in Sources */,
1EBB171A126C15C500AE6102 /* ngram_query.cc in Sources */,
1EBB171C126C15C500AE6102 /* read_arpa.cc in Sources */,
1EBB171E126C15C500AE6102 /* sri_test.cc in Sources */,
1EBB171F126C15C500AE6102 /* sri.cc in Sources */,
1EBB1721126C15C500AE6102 /* vocab.cc in Sources */,
1ED9988712783457006BBB6C /* file_piece_test.cc in Sources */,
1E8BF78A1278A434009F10C1 /* binary_format.cc in Sources */,
1E8BF79D1278A443009F10C1 /* trie.cc in Sources */,
1E8BF7D51278A600009F10C1 /* bit_packing.cc in Sources */,
1E8A94FE1288BD570022C4EB /* build_binary.cc in Sources */,
1E8A94FF1288BD570022C4EB /* config.cc in Sources */,
1E8A95011288BD570022C4EB /* model_test.cc in Sources */,
1E8A95021288BD570022C4EB /* model.cc in Sources */,
1E8A95041288BD570022C4EB /* search_hashed.cc in Sources */,
1E8A95061288BD570022C4EB /* search_trie.cc in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -174,6 +372,7 @@
HEADER_SEARCH_PATHS = (
/Users/hieuhoang/workspace/sourceforge/trunk/kenlm,
/usr/local/include,
../srilm/include,
);
ONLY_ACTIVE_ARCH = YES;
PREBINDING = NO;
@ -191,6 +390,7 @@
HEADER_SEARCH_PATHS = (
/Users/hieuhoang/workspace/sourceforge/trunk/kenlm,
/usr/local/include,
../srilm/include,
);
PREBINDING = NO;
SDKROOT = macosx10.6;

190
kenlm/lm/binary_format.cc Normal file
View File

@ -0,0 +1,190 @@
#include "lm/binary_format.hh"
#include "lm/lm_exception.hh"
#include "util/file_piece.hh"
#include <limits>
#include <string>
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
namespace lm {
namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0";
const long int kMagicVersion = 1;
// Test values.
struct Sanity {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
max_word_index = std::numeric_limits<WordIndex>::max();
one_uint64 = 1;
}
};
const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"};
std::size_t Align8(std::size_t in) {
std::size_t off = in % 8;
if (!off) return in;
return in + 8 - off;
}
std::size_t TotalHeaderSize(unsigned char order) {
return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
void ReadLoop(int fd, void *to_void, std::size_t size) {
uint8_t *to = static_cast<uint8_t*>(to_void);
while (size) {
ssize_t ret = read(fd, to, size);
if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file");
if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short");
to += ret;
size -= ret;
}
}
void WriteHeader(void *to, const Parameters &params) {
Sanity header = Sanity();
header.SetToReference();
memcpy(to, &header, sizeof(Sanity));
char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
*reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
out += sizeof(FixedWidthParameters);
uint64_t *counts = reinterpret_cast<uint64_t*>(out);
for (std::size_t i = 0; i < params.counts.size(); ++i) {
counts[i] = params.counts[i];
}
}
} // namespace
namespace detail {
bool IsBinaryFormat(int fd) {
const off_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false;
// Try reading the header.
util::scoped_memory memory;
try {
util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
} catch (const util::Exception &e) {
return false;
}
Sanity reference_header = Sanity();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
long int version = strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry.");
}
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture.");
}
return false;
}
void ReadHeader(int fd, Parameters &out) {
if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file");
ReadLoop(fd, &out.fixed, sizeof(out.fixed));
if (out.fixed.probing_multiplier < 1.0)
UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
out.counts.resize(static_cast<std::size_t>(out.fixed.order));
ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
}
void MatchCheck(ModelType model_type, const Parameters &params) {
if (params.fixed.model_type != model_type) {
if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
}
}
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
const off_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory);
if (config.enumerate_vocab && !params.fixed.has_vocabulary)
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
if (config.enumerate_vocab) {
if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words");
}
return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size());
}
uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;
// Write out an mmap file.
backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED);
Parameters params;
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
WriteHeader(backing.memory.get(), params);
if (params.fixed.has_vocabulary) {
if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words");
}
return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size());
} else {
backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.memory.get());
}
}
void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
} else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
}
}
} // namespace detail
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
if (!detail::IsBinaryFormat(fd.get())) return false;
Parameters params;
detail::ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}
} // namespace ngram
} // namespace lm

94
kenlm/lm/binary_format.hh Normal file
View File

@ -0,0 +1,94 @@
#ifndef LM_BINARY_FORMAT__
#define LM_BINARY_FORMAT__
#include "lm/config.hh"
#include "lm/read_arpa.hh"
#include "util/file_piece.hh"
#include "util/mmap.hh"
#include "util/scoped.hh"
#include <cstddef>
#include <vector>
#include <inttypes.h>
namespace lm {
namespace ngram {
typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType;
struct FixedWidthParameters {
unsigned char order;
float probing_multiplier;
// What type of model is this?
ModelType model_type;
// Does the end of the file have the actual strings in the vocabulary?
bool has_vocabulary;
};
struct Parameters {
FixedWidthParameters fixed;
std::vector<uint64_t> counts;
};
struct Backing {
// File behind memory, if any.
util::scoped_fd file;
// Raw block of memory backing the language model data structures
util::scoped_memory memory;
};
namespace detail {
bool IsBinaryFormat(int fd);
void ReadHeader(int fd, Parameters &params);
void MatchCheck(ModelType model_type, const Parameters &params);
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing);
void ComplainAboutARPA(const Config &config, ModelType model_type);
} // namespace detail
bool RecognizeBinary(const char *file, ModelType &recognized);
template <class To> void LoadLM(const char *file, const Config &config, To &to) {
Backing &backing = to.MutableBacking();
backing.file.reset(util::OpenReadOrThrow(file));
Parameters params;
try {
if (detail::IsBinaryFormat(backing.file.get())) {
detail::ReadHeader(backing.file.get(), params);
detail::MatchCheck(To::kModelType, params);
// Replace the probing_multiplier.
Config new_config(config);
new_config.probing_multiplier = params.fixed.probing_multiplier;
std::size_t memory_size = To::Size(params.counts, new_config);
uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {
detail::ComplainAboutARPA(config, To::kModelType);
util::FilePiece f(backing.file.release(), file, config.messages);
ReadARPACounts(f, params.counts);
std::size_t memory_size = To::Size(params.counts, config);
uint8_t *start = detail::SetupZeroed(config, To::kModelType, params.counts, memory_size, backing);
to.InitializeFromARPA(file, f, start, params, config);
}
} catch (util::Exception &e) {
e << " in file " << file;
throw;
}
}
} // namespace ngram
} // namespace lm
#endif // LM_BINARY_FORMAT__

113
kenlm/lm/build_binary.cc Normal file
View File

@ -0,0 +1,113 @@
#include "lm/model.hh"
#include "util/file_piece.hh"
#include <iostream>
#include <iomanip>
#include <math.h>
#include <stdlib.h>
#include <unistd.h>
namespace lm {
namespace ngram {
namespace {
void Usage(const char *name) {
std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
"Where type is one of probing, trie, or sorted:\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
"-t is the temporary directory prefix. Default is the output file name.\n"
"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n"
"sorted is like probing but uses a sorted uniform map instead of a hash table.\n"
"It uses more memory than trie and is also slower, so there's no real reason to\n"
"use it.\n\n"
"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n"
"Passing only an input file will print memory usage of each data structure.\n"
"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n";
exit(1);
}
// I could really use boost::lexical_cast right about now.
float ParseFloat(const char *from) {
char *end;
float ret = strtod(from, &end);
if (*end) throw util::ParseNumberException(from);
return ret;
}
unsigned long int ParseUInt(const char *from) {
char *end;
unsigned long int ret = strtoul(from, &end, 10);
if (*end) throw util::ParseNumberException(from);
return ret;
}
void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
std::size_t probing_size = ProbingModel::Size(counts, config);
// probing is always largest so use it to determine number of columns.
long int length = std::max<long int>(5, lrint(ceil(log10(probing_size))));
std::cout << "Memory usage:\ntype ";
// right align bytes.
for (long int i = 0; i < length - 5; ++i) std::cout << ' ';
std::cout << "bytes\n"
"probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"
"sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";
}
} // namespace ngram
} // namespace lm
} // namespace
int main(int argc, char *argv[]) {
using namespace lm::ngram;
lm::ngram::Config config;
int opt;
while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) {
switch(opt) {
case 'u':
config.unknown_missing_prob = ParseFloat(optarg);
break;
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
case 't':
config.temporary_directory_prefix = optarg;
break;
case 'm':
config.building_memory = ParseUInt(optarg) * 1048576;
break;
default:
Usage(argv[0]);
}
}
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
} else if (optind + 2 == argc) {
config.write_mmap = argv[optind + 1];
ProbingModel(argv[optind], config);
} else if (optind + 3 == argc) {
const char *model_type = argv[optind];
const char *from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
if (!strcmp(model_type, "probing")) {
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "sorted")) {
SortedModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
TrieModel(from_file, config);
} else {
Usage(argv[0]);
}
} else {
Usage(argv[0]);
}
return 0;
}

22
kenlm/lm/config.cc Normal file
View File

@ -0,0 +1,22 @@
#include "lm/config.hh"
#include <iostream>
namespace lm {
namespace ngram {
Config::Config() :
messages(&std::cerr),
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
unknown_missing_prob(0.0),
probing_multiplier(1.5),
building_memory(1073741824ULL), // 1 GB
temporary_directory_prefix(NULL),
arpa_complain(ALL),
write_mmap(NULL),
include_vocab(true),
load_method(util::POPULATE_OR_READ) {}
} // namespace ngram
} // namespace lm

83
kenlm/lm/config.hh Normal file
View File

@ -0,0 +1,83 @@
#ifndef LM_CONFIG__
#define LM_CONFIG__
#include <iosfwd>
#include "util/mmap.hh"
/* Configuration for ngram model. Separate header to reduce pollution. */
namespace lm { namespace ngram {
class EnumerateVocab;
struct Config {
// EFFECTIVE FOR BOTH ARPA AND BINARY READS
// Where to log messages including the progress bar. Set to NULL for
// silence.
std::ostream *messages;
// This will be called with every string in the vocabulary. See
// enumerate_vocab.hh for more detail. Config does not take ownership; you
// are still responsible for deleting it (or stack allocating).
EnumerateVocab *enumerate_vocab;
// ONLY EFFECTIVE WHEN READING ARPA
// What to do when <unk> isn't in the provided model.
typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
UnknownMissing unknown_missing;
// The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_prob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
// for sorted variant.
// If you find yourself setting this to a low number, consider using the
// Sorted version instead which has lower memory consumption.
float probing_multiplier;
// Amount of memory to use for building. The actual memory usage will be
// higher since this just sets sort buffer size. Only applies to trie
// models.
std::size_t building_memory;
// Template for temporary directory appropriate for passing to mkdtemp.
// The characters XXXXXX are appended before passing to mkdtemp. Only
// applies to trie. If NULL, defaults to write_mmap. If that's NULL,
// defaults to input file name.
const char *temporary_directory_prefix;
// Level of complaining to do when an ARPA instead of a binary format.
typedef enum {ALL, EXPENSIVE, NONE} ARPALoadComplain;
ARPALoadComplain arpa_complain;
// While loading an ARPA file, also write out this binary format file. Set
// to NULL to disable.
const char *write_mmap;
// Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
// ONLY EFFECTIVE WHEN READING BINARY
// How to get the giant array into memory: lazy mmap, populate, read etc.
// See util/mmap.hh for details of MapMethod.
util::LoadMethod load_method;
// Set defaults.
Config();
};
} /* namespace ngram */ } /* namespace lm */
#endif // LM_CONFIG__

View File

@ -0,0 +1,30 @@
#ifndef LM_ENUMERATE_VOCAB__
#define LM_ENUMERATE_VOCAB__
#include "lm/word_index.hh"
#include "util/string_piece.hh"
namespace lm {
namespace ngram {
/* If you need the actual strings in the vocabulary, inherit from this class
* and implement Add. Then put a pointer in Config.enumerate_vocab; it does
* not take ownership. Add is called once per vocab word. index starts at 0
* and increases by 1 each time. This is only used by the Model constructor;
* the pointer is not retained by the class.
*/
class EnumerateVocab {
public:
virtual ~EnumerateVocab() {}
virtual void Add(WordIndex index, const StringPiece &str) = 0;
protected:
EnumerateVocab() {}
};
} // namespace ngram
} // namespace lm
#endif // LM_ENUMERATE_VOCAB__

View File

@ -5,14 +5,18 @@
namespace lm {
ConfigException::ConfigException() throw() {}
ConfigException::~ConfigException() throw() {}
LoadException::LoadException() throw() {}
LoadException::~LoadException() throw() {}
VocabLoadException::VocabLoadException() throw() {}
VocabLoadException::~VocabLoadException() throw() {}
FormatLoadException::FormatLoadException() throw() {}
FormatLoadException::~FormatLoadException() throw() {}
VocabLoadException::VocabLoadException() throw() {}
VocabLoadException::~VocabLoadException() throw() {}
SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
*this << "Missing special word " << which;
}

View File

@ -1,5 +1,7 @@
#ifndef LM_EXCEPTION__
#define LM_EXCEPTION__
#ifndef LM_LM_EXCEPTION__
#define LM_LM_EXCEPTION__
// Named to avoid conflict with util/exception.hh.
#include "util/exception.hh"
#include "util/string_piece.hh"
@ -9,6 +11,12 @@
namespace lm {
class ConfigException : public util::Exception {
public:
ConfigException() throw();
~ConfigException() throw();
};
class LoadException : public util::Exception {
public:
virtual ~LoadException() throw();
@ -17,18 +25,18 @@ class LoadException : public util::Exception {
LoadException() throw();
};
class VocabLoadException : public LoadException {
public:
virtual ~VocabLoadException() throw();
VocabLoadException() throw();
};
class FormatLoadException : public LoadException {
public:
FormatLoadException() throw();
~FormatLoadException() throw();
};
class VocabLoadException : public LoadException {
public:
virtual ~VocabLoadException() throw();
VocabLoadException() throw();
};
class SpecialWordMissingException : public VocabLoadException {
public:
explicit SpecialWordMissingException(StringPiece which) throw();
@ -37,4 +45,4 @@ class SpecialWordMissingException : public VocabLoadException {
} // namespace lm
#endif
#endif // LM_LM_EXCEPTION

240
kenlm/lm/model.cc Normal file
View File

@ -0,0 +1,240 @@
#include "lm/model.hh"
#include "lm/lm_exception.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/read_arpa.hh"
#include "util/murmur_hash.hh"
#include <algorithm>
#include <functional>
#include <numeric>
#include <cmath>
namespace lm {
namespace ngram {
size_t hash_value(const State &state) {
return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
}
namespace detail {
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
uint8_t *start = static_cast<uint8_t*>(base);
size_t allocated = VocabularyT::Size(counts[0], config);
vocab_.SetupMemory(start, allocated, counts[0], config);
start += allocated;
start = search_.SetupMemory(start, counts, config);
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config));
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
LoadLM(file, config, *this);
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.valid_length_ = 1;
begin_sentence.history_[0] = vocab_.BeginSentence();
begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff;
State null_context = State();
null_context.valid_length_ = 0;
P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
SetupMemory(start, params.counts, config);
vocab_.LoadedBinary(fd, config.enumerate_vocab);
search_.unigram.LoadedBinary();
for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) {
i->LoadedBinary();
}
search_.longest.LoadedBinary();
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config) {
SetupMemory(start, params.counts, config);
if (config.write_mmap) {
WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get());
vocab_.ConfigureEnumerate(&wrap, params.counts[0]);
search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]);
search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
}
// TODO: fail faster?
if (!vocab_.SawUnk()) {
switch(config.unknown_missing) {
case Config::THROW_UP:
{
SpecialWordMissingException e("<unk>");
e << " and configuration was set to throw if unknown is missing";
throw e;
}
case Config::COMPLAIN:
if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
// There's no break;. This is by design.
case Config::SILENT:
// Default probabilities for unknown.
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_prob;
break;
}
}
if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
unsigned char backoff_start;
FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state);
if (backoff_start - 1 < in_state.valid_length_) {
ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
}
return ret;
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
unsigned char backoff_start;
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state);
ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start);
return ret;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin || *context_rbegin == 0) {
out_state.valid_length_ = 0;
return;
}
float ignored_prob;
typename Search::Node node;
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
float *backoff_out = out_state.backoff_ + 1;
const WordIndex *i = context_rbegin + 1;
for (; i < context_rend; ++i, ++backoff_out) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) {
out_state.valid_length_ = i - context_rbegin;
std::copy(context_rbegin, i, out_state.history_);
return;
}
}
std::copy(context_rbegin, context_rend, out_state.history_);
out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin);
}
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const {
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0;
float ret = 0.0;
if (start == 1) {
ret += search_.unigram.Lookup(*context_rbegin).backoff;
start = 2;
}
typename Search::Node node;
if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return 0.0;
}
float backoff;
// i is the order of the backoff we're looking for.
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
ret += backoff;
}
return ret;
}
/* Ugly optimized function. Produce a score excluding backoff.
* The search goes in increasing order of ngram length.
* Context goes backward, so context_begin is the word immediately preceeding
* new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
const WordIndex *context_rbegin,
const WordIndex *context_rend,
const WordIndex new_word,
unsigned char &backoff_start,
State &out_state) const {
FullScoreReturn ret;
typename Search::Node node;
float *backoff_out(out_state.backoff_);
search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
if (new_word == 0) {
ret.ngram_length = out_state.valid_length_ = 0;
// All of backoff.
backoff_start = 1;
return ret;
}
out_state.history_[0] = new_word;
if (context_rbegin == context_rend) {
ret.ngram_length = out_state.valid_length_ = 1;
// No backoff because we don't have the history for it.
backoff_start = P::Order();
return ret;
}
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
const WordIndex *hist_iter = context_rbegin;
typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. No backoff.
backoff_start = P::Order();
std::copy(context_rbegin, context_rend, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1;
// ret.prob was already set.
return ret;
}
if (mid_iter == search_.middle.end()) break;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
// Didn't find an ngram using hist_iter.
// The history used in the found n-gram is [context_rbegin, hist_iter).
std::copy(context_rbegin, hist_iter, out_state.history_ + 1);
// Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word.
ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1;
backoff_start = mid_iter - search_.middle.begin() + 1;
// ret.prob was already set.
return ret;
}
}
// It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram.
// All that's left is to check search_.longest.
if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
// It's an (P::Order()-1)-gram
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
backoff_start = P::Order() - 1;
// ret.prob was already set.
return ret;
}
// It's an P::Order()-gram
// out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
ret.ngram_length = P::Order();
backoff_start = P::Order();
return ret;
}
template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;
template class GenericModel<SortedHashedSearch, SortedVocabulary>;
template class GenericModel<trie::TrieSearch, SortedVocabulary>;
} // namespace detail
} // namespace ngram
} // namespace lm

145
kenlm/lm/model.hh Normal file
View File

@ -0,0 +1,145 @@
#ifndef LM_MODEL__
#define LM_MODEL__
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
#include <algorithm>
#include <vector>
#include <string.h>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
// If you need higher order, change this and recompile.
// Having this limit means that State can be
// (kMaxOrder - 1) * sizeof(float) bytes instead of
// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
const unsigned char kMaxOrder = 6;
// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.
class State {
public:
bool operator==(const State &other) const {
if (valid_length_ != other.valid_length_) return false;
const WordIndex *end = history_ + valid_length_;
for (const WordIndex *first = history_, *second = other.history_;
first != end; ++first, ++second) {
if (*first != *second) return false;
}
// If the histories are equal, so are the backoffs.
return true;
}
// Three way comparison function.
int Compare(const State &other) const {
if (valid_length_ == other.valid_length_) {
return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex));
}
return (valid_length_ < other.valid_length_) ? -1 : 1;
}
// Call this before using raw memcmp.
void ZeroRemaining() {
for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) {
history_[i] = 0;
backoff_[i] = 0.0;
}
}
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex history_[kMaxOrder - 1];
float backoff_[kMaxOrder - 1];
unsigned char valid_length_;
};
size_t hash_value(const State &state);
namespace detail {
// Should return the same results as SRI.
// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// Get the size of memory that will be mapped given ngram counts. This
// does not include small non-mapped control structures, such as this class
// itself.
static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
GenericModel(const char *file, const Config &config = Config());
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
/* Slower call without in_state. Don't use this if you can avoid it. This
* is mostly a hack for Hieu to integrate it into Moses which sometimes
* forgets LM state (i.e. it doesn't store it with the phrase). Sigh.
* The context indices should be in an array.
* If context_rbegin != context_rend then *context_rbegin is the word
* before new_word.
*/
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
/* Get the state for a context. Don't use this if you can avoid it. Use
* BeginSentenceState or EmptyContextState and extend from those. If
* you're only going to use this state to call FullScore once, use
* FullScoreForgotState. */
void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
private:
friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;
FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config);
Backing &MutableBacking() { return backing_; }
static const ModelType kModelType = Search::kModelType;
Backing backing_;
VocabularyT vocab_;
typedef typename Search::Unigram Unigram;
typedef typename Search::Middle Middle;
typedef typename Search::Longest Longest;
Search search_;
};
} // namespace detail
// These must also be instantiated in the cc file.
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel;
// Default implementation. No real reason for it to be the default.
typedef ProbingModel Model;
typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel;
typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel;
} // namespace ngram
} // namespace lm
#endif // LM_MODEL__

202
kenlm/lm/model_test.cc Normal file
View File

@ -0,0 +1,202 @@
#include "lm/model.hh"
#include <stdlib.h>
#define BOOST_TEST_MODULE ModelTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
namespace lm {
namespace ngram {
namespace {
#define StartTest(word, ngram, score) \
ret = model.FullScore( \
state, \
model.GetVocabulary().Index(word), \
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
#define AppendTest(word, ngram, score) \
StartTest(word, ngram, score) \
state = out;
template <class M> void Starters(const M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
StartTest("looking", 2, -0.4846522);
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733);
// <unk> probability plus <s> backoff
StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
}
template <class M> void Continuation(const M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
AppendTest("looking", 2, -0.484652);
AppendTest("on", 3, -0.348837);
AppendTest("a", 4, -0.0155266);
AppendTest("little", 5, -0.00306122);
State preserve = state;
AppendTest("the", 1, -4.04005);
AppendTest("biarritz", 1, -1.9889);
AppendTest("not_found", 0, -2.29666);
AppendTest("more", 1, -1.20632);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
state = preserve;
AppendTest("more", 5, -0.00181395);
AppendTest("loin", 5, -0.0432557);
}
#define StatelessTest(word, provide, ngram, score) \
ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
ret = model.FullScore(before, indices[num_words - word - 1], out); \
BOOST_CHECK(state == out); \
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
template <class M> void Stateless(const M &model) {
const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
const size_t num_words = sizeof(words) / sizeof(const char*);
// Silience "array subscript is above array bounds" when extracting end pointer.
WordIndex indices[num_words + 1];
for (unsigned int i = 0; i < num_words; ++i) {
indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]);
}
FullScoreReturn ret;
State state, out, before;
ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
BOOST_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
StatelessTest(1, 1, 2, -0.484652);
// looking
StatelessTest(1, 2, 2, -0.484652);
// on
AppendTest("on", 3, -0.348837);
StatelessTest(2, 3, 3, -0.348837);
StatelessTest(2, 2, 3, -0.348837);
StatelessTest(2, 1, 2, -0.4638903);
// a
StatelessTest(3, 4, 4, -0.0155266);
// little
AppendTest("little", 5, -0.00306122);
StatelessTest(4, 5, 5, -0.00306122);
// the
AppendTest("the", 1, -4.04005);
StatelessTest(5, 5, 1, -4.04005);
// No context of the.
StatelessTest(5, 0, 1, -1.687872);
// biarritz
StatelessTest(6, 1, 1, -1.9889);
// not found
StatelessTest(7, 1, 0, -2.29666);
StatelessTest(7, 0, 0, -1.995635);
WordIndex unk[1];
unk[0] = 0;
model.GetState(unk, unk + 1, state);
BOOST_CHECK_EQUAL(0, state.valid_length_);
}
//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"};
class ExpectEnumerateVocab : public EnumerateVocab {
public:
ExpectEnumerateVocab() {}
void Add(WordIndex index, const StringPiece &str) {
BOOST_CHECK_EQUAL(seen.size(), index);
seen.push_back(std::string(str.data(), str.length()));
}
void Check(const base::Vocabulary &vocab) {
BOOST_CHECK_EQUAL(34ULL, seen.size());
BOOST_REQUIRE(!seen.empty());
BOOST_CHECK_EQUAL("<unk>", seen[0]);
for (WordIndex i = 0; i < seen.size(); ++i) {
BOOST_CHECK_EQUAL(i, vocab.Index(seen[i]));
}
}
void Clear() {
seen.clear();
}
std::vector<std::string> seen;
};
template <class ModelT> void LoadingTest() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
config.probing_multiplier = 2.0;
ModelT m("test.arpa", config);
enumerate.Check(m.GetVocabulary());
Starters(m);
Continuation(m);
Stateless(m);
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
BOOST_AUTO_TEST_CASE(sorted) {
LoadingTest<SortedModel>();
}
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
template <class ModelT> void BinaryTest() {
Config config;
config.write_mmap = "test.binary";
config.messages = NULL;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
{
ModelT copy_model("test.arpa", config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
}
config.write_mmap = NULL;
ModelT binary("test.binary", config);
enumerate.Check(binary.GetVocabulary());
Starters(binary);
Continuation(binary);
Stateless(binary);
unlink("test.binary");
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<Model>();
}
BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
BinaryTest<SortedModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
}
} // namespace
} // namespace ngram
} // namespace lm

View File

@ -1,438 +0,0 @@
#include "ngram.hh"
#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
#include "util/file_piece.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
#include <algorithm>
#include <functional>
#include <numeric>
#include <limits>
#include <string>
#include <cmath>
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
namespace lm {
namespace ngram {
size_t hash_value(const State &state) {
return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
}
namespace {
/* All of the entropy is in low order bits and boost::hash does poorly with
* these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
* stable point: 0. But 0 is <unk> which won't be queried here anyway.
*/
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
return ret;
}
uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
if (word == word_end) return 0;
uint64_t current = static_cast<uint64_t>(*word);
for (++word; word != word_end; ++word) {
current = CombineWordHash(current, *word);
}
return current;
}
template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
ReadNGramHeader(f, n);
// vocab ids of words in reverse order
WordIndex vocab_ids[n];
typename Store::Packing::Value value;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, vocab_ids, value);
uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
store.Insert(Store::Packing::Make(key, value));
}
if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset());
store.FinishedInserting();
}
} // namespace
namespace detail {
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) {
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier);
memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk>
for (unsigned char n = 2; n < counts.size(); ++n) {
memory_size += Middle::Size(counts[n - 1], config.probing_multiplier);
}
memory_size += Longest::Size(counts.back(), config.probing_multiplier);
return memory_size;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) {
char *start = base;
size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier);
vocab_.Init(start, allocated, counts[0]);
start += allocated;
unigram_ = reinterpret_cast<ProbBackoff*>(start);
start += sizeof(ProbBackoff) * (counts[0] + 1);
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
start += allocated;
}
allocated = Longest::Size(counts.back(), config.probing_multiplier);
longest_ = Longest(start, allocated);
start += allocated;
if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config));
}
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0";
struct BinaryFileHeader {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
max_word_index = std::numeric_limits<WordIndex>::max();
one_uint64 = 1;
}
};
bool IsBinaryFormat(int fd, off_t size) {
if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false;
// Try reading the header.
util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader));
if (memory.get() == MAP_FAILED) return false;
BinaryFileHeader reference_header = BinaryFileHeader();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true;
if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?");
return false;
}
std::size_t Align8(std::size_t in) {
std::size_t off = in % 8;
if (!off) return in;
return in + 8 - off;
}
std::size_t TotalHeaderSize(unsigned int order) {
return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */);
}
void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) {
const char *from_char = reinterpret_cast<const char*>(from);
if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information.");
// Skip over the BinaryFileHeader which was read by IsBinaryFormat.
from_char += sizeof(BinaryFileHeader);
unsigned char order = *reinterpret_cast<const unsigned char*>(from_char);
if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header.");
out.resize(static_cast<std::size_t>(order));
const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1);
for (std::size_t i = 0; i < out.size(); ++i) {
out[i] = static_cast<std::size_t>(counts[i]);
}
const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size());
probing_multiplier = *probing_ptr;
search_tag = *reinterpret_cast<const char*>(probing_ptr + 1);
}
void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) {
BinaryFileHeader header = BinaryFileHeader();
header.SetToReference();
memcpy(to, &header, sizeof(BinaryFileHeader));
char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader);
*reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size());
uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1);
for (std::size_t i = 0; i < from.size(); ++i) {
counts[i] = from[i];
}
float *probing_ptr = reinterpret_cast<float*>(counts + from.size());
*probing_ptr = probing_multiplier;
*reinterpret_cast<char*>(probing_ptr + 1) = search_tag;
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) {
const off_t file_size = util::SizeFile(mapped_file_.get());
std::vector<size_t> counts;
if (IsBinaryFormat(mapped_file_.get(), file_size)) {
memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size);
unsigned char search_tag;
ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag);
if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0.");
if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested.");
size_t memory_size = Size(counts, config);
char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration.");
SetupMemory(start, counts, config);
vocab_.LoadedBinary();
for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
i->LoadedBinary();
}
longest_.LoadedBinary();
} else {
if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
util::FilePiece f(file, mapped_file_.release(), config.messages);
ReadARPACounts(f, counts);
size_t memory_size = Size(counts, config);
char *start;
if (config.write_mmap) {
// Write out an mmap file.
util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_);
WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag);
start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
} else {
memory_.reset(util::MapAnonymous(memory_size), memory_size);
start = reinterpret_cast<char*>(memory_.get());
}
SetupMemory(start, counts, config);
try {
LoadFromARPA(f, counts, config);
} catch (FormatLoadException &e) {
e << " in file " << file;
throw;
}
}
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.valid_length_ = 1;
begin_sentence.history_[0] = vocab_.BeginSentence();
begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff;
State null_context = State();
null_context.valid_length_ = 0;
P::Init(begin_sentence, null_context, vocab_, counts.size());
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) {
// Read the unigrams.
Read1Grams(f, counts[0], vocab_, unigram_);
bool saw_unk = vocab_.SawUnk();
if (!saw_unk) {
switch(config.unknown_missing) {
case Config::THROW_UP:
{
SpecialWordMissingException e("<unk>");
e << " and configuration was set to throw if unknown is missing";
throw e;
}
case Config::COMPLAIN:
if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
// There's no break;. This is by design.
case Config::SILENT:
// Default probabilities for unknown.
unigram_[0].backoff = 0.0;
unigram_[0].prob = config.unknown_missing_prob;
break;
}
}
// Read the n-grams.
for (unsigned int n = 2; n < counts.size(); ++n) {
ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]);
}
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_);
if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff);
}
/* Ugly optimized function.
* in_state contains the previous ngram's length and backoff probabilites to
* be used here. out_state is populated with the found ngram length and
* backoffs that the next call will find useful.
*
* The search goes in increasing order of ngram length.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(
const State &in_state,
const WordIndex new_word,
State &out_state) const {
FullScoreReturn ret;
const ProbBackoff &unigram = unigram_[new_word];
if (new_word == 0) {
ret.ngram_length = out_state.valid_length_ = 0;
// all of backoff.
ret.prob = std::accumulate(
in_state.backoff_,
in_state.backoff_ + in_state.valid_length_,
unigram.prob);
return ret;
}
float *backoff_out(out_state.backoff_);
*backoff_out = unigram.backoff;
ret.prob = unigram.prob;
out_state.history_[0] = new_word;
if (in_state.valid_length_ == 0) {
ret.ngram_length = out_state.valid_length_ = 1;
// No backoff because NGramLength() == 0 and unknown can't have backoff.
return ret;
}
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
uint64_t lookup_hash = static_cast<uint64_t>(new_word);
const WordIndex *hist_iter = in_state.history_;
const WordIndex *const hist_end = hist_iter + in_state.valid_length_;
typename std::vector<Middle>::const_iterator mid_iter = middle_.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == hist_end) {
// Used history [in_state.history_, hist_end) and ran out. No backoff.
std::copy(in_state.history_, hist_end, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1;
// ret.prob was already set.
return ret;
}
lookup_hash = CombineWordHash(lookup_hash, *hist_iter);
if (mid_iter == middle_.end()) break;
typename Middle::ConstIterator found;
if (!mid_iter->Find(lookup_hash, found)) {
// Didn't find an ngram using hist_iter.
// The history used in the found n-gram is [in_state.history_, hist_iter).
std::copy(in_state.history_, hist_iter, out_state.history_ + 1);
// Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word.
ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1;
ret.prob = std::accumulate(
in_state.backoff_ + (mid_iter - middle_.begin()),
in_state.backoff_ + in_state.valid_length_,
ret.prob);
return ret;
}
*backoff_out = found->GetValue().backoff;
ret.prob = found->GetValue().prob;
}
typename Longest::ConstIterator found;
if (!longest_.Find(lookup_hash, found)) {
// It's an (P::Order()-1)-gram
std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
ret.prob += in_state.backoff_[P::Order() - 2];
return ret;
}
// It's an P::Order()-gram
// out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
ret.ngram_length = P::Order();
ret.prob = found->GetValue().prob;
return ret;
}
/* Until somebody implements stateful language models in Moses, here's a slower stateless version. It also provides a mostly meaningless void * value that can be used for pruning. */
template <class Search, class VocabularyT> HieuShouldRefactorMoses GenericModel<Search, VocabularyT>::SlowStatelessScore(
const WordIndex *begin, const WordIndex *end) const {
begin = std::max(begin, end - P::Order());
HieuShouldRefactorMoses ret;
// This is end pointer passed to SumBackoffs.
const ProbBackoff &unigram = unigram_[*(end - 1)];
if (!*(end - 1)) {
ret.ngram_length = 0;
// all of backoff.
ret.prob = unigram.prob + SlowBackoffLookup(begin, end - 1, 1);
ret.meaningless_unique_state = 0;
return ret;
}
ret.prob = unigram.prob;
if (begin == end - 1) {
ret.ngram_length = 1;
ret.meaningless_unique_state = reinterpret_cast<void*>(*(end - 1));
// No backoff because the context is empty.
return ret;
}
// Ok now we now that the bigram contains known words. Start by looking it up.
uint64_t lookup_hash = static_cast<uint64_t>(*(end - 1));
const WordIndex *hist_iter = end - 2;
const WordIndex *const hist_none = begin - 1;
typename std::vector<Middle>::const_iterator mid_iter = middle_.begin();
for (; ; ++mid_iter, --hist_iter) {
if (hist_iter == hist_none) {
// Ran out. No backoff.
ret.ngram_length = end - begin;
// ret.prob was already set.
ret.meaningless_unique_state = reinterpret_cast<void*>(lookup_hash + 1);
return ret;
}
lookup_hash = CombineWordHash(lookup_hash, *hist_iter);
if (mid_iter == middle_.end()) break;
typename Middle::ConstIterator found;
if (!mid_iter->Find(lookup_hash, found)) {
// Didn't find an ngram using hist_iter.
ret.ngram_length = end - 1 - hist_iter;
ret.prob += SlowBackoffLookup(begin, end - 1, mid_iter - middle_.begin() + 1);
ret.meaningless_unique_state = reinterpret_cast<void*>(lookup_hash + 2);
return ret;
}
ret.prob = found->GetValue().prob;
}
typename Longest::ConstIterator found;
if (!longest_.Find(lookup_hash, found)) {
// It's an (P::Order()-1)-gram
ret.ngram_length = P::Order() - 1;
ret.prob += SlowBackoffLookup(begin, end - 1, P::Order() - 1);
ret.meaningless_unique_state = reinterpret_cast<void*>(lookup_hash + 3);
return ret;
}
// It's an P::Order()-gram
ret.ngram_length = P::Order();
ret.prob = found->GetValue().prob;
ret.meaningless_unique_state = reinterpret_cast<void*>(lookup_hash + 4);
return ret;
}
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
const WordIndex *const begin, const WordIndex *const end, unsigned char start) const {
// Add the backoff weights for n-grams of order start to (end - begin).
if (end - begin < static_cast<std::ptrdiff_t>(start)) return 0.0;
float ret = 0.0;
if (start == 1) {
ret += unigram_[*(end - 1)].backoff;
start = 2;
}
uint64_t lookup_hash = static_cast<uint64_t>(*(end - 1));
for (unsigned char i = 2; i < start; ++i) {
lookup_hash = CombineWordHash(lookup_hash, *(end - i));
}
typename Middle::ConstIterator found;
// i is the order of the backoff we're looking for.
for (unsigned char i = start; i <= static_cast<unsigned char>(end - begin); ++i) {
lookup_hash = CombineWordHash(lookup_hash, *(end - i));
if (!middle_[i - 2].Find(lookup_hash, found)) break;
ret += found->GetValue().backoff;
}
return ret;
}
template class GenericModel<ProbingSearch, ProbingVocabulary>;
template class GenericModel<SortedUniformSearch, SortedVocabulary>;
} // namespace detail
} // namespace ngram
} // namespace lm

View File

@ -1,148 +0,0 @@
#ifndef LM_NGRAM__
#define LM_NGRAM__
#include "lm/facade.hh"
#include "lm/ngram_config.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
#include "util/key_value_packing.hh"
#include "util/mmap.hh"
#include "util/probing_hash_table.hh"
#include "util/scoped.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
#include <algorithm>
#include <vector>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
// If you need higher order, change this and recompile.
// Having this limit means that State can be
// (kMaxOrder - 1) * sizeof(float) bytes instead of
// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
const std::size_t kMaxOrder = 6;
// This is a POD.
class State {
public:
bool operator==(const State &other) const {
if (valid_length_ != other.valid_length_) return false;
const WordIndex *end = history_ + valid_length_;
for (const WordIndex *first = history_, *second = other.history_;
first != end; ++first, ++second) {
if (*first != *second) return false;
}
// If the histories are equal, so are the backoffs.
return true;
}
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex history_[kMaxOrder - 1];
float backoff_[kMaxOrder - 1];
unsigned char valid_length_;
};
size_t hash_value(const State &state);
// TODO(hieuhoang1972): refactor language models to keep arbitrary state, not a void* pointer. Then use FullScore like good people do. For now, you get a stateless interface.
struct HieuShouldRefactorMoses {
float prob;
unsigned char ngram_length;
void *meaningless_unique_state;
};
namespace detail {
// std::identity is an SGI extension :-(
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
};
// Should return the same results as SRI.
// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// Get the size of memory that will be mapped given ngram counts. This
// does not include small non-mapped control structures, such as this class
// itself.
static size_t Size(const std::vector<size_t> &counts, const Config &config = Config());
GenericModel(const char *file, Config config = Config());
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
/* Slower but stateless call. Don't use this if you can avoid it. This
* is mostly a hack for Hieu to integrate it into Moses which is currently
* unable to handle arbitrary LM state. Sigh.
* The word indices should be in an array. *begin is the earliest word of context.
* *(end-1) is the word being appended.
*/
HieuShouldRefactorMoses SlowStatelessScore(const WordIndex *begin, const WordIndex *end) const;
private:
float SlowBackoffLookup(const WordIndex *const begin, const WordIndex *const end, unsigned char start) const;
// Appears after Size in the cc file.
void SetupMemory(char *start, const std::vector<size_t> &counts, const Config &config);
void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config);
util::scoped_fd mapped_file_;
// memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_.
util::scoped_mmap memory_;
VocabularyT vocab_;
ProbBackoff *unigram_;
typedef typename Search::template Table<ProbBackoff>::T Middle;
std::vector<Middle> middle_;
typedef typename Search::template Table<Prob>::T Longest;
Longest longest_;
};
struct ProbingSearch {
typedef float Init;
static const unsigned char kBinaryTag = 1;
template <class Value> struct Table {
typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
typedef util::ProbingHashTable<Packing, IdentityHash> T;
};
};
struct SortedUniformSearch {
// This is ignored.
typedef float Init;
static const unsigned char kBinaryTag = 2;
template <class Value> struct Table {
typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
typedef util::SortedUniformMap<Packing> T;
};
};
} // namespace detail
// These must also be instantiated in the cc file.
typedef ::lm::ProbingVocabulary Vocabulary;
typedef detail::GenericModel<detail::ProbingSearch, Vocabulary> Model;
typedef ::lm::SortedVocabulary SortedVocabulary;
typedef detail::GenericModel<detail::SortedUniformSearch, SortedVocabulary> SortedModel;
} // namespace ngram
} // namespace lm
#endif // LM_NGRAM__

View File

@ -1,13 +0,0 @@
#include "lm/ngram.hh"
#include <iostream>
int main(int argc, char *argv[]) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl;
return 1;
}
lm::ngram::Config config;
config.write_mmap = argv[2];
lm::ngram::Model(argv[1], config);
}

View File

@ -1,58 +0,0 @@
#ifndef LM_NGRAM_CONFIG__
#define LM_NGRAM_CONFIG__
/* Configuration for ngram model. Separate header to reduce pollution. */
#include <iostream>
namespace lm { namespace ngram {
struct Config {
/* EFFECTIVE FOR BOTH ARPA AND BINARY READS */
// Where to log messages including the progress bar. Set to NULL for
// silence.
std::ostream *messages;
/* ONLY EFFECTIVE WHEN READING ARPA */
// What to do when <unk> isn't in the provided model.
typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
UnknownMissing unknown_missing;
// The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_prob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
// for sorted variant.
// If you find yourself setting this to a low number, consider using the
// Sorted version instead which has lower memory consumption.
float probing_multiplier;
// While loading an ARPA file, also write out this binary format file. Set
// to NULL to disable.
const char *write_mmap;
/* ONLY EFFECTIVE WHEN READING BINARY */
bool prefault;
// Defaults.
Config() :
messages(&std::cerr),
unknown_missing(COMPLAIN),
unknown_missing_prob(0.0),
probing_multiplier(1.5),
write_mmap(NULL),
prefault(false) {}
};
} /* namespace ngram */ } /* namespace lm */
#endif // LM_NGRAM_CONFIG__

View File

@ -1,4 +1,4 @@
#include "lm/ngram.hh"
#include "lm/model.hh"
#include <cstdlib>
#include <fstream>

View File

@ -1,126 +0,0 @@
#include "lm/ngram.hh"
#include <stdlib.h>
#define BOOST_TEST_MODULE NGramTest
#include <boost/test/unit_test.hpp>
namespace lm {
namespace ngram {
namespace {
#define StartTest(word, ngram, score) \
ret = model.FullScore( \
state, \
model.GetVocabulary().Index(word), \
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
#define AppendTest(word, ngram, score) \
StartTest(word, ngram, score) \
state = out;
template <class M> void Starters(const M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
StartTest("looking", 2, -0.4846522);
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733);
// <unk> probability plus <s> backoff
StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
}
template <class M> void Continuation(const M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
AppendTest("looking", 2, -0.484652);
AppendTest("on", 3, -0.348837);
AppendTest("a", 4, -0.0155266);
AppendTest("little", 5, -0.00306122);
State preserve = state;
AppendTest("the", 1, -4.04005);
AppendTest("biarritz", 1, -1.9889);
AppendTest("not_found", 0, -2.29666);
AppendTest("more", 1, -1.20632);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
state = preserve;
AppendTest("more", 5, -0.00181395);
AppendTest("loin", 5, -0.0432557);
}
#define StatelessTest(begin, end, ngram, score) \
ret = model.SlowStatelessScore(begin, end); \
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
template <class M> void Stateless(const M &model) {
const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
WordIndex indices[sizeof(words) / sizeof(const char*)];
for (unsigned int i = 0; i < sizeof(words) / sizeof(const char*); ++i) {
indices[i] = model.GetVocabulary().Index(words[i]);
}
FullScoreReturn ret;
StatelessTest(indices, indices + 2, 2, -0.484652);
StatelessTest(indices, indices + 3, 3, -0.348837);
StatelessTest(indices, indices + 4, 4, -0.0155266);
StatelessTest(indices, indices + 5, 5, -0.00306122);
// the
StatelessTest(indices, indices + 6, 1, -4.04005);
StatelessTest(indices + 1, indices + 6, 1, -4.04005);
// biarritz
StatelessTest(indices, indices + 7, 1, -1.9889);
// not found
StatelessTest(indices, indices + 8, 0, -2.29666);
}
BOOST_AUTO_TEST_CASE(probing) {
Model m("test.arpa");
Starters(m);
Continuation(m);
Stateless(m);
}
BOOST_AUTO_TEST_CASE(sorted) {
SortedModel m("test.arpa");
Starters(m);
Continuation(m);
Stateless(m);
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
Config config;
config.write_mmap = "test.binary";
{
Model copy_model("test.arpa", config);
}
Model binary("test.binary");
Starters(binary);
Continuation(binary);
Stateless(binary);
}
BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
Config config;
config.write_mmap = "test.binary";
config.prefault = true;
{
SortedModel copy_model("test.arpa", config);
}
SortedModel binary("test.binary");
Starters(binary);
Continuation(binary);
Stateless(binary);
}
} // namespace
} // namespace ngram
} // namespace lm

View File

@ -1,7 +1,10 @@
#include "lm/read_arpa.hh"
#include <cstdlib>
#include <vector>
#include <ctype.h>
#include <inttypes.h>
namespace lm {
@ -14,10 +17,15 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {
return true;
}
template <class F> void GenericReadARPACounts(F &in, std::vector<size_t> &number) {
template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) {
number.clear();
StringPiece line;
if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank");
if (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank");
}
if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
@ -41,7 +49,7 @@ template <class F> void GenericReadNGramHeader(F &in, unsigned int length) {
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
std::stringstream expected;
expected << '\\' << length << "-grams:";
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead. ");
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
}
template <class F> void GenericReadEnd(F &in) {
@ -69,6 +77,11 @@ class FakeFilePiece {
return ret;
}
const char *FileName() const {
// This only used for error messages and we don't know the file name. . .
return "$file";
}
private:
std::istream &in_;
std::string buffer_;
@ -76,10 +89,10 @@ class FakeFilePiece {
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<std::size_t> &number) {
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
GenericReadARPACounts(in, number);
}
void ReadARPACounts(std::istream &in, std::vector<std::size_t> &number) {
void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) {
FakeFilePiece fake(in);
GenericReadARPACounts(fake, number);
}
@ -91,13 +104,13 @@ void ReadNGramHeader(std::istream &in, unsigned int length) {
GenericReadNGramHeader(fake, length);
}
void ReadBackoff(util::FilePiece &in, Prob &weights) {
void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
switch (in.get()) {
case '\t':
{
float got = in.ReadFloat();
if (got != 0.0)
UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff.");
UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff");
}
break;
case '\n':
@ -123,6 +136,15 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
void ReadEnd(util::FilePiece &in) {
GenericReadEnd(in);
StringPiece line;
try {
while (true) {
line = in.ReadLine();
if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line);
}
} catch (const util::EndOfFileException &e) {
return;
}
}
void ReadEnd(std::istream &in) {
FakeFilePiece fake(in);

View File

@ -2,20 +2,26 @@
#define LM_READ_ARPA__
#include "lm/lm_exception.hh"
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "lm/weights.hh"
#include "util/file_piece.hh"
#include <cstddef>
#include <iosfwd>
#include <vector>
namespace lm {
void ReadARPACounts(util::FilePiece &in, std::vector<std::size_t> &number);
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number);
void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number);
void ReadNGramHeader(util::FilePiece &in, unsigned int length);
void ReadNGramHeader(std::istream &in, unsigned int length);
void ReadBackoff(util::FilePiece &f, Prob &weights);
void ReadBackoff(util::FilePiece &f, ProbBackoff &weights);
void ReadBackoff(util::FilePiece &in, Prob &weights);
void ReadBackoff(util::FilePiece &in, ProbBackoff &weights);
void ReadEnd(util::FilePiece &in);
void ReadEnd(std::istream &in);
template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) {
try {
@ -36,7 +42,6 @@ template <class Voc> void Read1Grams(util::FilePiece &f, std::size_t count, Voc
for (std::size_t i = 0; i < count; ++i) {
Read1Gram(f, vocab, unigrams);
}
if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset());
vocab.FinishedLoading(unigrams);
}
@ -49,7 +54,7 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
}
ReadBackoff(f, weights);
} catch(util::Exception &e) {
e << " in the " << n << "-gram at byte " << f.Offset();
e << " in the " << static_cast<unsigned int>(n) << "-gram at byte " << f.Offset();
throw;
}
}

66
kenlm/lm/search_hashed.cc Normal file
View File

@ -0,0 +1,66 @@
#include "lm/search_hashed.hh"
#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
#include "lm/vocab.hh"
#include "util/file_piece.hh"
#include <string>
namespace lm {
namespace ngram {
namespace {
/* All of the entropy is in low order bits and boost::hash does poorly with
* these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
* stable point: 0. But 0 is <unk> which won't be queried here anyway.
*/
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
return ret;
}
uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
if (word == word_end) return 0;
uint64_t current = static_cast<uint64_t>(*word);
for (++word; word != word_end; ++word) {
current = CombineWordHash(current, *word);
}
return current;
}
template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
ReadNGramHeader(f, n);
// vocab ids of words in reverse order
WordIndex vocab_ids[n];
typename Store::Packing::Value value;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, vocab_ids, value);
uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
store.Insert(Store::Packing::Make(key, value));
}
store.FinishedInserting();
}
} // namespace
namespace detail {
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &/*config*/, Voc &vocab) {
Read1Grams(f, counts[0], vocab, unigram.Raw());
// Read the n-grams.
for (unsigned int n = 2; n < counts.size(); ++n) {
ReadNGrams(f, n, counts[n-1], vocab, middle[n-2]);
}
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, longest);
}
template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab);
template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab);
} // namespace detail
} // namespace ngram
} // namespace lm

156
kenlm/lm/search_hashed.hh Normal file
View File

@ -0,0 +1,156 @@
#ifndef LM_SEARCH_HASHED__
#define LM_SEARCH_HASHED__
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/read_arpa.hh"
#include "lm/weights.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include <algorithm>
#include <vector>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
namespace detail {
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
return ret;
}
struct HashedSearch {
typedef uint64_t Node;
class Unigram {
public:
Unigram() {}
Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
static std::size_t Size(uint64_t count) {
return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
}
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
ProbBackoff &Unknown() { return unigram_[0]; }
void LoadedBinary() {}
// For building.
ProbBackoff *Raw() { return unigram_; }
private:
ProbBackoff *unigram_;
};
Unigram unigram;
bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
const ProbBackoff &entry = unigram.Lookup(word);
prob = entry.prob;
backoff = entry.backoff;
next = static_cast<Node>(word);
return true;
}
};
template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch {
typedef MiddleT Middle;
std::vector<Middle> middle;
typedef LongestT Longest;
Longest longest;
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
return ret + Longest::Size(counts.back(), config.probing_multiplier);
}
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
unigram = Unigram(start, allocated);
start += allocated;
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle.push_back(Middle(start, allocated));
start += allocated;
}
allocated = Longest::Size(counts.back(), config.probing_multiplier);
longest = Longest(start, allocated);
start += allocated;
return start;
}
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab);
bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
prob = found->GetValue().prob;
backoff = found->GetValue().backoff;
return true;
}
bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
backoff = found->GetValue().backoff;
return true;
}
bool LookupLongest(WordIndex word, float &prob, Node &node) const {
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;
prob = found->GetValue().prob;
return true;
}
// Geenrate a node without necessarily checking that it actually exists.
// Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
for (const WordIndex *i = begin + 1; i < end; ++i) {
node = CombineWordHash(node, *i);
}
return true;
}
};
// std::identity is an SGI extension :-(
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
};
struct ProbingHashedSearch : public TemplateHashedSearch<
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
static const ModelType kModelType = HASH_PROBING;
};
struct SortedHashedSearch : public TemplateHashedSearch<
util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >,
util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > {
static const ModelType kModelType = HASH_SORTED;
};
} // namespace detail
} // namespace ngram
} // namespace lm
#endif // LM_SEARCH_HASHED__

494
kenlm/lm/search_trie.cc Normal file
View File

@ -0,0 +1,494 @@
/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
#include "lm/trie.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
#include "util/file_piece.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <deque>
#include <iostream>
#include <limits>
//#include <parallel/algorithm>
#include <vector>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdlib.h>
namespace lm {
namespace ngram {
namespace trie {
namespace {
/* An entry is a n-gram with probability. It consists of:
* WordIndex[order]
* float probability
* backoff probability (omitted for highest order n-gram)
* These are stored consecutively in memory. We want to sort them.
*
* The problem is the length depends on order (but all n-grams being compared
* have the same order). Allocating each entry on the heap (i.e. std::vector
* or std::string) then sorting pointers is the normal solution. But that's
* too memory inefficient. A lot of this code is just here to force std::sort
* to work with records where length is specified at runtime (and avoid using
* Boost for LM code). I could have used qsort, but the point is to also
* support __gnu_cxx:parallel_sort which doesn't have a qsort version.
*/
class EntryIterator {
public:
EntryIterator() {}
EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {}
bool operator==(const EntryIterator &other) const {
return ptr_ == other.ptr_;
}
bool operator<(const EntryIterator &other) const {
return ptr_ < other.ptr_;
}
EntryIterator &operator+=(std::ptrdiff_t amount) {
ptr_ += amount * size_;
return *this;
}
std::ptrdiff_t operator-(const EntryIterator &other) const {
return (ptr_ - other.ptr_) / size_;
}
const void *Data() const { return ptr_; }
void *Data() { return ptr_; }
std::size_t EntrySize() const { return size_; }
private:
uint8_t *ptr_;
std::size_t size_;
};
class EntryProxy {
public:
EntryProxy() {}
EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {}
operator std::string() const {
return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize());
}
EntryProxy &operator=(const EntryProxy &from) {
memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize());
return *this;
}
EntryProxy &operator=(const std::string &from) {
memcpy(inner_.Data(), from.data(), inner_.EntrySize());
return *this;
}
const WordIndex *Indices() const {
return static_cast<const WordIndex*>(inner_.Data());
}
private:
friend class util::ProxyIterator<EntryProxy>;
typedef std::string value_type;
typedef EntryIterator InnerIterator;
InnerIterator &Inner() { return inner_; }
const InnerIterator &Inner() const { return inner_; }
InnerIterator inner_;
};
typedef util::ProxyIterator<EntryProxy> NGramIter;
class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> {
public:
explicit CompareRecords(unsigned char order) : order_(order) {}
bool operator()(const EntryProxy &first, const EntryProxy &second) const {
return Compare(first.Indices(), second.Indices());
}
bool operator()(const EntryProxy &first, const std::string &second) const {
return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data()));
}
bool operator()(const std::string &first, const EntryProxy &second) const {
return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
}
bool operator()(const std::string &first, const std::string &second) const {
return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data()));
}
private:
bool Compare(const WordIndex *first, const WordIndex *second) const {
const WordIndex *end = first + order_;
for (; first != end; ++first, ++second) {
if (*first < *second) return true;
if (*first > *second) return false;
}
return false;
}
unsigned char order_;
};
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
}
void ReadOrThrow(FILE *from, void *data, size_t size) {
if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size);
}
void CopyOrThrow(FILE *from, FILE *to, size_t size) {
const size_t kBufSize = 512;
char buf[kBufSize];
for (size_t i = 0; i < size; i += kBufSize) {
std::size_t amount = std::min(size - i, kBufSize);
ReadOrThrow(from, buf, amount);
WriteOrThrow(to, buf, amount);
}
}
std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
const std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
std::string ret(assembled.str());
util::scoped_FILE out(fopen(ret.c_str(), "w"));
if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
// Compress entries that being with the same (order-1) words.
for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
const uint8_t *group_end = group_begin;
for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {}
WriteOrThrow(out.get(), group_begin, prefix_size);
WordIndex group_size = (group_end - group_begin) / entry_size;
WriteOrThrow(out.get(), &group_size, sizeof(group_size));
for (const uint8_t *i = group_begin; i != group_end; i += entry_size) {
WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex));
WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size);
}
group_begin = group_end;
}
return ret;
}
class SortedFileReader {
public:
SortedFileReader() {}
void Init(const std::string &name, unsigned char order) {
file_.reset(fopen(name.c_str(), "r"));
if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read");
header_.resize(order - 1);
NextHeader();
}
// Preceding words.
const WordIndex *Header() const {
return &*header_.begin();
}
const std::vector<WordIndex> &HeaderVector() const { return header_;}
std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); }
void NextHeader() {
if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) {
UTIL_THROW(util::ErrnoException, "Short read of counts");
}
}
void ReadCount(WordIndex &to) {
ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
}
void ReadWord(WordIndex &to) {
ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
}
template <class Weights> void ReadWeights(Weights &to) {
ReadOrThrow(file_.get(), &to, sizeof(Weights));
}
bool Ended() {
return feof(file_.get());
}
FILE *File() { return file_.get(); }
private:
util::scoped_FILE file_;
std::vector<WordIndex> header_;
};
void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) {
WriteOrThrow(to, from.Header(), from.HeaderBytes());
WordIndex count;
from.ReadCount(count);
WriteOrThrow(to, &count, sizeof(WordIndex));
CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
}
void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) {
SortedFileReader first, second;
first.Init(first_name, order);
second.Init(second_name, order);
util::scoped_FILE out_file(fopen(out, "w"));
if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write");
while (!first.Ended() && !second.Ended()) {
if (first.HeaderVector() < second.HeaderVector()) {
CopyFullRecord(first, out_file.get(), weights_size);
first.NextHeader();
continue;
}
if (first.HeaderVector() > second.HeaderVector()) {
CopyFullRecord(second, out_file.get(), weights_size);
second.NextHeader();
continue;
}
// Merge at the entry level.
WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes());
WordIndex first_count, second_count;
first.ReadCount(first_count); second.ReadCount(second_count);
WordIndex total_count = first_count + second_count;
WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex));
WordIndex first_word, second_word;
first.ReadWord(first_word); second.ReadWord(second_word);
WordIndex first_index = 0, second_index = 0;
while (true) {
if (first_word < second_word) {
WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
CopyOrThrow(first.File(), out_file.get(), weights_size);
if (++first_index == first_count) break;
first.ReadWord(first_word);
} else {
WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
CopyOrThrow(second.File(), out_file.get(), weights_size);
if (++second_index == second_count) break;
second.ReadWord(second_word);
}
}
if (first_index == first_count) {
WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex));
} else {
WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex));
}
first.NextHeader();
second.NextHeader();
}
for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) {
CopyFullRecord(remaining, out_file.get(), weights_size);
}
}
void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
if (order == 1) return;
ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1);
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
const size_t words_size = sizeof(WordIndex) * order;
const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
const size_t entry_size = words_size + weights_size;
const size_t batch_size = std::min(count, mem.size() / entry_size);
uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get());
std::deque<std::string> files;
for (std::size_t batch = 0, done = 0; done < count; ++batch) {
uint8_t *out = begin;
uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
if (order == counts.size()) {
for (; out != out_end; out += entry_size) {
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size));
}
} else {
for (; out != out_end; out += entry_size) {
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size));
}
}
// TODO: __gnu_parallel::sort here.
EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order));
files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
done += (out_end - begin) / entry_size;
}
// All individual files created. Merge them.
std::size_t merge_count = 0;
while (files.size() > 1) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order);
if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
files.pop_front();
if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
files.pop_front();
}
if (!files.empty()) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
std::string merged_name(assembled.str());
if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
}
}
void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
util::scoped_mmap unigram_mmap;
unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
}
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED);
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());
ReadEnd(f);
}
struct RecursiveInsertParams {
WordIndex *words;
SortedFileReader *files;
unsigned char max_order;
// This is an array of size order - 2.
BitPackedMiddle *middle;
// This has exactly one entry.
BitPackedLongest *longest;
};
uint64_t RecursiveInsert(RecursiveInsertParams &params, unsigned char order) {
SortedFileReader &file = params.files[order - 2];
const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex();
if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1)))
return ret;
WordIndex count;
file.ReadCount(count);
WordIndex key;
if (order == params.max_order) {
Prob value;
for (WordIndex i = 0; i < count; ++i) {
file.ReadWord(key);
file.ReadWeights(value);
params.longest->Insert(key, value.prob);
}
file.NextHeader();
return ret;
}
ProbBackoff value;
for (WordIndex i = 0; i < count; ++i) {
file.ReadWord(params.words[order - 1]);
file.ReadWeights(value);
params.middle[order - 2].Insert(
params.words[order - 1],
value.prob,
value.backoff,
RecursiveInsert(params, order + 1));
}
file.NextHeader();
return ret;
}
void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) {
UnigramValue *unigrams = out.unigram.Raw();
// Load unigrams. Leave the next pointers uninitialized.
{
std::string name(file_prefix + "unigrams");
util::scoped_FILE file(fopen(name.c_str(), "r"));
if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
for (WordIndex i = 0; i < counts[0]; ++i) {
ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
}
unlink(name.c_str());
}
// inputs[0] is bigrams.
SortedFileReader inputs[counts.size() - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
inputs[i-2].Init(assembled.str(), i);
unlink(assembled.str().c_str());
}
// words[0] is unigrams.
WordIndex words[counts.size()];
RecursiveInsertParams params;
params.words = words;
params.files = inputs;
params.max_order = static_cast<unsigned char>(counts.size());
params.middle = &*out.middle.begin();
params.longest = &out.longest;
{
util::ErsatzProgress progress(messages, "Building trie", counts[0]);
for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) {
unigrams[words[0]].next = RecursiveInsert(params, 2);
}
}
/* Set ending offsets so the last entry will be sized properly */
if (!out.middle.empty()) {
unigrams[counts[0]].next = out.middle.front().InsertIndex();
for (size_t i = 0; i < out.middle.size() - 1; ++i) {
out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex());
}
out.middle.back().FinishedLoading(out.longest.InsertIndex());
} else {
unigrams[counts[0]].next = out.longest.InsertIndex();
}
}
} // namespace
void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
} else if (config.write_mmap) {
temporary_directory = config.write_mmap;
} else {
temporary_directory = file;
}
// Null on end is kludge to ensure null termination.
temporary_directory += "-tmp-XXXXXX\0";
if (!mkdtemp(&temporary_directory[0])) {
UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
}
// Chop off null kludge.
temporary_directory.resize(strlen(temporary_directory.c_str()));
// Add directory delimiter. Assumes a real operating system.
temporary_directory += '/';
// At least 1MB sorting memory.
ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
BuildTrie(temporary_directory.c_str(), counts, config.messages, *this);
if (rmdir(temporary_directory.c_str())) {
std::cerr << "Failed to delete " << temporary_directory << std::endl;
}
}
} // namespace trie
} // namespace ngram
} // namespace lm

83
kenlm/lm/search_trie.hh Normal file
View File

@ -0,0 +1,83 @@
#ifndef LM_SEARCH_TRIE__
#define LM_SEARCH_TRIE__
#include "lm/binary_format.hh"
#include "lm/trie.hh"
#include "lm/weights.hh"
#include <assert.h>
namespace lm {
namespace ngram {
class SortedVocabulary;
namespace trie {
struct TrieSearch {
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
typedef trie::BitPackedMiddle Middle;
std::vector<Middle> middle;
typedef trie::BitPackedLongest Longest;
Longest longest;
static const ModelType kModelType = TRIE_SORTED;
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) {
std::size_t ret = Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
ret += Middle::Size(counts[i], counts[0], counts[i+1]);
}
return ret + Longest::Size(counts.back(), counts[0]);
}
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) {
unigram.Init(start);
start += Unigram::Size(counts[0]);
middle.resize(counts.size() - 2);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
middle[i-1].Init(start, counts[0], counts[i+1]);
start += Middle::Size(counts[i], counts[0], counts[i+1]);
}
longest.Init(start, counts[0]);
return start + Longest::Size(counts.back(), counts[0]);
}
void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab);
bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
return unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
return mid.Find(word, prob, backoff, node);
}
bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
return mid.FindNoProb(word, backoff, node);
}
bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
return longest.Find(word, prob, node);
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
// TODO: don't decode prob.
assert(begin != end);
float ignored_prob, ignored_backoff;
LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
for (const WordIndex *i = begin + 1; i < end; ++i) {
if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false;
}
return true;
}
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_SEARCH_TRIE__

View File

@ -31,8 +31,7 @@ void Vocabulary::FinishedLoading() {
SetSpecial(
sri_->ssIndex(),
sri_->seIndex(),
sri_->unkIndex(),
sri_->highIndex() + 1);
sri_->unkIndex());
}
namespace {

167
kenlm/lm/trie.cc Normal file
View File

@ -0,0 +1,167 @@
#include "lm/trie.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
#include "util/proxy_iterator.hh"
#include "util/sorted_uniform.hh"
#include <assert.h>
namespace lm {
namespace ngram {
namespace trie {
namespace {
// Assumes key is first.
class JustKeyProxy {
public:
JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {}
operator uint64_t() const { return GetKey(); }
uint64_t GetKey() const {
uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_);
}
private:
friend class util::ProxyIterator<JustKeyProxy>;
friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
: inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
// This is a read-only iterator.
JustKeyProxy &operator=(const JustKeyProxy &other);
typedef uint64_t value_type;
typedef uint64_t InnerIterator;
uint64_t &Inner() { return inner_; }
const uint64_t &Inner() const { return inner_; }
// The address in bits is base_ * 8 + inner_ * total_bits_.
uint64_t inner_;
const uint8_t *const base_;
const uint64_t key_mask_;
const uint8_t key_bits_, total_bits_;
};
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> out;
if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
at_index = out.Inner();
return true;
}
} // namespace
std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
// +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
// Note that this waste is O(order), not O(number of ngrams).
return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
}
void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
util::BitPackingSanity();
word_bits_ = util::RequiredBits(max_vocab);
word_mask_ = (1ULL << word_bits_) - 1ULL;
if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
prob_bits_ = 31;
total_bits_ = word_bits_ + prob_bits_ + remaining_bits;
base_ = static_cast<uint8_t*>(base);
insert_index_ = 0;
}
std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
}
void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
backoff_bits_ = 32;
next_bits_ = util::RequiredBits(max_next);
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
next_mask_ = (1ULL << next_bits_) - 1;
BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
}
void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) {
assert(word <= word_mask_);
assert(next <= next_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);
at_pointer += word_bits_;
util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
at_pointer += prob_bits_;
util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
at_pointer += backoff_bits_;
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);
++insert_index_;
}
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
return true;
}
bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
return true;
}
void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
assert(next_end <= next_mask_);
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);
}
void BitPackedLongest::Insert(WordIndex index, float prob) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index);
at_pointer += word_bits_;
util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
++insert_index_;
}
bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer = at_pointer * total_bits_ + word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
return true;
}
} // namespace trie
} // namespace ngram
} // namespace lm

129
kenlm/lm/trie.hh Normal file
View File

@ -0,0 +1,129 @@
#ifndef LM_TRIE__
#define LM_TRIE__
#include <inttypes.h>
#include <cstddef>
#include "lm/word_index.hh"
#include "lm/weights.hh"
namespace lm {
namespace ngram {
namespace trie {
struct NodeRange {
uint64_t begin, end;
};
// TODO: if the number of unigrams is a concern, also bit pack these records.
struct UnigramValue {
ProbBackoff weights;
uint64_t next;
uint64_t Next() const { return next; }
};
class Unigram {
public:
Unigram() {}
void Init(void *start) {
unigram_ = static_cast<UnigramValue*>(start);
}
static std::size_t Size(uint64_t count) {
// +1 in case unknown doesn't appear. +1 for the final next.
return (count + 2) * sizeof(UnigramValue);
}
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
ProbBackoff &Unknown() { return unigram_[0].weights; }
UnigramValue *Raw() {
return unigram_;
}
void LoadedBinary() {}
bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
prob = val->weights.prob;
backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
return true;
}
private:
UnigramValue *unigram_;
};
class BitPacked {
public:
BitPacked() {}
uint64_t InsertIndex() const {
return insert_index_;
}
void LoadedBinary() {}
protected:
static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
uint8_t word_bits_, prob_bits_;
uint8_t total_bits_;
uint64_t word_mask_;
uint8_t *base_;
uint64_t insert_index_;
};
class BitPackedMiddle : public BitPacked {
public:
BitPackedMiddle() {}
static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next);
void Init(void *base, uint64_t max_vocab, uint64_t max_next);
void Insert(WordIndex word, float prob, float backoff, uint64_t next);
bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
void FinishedLoading(uint64_t next_end);
private:
uint8_t backoff_bits_, next_bits_;
uint64_t next_mask_;
};
class BitPackedLongest : public BitPacked {
public:
BitPackedLongest() {}
static std::size_t Size(uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, 0);
}
void Init(void *base, uint64_t max_vocab) {
return BaseInit(base, max_vocab, 0);
}
void Insert(WordIndex word, float prob);
bool Find(WordIndex word, float &prob, const NodeRange &node) const;
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_TRIE__

View File

@ -1,4 +1,5 @@
#include "lm/virtual_interface.hh"
#include "lm/lm_exception.hh"
namespace lm {
@ -6,11 +7,10 @@ namespace base {
Vocabulary::~Vocabulary() {}
void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) {
void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
begin_sentence_ = begin_sentence;
end_sentence_ = end_sentence;
not_found_ = not_found;
available_ = available;
if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");
if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");
}

View File

@ -37,8 +37,6 @@ class Vocabulary {
WordIndex BeginSentence() const { return begin_sentence_; }
WordIndex EndSentence() const { return end_sentence_; }
WordIndex NotFound() const { return not_found_; }
// FullScoreReturn start index of unused word assignments.
WordIndex Available() const { return available_; }
/* Most implementations allow StringPiece lookups and need only override
* Index(StringPiece). SRI requires null termination and overrides all
@ -56,13 +54,13 @@ class Vocabulary {
// Call SetSpecial afterward.
Vocabulary() {}
Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) {
SetSpecial(begin_sentence, end_sentence, not_found, available);
Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
SetSpecial(begin_sentence, end_sentence, not_found);
}
void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available);
void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found);
WordIndex begin_sentence_, end_sentence_, not_found_, available_;
WordIndex begin_sentence_, end_sentence_, not_found_;
private:
// Disable copy constructors. They're private and undefined.
@ -97,7 +95,7 @@ class Vocabulary {
* missing these methods, see facade.hh.
*
* This is the fastest way to use a model and presents a normal State class to
* be included in hypothesis state structure.
* be included in a hypothesis state structure.
*
*
* OPTION 2: Use the virtual interface below.

View File

@ -1,6 +1,10 @@
#include "lm/vocab.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
@ -8,6 +12,7 @@
#include <string>
namespace lm {
namespace ngram {
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len) {
@ -22,23 +27,81 @@ namespace {
const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
void ReadWords(int fd, EnumerateVocab *enumerate) {
if (!enumerate) return;
const std::size_t kInitialRead = 16384;
std::string buf;
buf.reserve(kInitialRead + 100);
buf.resize(kInitialRead);
WordIndex index = 0;
while (true) {
ssize_t got = read(fd, &buf[0], kInitialRead);
if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
if (got == 0) return;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
ssize_t ret = read(fd, &next_char, 1);
if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word.");
buf.push_back(next_char);
}
// Ok now we have null terminated strings.
for (const char *i = buf.data(); i != buf.data() + buf.size();) {
std::size_t length = strlen(i);
enumerate->Add(index++, StringPiece(i, length));
i += length + 1 /* null byte */;
}
}
}
void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
const uint8_t *data = static_cast<const uint8_t*>(data_void);
while (size) {
ssize_t ret = write(fd, data, size);
if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed");
data += ret;
size -= ret;
}
}
} // namespace
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {}
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {}
WriteWordsWrapper::~WriteWordsWrapper() {}
std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) {
void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
if (inner_) inner_->Add(index, str);
WriteOrThrow(fd_, str.data(), str.size());
char null_byte = 0;
// Inefficient because it's unbuffered. Sue me.
WriteOrThrow(fd_, &null_byte, 1);
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(Entry) * entries;
}
void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) {
assert(allocated >= Size(entries));
void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
assert(allocated >= Size(entries, config));
// Leave space for number of entries.
begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
end_ = begin_;
saw_unk_ = false;
}
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
enumerate_->Add(0, "<unk>");
strings_to_enumerate_.resize(max_entries);
}
}
WordIndex SortedVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
@ -46,32 +109,56 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
return 0;
}
end_->key = hashed;
if (enumerate_) {
strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
return end_ - begin_;
}
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
util::JointSort(begin_, end_, reorder_vocab + 1);
SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
if (enumerate_) {
util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
} else {
util::JointSort(begin_, end_, reorder_vocab + 1);
}
SetSpecial(Index("<s>"), Index("</s>"), 0);
// Save size.
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
}
void SortedVocabulary::LoadedBinary() {
void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
ProbingVocabulary::ProbingVocabulary() {}
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
void ProbingVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) {
std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
return Lookup::Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
lookup_ = Lookup(start, allocated);
available_ = 1;
// Later if available_ != expected_available_ then we can throw UnknownMissingException.
saw_unk_ = false;
}
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
enumerate_->Add(0, "<unk>");
}
}
WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
// Prevent unknown from going into the table.
@ -79,19 +166,22 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
saw_unk_ = true;
return 0;
} else {
if (enumerate_) enumerate_->Add(available_, str);
lookup_.Insert(Lookup::Packing::Make(hashed, available_));
return available_++;
}
}
void ProbingVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
lookup_.FinishedInserting();
SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
void ProbingVocabulary::LoadedBinary() {
void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
lookup_.LoadedBinary();
SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
} // namespace ngram
} // namespace lm

View File

@ -1,16 +1,23 @@
#ifndef LM_VOCAB__
#define LM_VOCAB__
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
namespace lm {
#include <string>
#include <vector>
namespace lm {
class ProbBackoff;
namespace ngram {
class Config;
class EnumerateVocab;
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len);
inline uint64_t HashForVocab(const StringPiece &str) {
@ -18,6 +25,19 @@ inline uint64_t HashForVocab(const StringPiece &str) {
}
} // namespace detail
class WriteWordsWrapper : public EnumerateVocab {
public:
WriteWordsWrapper(EnumerateVocab *inner, int fd);
~WriteWordsWrapper();
void Add(WordIndex index, const StringPiece &str);
private:
EnumerateVocab *inner_;
int fd_;
};
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
class SortedVocabulary : public base::Vocabulary {
private:
@ -43,10 +63,12 @@ class SortedVocabulary : public base::Vocabulary {
}
// Ignores second argument for consistency with probing hash which has a float here.
static size_t Size(std::size_t entries, float ignored = 0.0);
static size_t Size(std::size_t entries, const Config &config);
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void Init(void *start, std::size_t allocated, std::size_t entries);
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@ -55,12 +77,17 @@ class SortedVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
void LoadedBinary();
void LoadedBinary(int fd, EnumerateVocab *to);
private:
Entry *begin_, *end_;
bool saw_unk_;
EnumerateVocab *enumerate_;
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
std::vector<std::string> strings_to_enumerate_;
};
// Vocabulary storing a map from uint64_t to WordIndex.
@ -73,12 +100,12 @@ class ProbingVocabulary : public base::Vocabulary {
return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0;
}
static size_t Size(std::size_t entries, float probing_multiplier) {
return Lookup::Size(entries, probing_multiplier);
}
static size_t Size(std::size_t entries, const Config &config);
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void Init(void *start, std::size_t allocated, std::size_t entries);
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
@ -86,7 +113,7 @@ class ProbingVocabulary : public base::Vocabulary {
bool SawUnk() const { return saw_unk_; }
void LoadedBinary();
void LoadedBinary(int fd, EnumerateVocab *to);
private:
// std::identity is an SGI extension :-(
@ -98,9 +125,14 @@ class ProbingVocabulary : public base::Vocabulary {
Lookup lookup_;
WordIndex available_;
bool saw_unk_;
EnumerateVocab *enumerate_;
};
} // namespace ngram
} // namespace lm
#endif // LM_VOCAB__

40
kenlm/util/bit_packing.cc Normal file
View File

@ -0,0 +1,40 @@
#include "util/bit_packing.hh"
#include "util/exception.hh"
#include <string.h>
namespace util {
namespace {
template <bool> struct StaticCheck {};
template <> struct StaticCheck<true> { typedef bool StaticAssertionPassed; };
// If your float isn't 4 bytes, we're hosed.
typedef StaticCheck<sizeof(float) == 4>::StaticAssertionPassed FloatSize;
} // namespace
uint8_t RequiredBits(uint64_t max_value) {
if (!max_value) return 0;
uint8_t ret = 1;
while (max_value >>= 1) ++ret;
return ret;
}
void BitPackingSanity() {
const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000");
char mem[57+8];
memset(mem, 0, sizeof(mem));
const uint64_t test57 = 0x123456789abcdefULL;
for (uint64_t b = 0; b < 57 * 8; b += 57) {
WriteInt57(mem + b / 8, b % 8, 57, test57);
}
for (uint64_t b = 0; b < 57 * 8; b += 57) {
if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1))
UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler.");
}
// TODO: more checks.
}
} // namespace util

100
kenlm/util/bit_packing.hh Normal file
View File

@ -0,0 +1,100 @@
#ifndef UTIL_BIT_PACKING__
#define UTIL_BIT_PACKING__
/* Bit-level packing routines */
#include <assert.h>
#ifdef __APPLE__
#include <architecture/byte_order.h>
#elif __linux__
#include <endian.h>
#else
#include <arpa/nameser_compat.h>
#endif
#include <inttypes.h>
namespace util {
/* WARNING WARNING WARNING:
* The write functions assume that memory is zero initially. This makes them
* faster and is the appropriate case for mmapped language model construction.
* These routines assume that unaligned access to uint64_t is fast and that
* storage is little endian. This is the case on x86_64. I'm not sure how
* fast unaligned 64-bit access is on x86 but my target audience is large
* language models for which 64-bit is necessary.
*
* Call the BitPackingSanity function to sanity check. Calling once suffices,
* but it may be called multiple times when that's inconvenient.
*/
inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.
#if BYTE_ORDER == LITTLE_ENDIAN
return bit;
#elif BYTE_ORDER == BIG_ENDIAN
return 64 - length - bit;
#else
#error "Bit packing code isn't written for your byte order."
#endif
}
/* Pack integers up to 57 bits using their least significant digits.
* The length is specified using mask:
* Assumes mask == (1 << length) - 1 where length <= 57.
*/
inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) {
return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask;
}
/* Assumes value < (1 << length) and length <= 57.
* Assumes the memory is zero initially.
*/
inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) {
*reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length));
}
namespace detail { typedef union { float f; uint32_t i; } FloatEnc; }
inline float ReadFloat32(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32);
return encoded.f;
}
inline void WriteFloat32(void *base, uint8_t bit, float value) {
detail::FloatEnc encoded;
encoded.f = value;
WriteInt57(base, bit, 32, encoded.i);
}
inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31);
// Sign bit set means negative.
encoded.i |= 0x80000000;
return encoded.f;
}
inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
assert(value <= 0.0);
detail::FloatEnc encoded;
encoded.f = value;
encoded.i &= ~0x80000000;
WriteInt57(base, bit, 31, encoded.i);
}
void BitPackingSanity();
// Return bits required to store integers upto max_value. Not the most
// efficient implementation, but this is only called a few times to size tries.
uint8_t RequiredBits(uint64_t max_value);
struct BitsMask {
void FromMax(uint64_t max_value) {
bits = RequiredBits(max_value);
mask = (1 << bits) - 1;
}
uint8_t bits;
uint64_t mask;
};
} // namespace util
#endif // UTIL_BIT_PACKING__

View File

@ -0,0 +1,46 @@
#include "util/bit_packing.hh"
#define BOOST_TEST_MODULE BitPackingTest
#include <boost/test/unit_test.hpp>
#include <string.h>
namespace util {
namespace {
const uint64_t test57 = 0x123456789abcdefULL;
BOOST_AUTO_TEST_CASE(ZeroBit) {
char mem[16];
memset(mem, 0, sizeof(mem));
WriteInt57(mem, 0, 57, test57);
BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1));
}
BOOST_AUTO_TEST_CASE(EachBit) {
char mem[16];
for (uint8_t b = 0; b < 8; ++b) {
memset(mem, 0, sizeof(mem));
WriteInt57(mem, b, 57, test57);
BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1));
}
}
BOOST_AUTO_TEST_CASE(Consecutive) {
char mem[57+8];
memset(mem, 0, sizeof(mem));
for (uint64_t b = 0; b < 57 * 8; b += 57) {
WriteInt57(mem + (b / 8), b % 8, 57, test57);
BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1));
}
for (uint64_t b = 0; b < 57 * 8; b += 57) {
BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1));
}
}
BOOST_AUTO_TEST_CASE(Sanity) {
BitPackingSanity();
}
} // namespace
} // namespace util

View File

@ -13,10 +13,7 @@ ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::s
ErsatzProgress::~ErsatzProgress() {
if (!out_) return;
for (; stones_written_ < kWidth; ++stones_written_) {
(*out_) << '*';
}
*out_ << '\n';
Finished();
}
ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete)
@ -36,8 +33,8 @@ void ErsatzProgress::Milestone() {
for (; stones_written_ < stone; ++stones_written_) {
(*out_) << '*';
}
if (current_ >= complete_) {
if (stone == kWidth) {
(*out_) << std::endl;
next_ = std::numeric_limits<std::size_t>::max();
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);

View File

@ -19,7 +19,7 @@ class ErsatzProgress {
~ErsatzProgress();
ErsatzProgress &operator++() {
if (++current_ == next_) Milestone();
if (++current_ >= next_) Milestone();
return *this;
}
@ -33,6 +33,10 @@ class ErsatzProgress {
Milestone();
}
void Finished() {
Set(complete_);
}
private:
void Milestone();

View File

@ -16,7 +16,7 @@ const char *HandleStrerror(int ret, const char *buf) {
}
// The GNU version.
const char *HandleStrerror(const char *ret, const char *) {
const char *HandleStrerror(const char *ret, const char *buf) {
return ret;
}
} // namespace
@ -24,7 +24,12 @@ const char *HandleStrerror(const char *ret, const char *) {
ErrnoException::ErrnoException() throw() : errno_(errno) {
char buf[200];
buf[0] = 0;
#ifdef sun
const char *add = strerror(errno);
#else
const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf);
#endif
if (add) {
*this << add << ' ';
}

View File

@ -7,14 +7,18 @@
#include <limits>
#include <assert.h>
#include <cstdlib>
#include <ctype.h>
#include <fcntl.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#ifdef HAVE_ZLIB
#include <zlib.h>
#endif
namespace util {
EndOfFileException::EndOfFileException() throw() {
@ -23,7 +27,14 @@ EndOfFileException::EndOfFileException() throw() {
EndOfFileException::~EndOfFileException() throw() {}
ParseNumberException::ParseNumberException(StringPiece value) throw() {
*this << "Could not parse \"" << value << "\" into a float";
*this << "Could not parse \"" << value << "\" into a number";
}
GZException::GZException(void *file) {
#ifdef HAVE_ZLIB
int num;
*this << gzerror(file, &num) << " from zlib";
#endif // HAVE_ZLIB
}
int OpenReadOrThrow(const char *name) {
@ -38,77 +49,33 @@ off_t SizeFile(int fd) {
return sb.st_size;
}
FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :
FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :
file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) :
FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :
file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) {
if (total_size_ == kBadSize) {
fallback_to_read_ = true;
if (show_progress)
*show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl;
} else {
fallback_to_read_ = false;
}
default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2);
position_ = NULL;
position_end_ = NULL;
mapped_offset_ = 0;
at_end_ = false;
Shift();
}
float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
char *end;
float ret = std::strtof(buffer.c_str(), &end);
if (buffer.c_str() == end) throw ParseNumberException(buffer);
position_ += end - buffer.c_str();
return ret;
}
Shift();
}
char *end;
float ret = std::strtof(position_, &end);
if (end == position_) throw ParseNumberException(ReadDelimited());
position_ = end;
return ret;
}
void FilePiece::SkipSpaces() throw (EndOfFileException) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
if (!isspace(*position_)) return;
}
}
const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) {
for (const char *i = position_; i <= last_space_; ++i) {
if (isspace(*i)) return i;
}
while (!at_end_) {
size_t skip = position_end_ - position_;
Shift();
for (const char *i = position_ + skip; i <= last_space_; ++i) {
if (isspace(*i)) return i;
FilePiece::~FilePiece() {
#ifdef HAVE_ZLIB
if (gz_file_) {
// zlib took ownership
file_.release();
int ret;
if (Z_OK != (ret = gzclose(gz_file_))) {
std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl;
abort();
}
}
return position_end_;
#endif
}
StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) {
StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {
const char *start = position_;
do {
for (const char *i = start; i < position_end_; ++i) {
@ -124,17 +91,129 @@ StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) {
} while (!at_end_);
StringPiece ret(position_, position_end_ - position_);
position_ = position_end_;
return position_;
return ret;
}
void FilePiece::Shift() throw(EndOfFileException) {
if (at_end_) throw EndOfFileException();
float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {
return ReadNumber<float>();
}
double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) {
return ReadNumber<double>();
}
long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) {
return ReadNumber<long int>();
}
unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) {
return ReadNumber<unsigned long int>();
}
void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
if (!isspace(*position_)) return;
}
}
void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) {
#ifdef HAVE_ZLIB
gz_file_ = NULL;
#endif
file_name_ = name;
default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2);
position_ = NULL;
position_end_ = NULL;
mapped_offset_ = 0;
at_end_ = false;
if (total_size_ == kBadSize) {
// So the assertion passes.
fallback_to_read_ = false;
if (show_progress)
*show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl;
TransitionToRead();
} else {
fallback_to_read_ = false;
}
Shift();
// gzip detect.
if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) {
#ifndef HAVE_ZLIB
UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in.");
#endif
if (!fallback_to_read_) {
at_end_ = false;
TransitionToRead();
}
}
}
namespace {
void ParseNumber(const char *begin, char *&end, float &out) {
#ifdef sun
out = static_cast<float>(strtod(begin, &end));
#else
out = strtof(begin, &end);
#endif
}
void ParseNumber(const char *begin, char *&end, double &out) {
out = strtod(begin, &end);
}
void ParseNumber(const char *begin, char *&end, long int &out) {
out = strtol(begin, &end, 10);
}
void ParseNumber(const char *begin, char *&end, unsigned long int &out) {
out = strtoul(begin, &end, 10);
}
} // namespace
template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
char *end;
T ret;
ParseNumber(buffer.c_str(), end, ret);
if (buffer.c_str() == end) throw ParseNumberException(buffer);
position_ += end - buffer.c_str();
return ret;
}
Shift();
}
char *end;
T ret;
ParseNumber(position_, end, ret);
if (end == position_) throw ParseNumberException(ReadDelimited());
position_ = end;
return ret;
}
const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) {
for (const char *i = position_; i <= last_space_; ++i) {
if (isspace(*i)) return i;
}
while (!at_end_) {
size_t skip = position_end_ - position_;
Shift();
for (const char *i = position_ + skip; i <= last_space_; ++i) {
if (isspace(*i)) return i;
}
}
return position_end_;
}
void FilePiece::Shift() throw(GZException, EndOfFileException) {
if (at_end_) {
progress_.Finished();
throw EndOfFileException();
}
off_t desired_begin = position_ - data_.begin() + mapped_offset_;
progress_.Set(desired_begin);
if (!fallback_to_read_) MMapShift(desired_begin);
// Notice an mmap failure might set the fallback.
if (fallback_to_read_) ReadShift(desired_begin);
if (fallback_to_read_) ReadShift();
for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) {
if (isspace(*last_space_)) break;
@ -163,28 +242,41 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {
data_.reset();
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
if (data_.get() == MAP_FAILED) {
fallback_to_read_ = true;
if (desired_begin) {
if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either.");
}
// The mmap was scheduled to end the file, but now we're going to read it.
at_end_ = false;
TransitionToRead();
return;
}
mapped_offset_ = mapped_offset;
position_ = data_.begin() + ignore;
position_end_ = data_.begin() + mapped_size;
progress_.Set(desired_begin);
}
void FilePiece::ReadShift(off_t desired_begin) throw() {
void FilePiece::TransitionToRead() throw (GZException) {
assert(!fallback_to_read_);
fallback_to_read_ = true;
data_.reset();
data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_);
position_ = data_.begin();
position_end_ = position_;
#ifdef HAVE_ZLIB
assert(!gz_file_);
gz_file_ = gzdopen(file_.get(), "r");
if (!gz_file_) {
UTIL_THROW(GZException, "zlib failed to open " << file_name_);
}
#endif
}
void FilePiece::ReadShift() throw(GZException, EndOfFileException) {
assert(fallback_to_read_);
if (data_.source() != scoped_memory::MALLOC_ALLOCATED) {
// First call.
data_.reset();
data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_);
position_ = data_.begin();
position_end_ = position_;
}
// Bytes [data_.begin(), position_) have been consumed.
// Bytes [position_, position_end_) have been read into the buffer.
@ -215,9 +307,23 @@ void FilePiece::ReadShift(off_t desired_begin) throw() {
}
}
ssize_t read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
ssize_t read_return;
#ifdef HAVE_ZLIB
read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
if (read_return == -1) throw GZException(gz_file_);
if (total_size_ != kBadSize) {
// Just get the position, don't actually seek. Apparently this is how you do it. . .
off_t ret = lseek(file_.get(), 0, SEEK_CUR);
if (ret != -1) progress_.Set(ret);
}
#else
read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
if (read_return == -1) UTIL_THROW(ErrnoException, "read failed");
if (read_return == 0) at_end_ = true;
progress_.Set(mapped_offset_);
#endif
if (read_return == 0) {
at_end_ = true;
}
position_end_ += read_return;
}

View File

@ -11,6 +11,8 @@
#include <cstddef>
#define HAVE_ZLIB
namespace util {
class EndOfFileException : public Exception {
@ -25,6 +27,13 @@ class ParseNumberException : public Exception {
~ParseNumberException() throw() {}
};
class GZException : public Exception {
public:
explicit GZException(void *file);
GZException() throw() {}
~GZException() throw() {}
};
int OpenReadOrThrow(const char *name);
// Return value for SizeFile when it can't size properly.
@ -34,40 +43,47 @@ off_t SizeFile(int fd);
class FilePiece {
public:
// 32 MB default.
explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);
// Takes ownership of fd. name is used for messages.
explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);
~FilePiece();
char get() throw(EndOfFileException) {
if (position_ == position_end_) Shift();
char get() throw(GZException, EndOfFileException) {
if (position_ == position_end_) {
Shift();
if (at_end_) throw EndOfFileException();
}
return *(position_++);
}
// Memory backing the returned StringPiece may vanish on the next call.
// Leaves the delimiter, if any, to be returned by get().
StringPiece ReadDelimited() throw(EndOfFileException) {
StringPiece ReadDelimited() throw(GZException, EndOfFileException) {
SkipSpaces();
return Consume(FindDelimiterOrEOF());
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n') throw(EndOfFileException);
StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException);
float ReadFloat() throw(EndOfFileException, ParseNumberException);
float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException);
double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException);
long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException);
unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException);
void SkipSpaces() throw (EndOfFileException);
void SkipSpaces() throw (GZException, EndOfFileException);
off_t Offset() const {
return position_ - data_.begin() + mapped_offset_;
}
// Only for testing.
void ForceFallbackToRead() {
fallback_to_read_ = true;
}
const std::string &FileName() const { return file_name_; }
private:
void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer);
void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException);
template <class T> T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException);
StringPiece Consume(const char *to) {
StringPiece ret(position_, to - position_);
@ -75,12 +91,14 @@ class FilePiece {
return ret;
}
const char *FindDelimiterOrEOF() throw(EndOfFileException);
const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException);
void Shift() throw (EndOfFileException);
void Shift() throw (EndOfFileException, GZException);
// Backends to Shift().
void MMapShift(off_t desired_begin) throw ();
void ReadShift(off_t desired_begin) throw ();
void TransitionToRead() throw (GZException);
void ReadShift() throw (GZException, EndOfFileException);
const char *position_, *last_space_, *position_end_;
@ -98,6 +116,12 @@ class FilePiece {
bool fallback_to_read_;
ErsatzProgress progress_;
std::string file_name_;
#ifdef HAVE_ZLIB
void *gz_file_;
#endif // HAVE_ZLIB
};
} // namespace util

View File

@ -1,15 +1,21 @@
#include "util/file_piece.hh"
#include "util/scoped.hh"
#define BOOST_TEST_MODULE FilePieceTest
#include <boost/test/unit_test.hpp>
#include <fstream>
#include <iostream>
#include <stdio.h>
#include <sys/types.h>
#include <sys/stat.h>
namespace util {
namespace {
/* mmap implementation */
BOOST_AUTO_TEST_CASE(MMapLine) {
BOOST_AUTO_TEST_CASE(MMapReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
FilePiece test("file_piece.cc", NULL, 1);
std::string ref_line;
@ -20,13 +26,21 @@ BOOST_AUTO_TEST_CASE(MMapLine) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
#ifndef __APPLE__
/* Apple isn't happy with the popen, fileno, dup. And I don't want to
* reimplement popen. This is an issue with the test.
*/
/* read() implementation */
BOOST_AUTO_TEST_CASE(ReadLine) {
BOOST_AUTO_TEST_CASE(StreamReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
FilePiece test("file_piece.cc", NULL, 1);
test.ForceFallbackToRead();
FILE *catter = popen("cat file_piece.cc", "r");
BOOST_REQUIRE(catter);
FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
@ -35,7 +49,54 @@ BOOST_AUTO_TEST_CASE(ReadLine) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
BOOST_REQUIRE(!pclose(catter));
}
#endif // __APPLE__
#ifdef HAVE_ZLIB
// gzip file
BOOST_AUTO_TEST_CASE(PlainZipReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
BOOST_REQUIRE_EQUAL(0, system("gzip <file_piece.cc >file_piece.cc.gz"));
FilePiece test("file_piece.cc.gz", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
// I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924
if (!test_line.empty() || !ref_line.empty()) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
// gzip stream. Apple doesn't like popen, fileno, dup. This is an issue with
// the test.
#ifndef __APPLE__
BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
FILE * catter = popen("gzip <file_piece.cc", "r");
BOOST_REQUIRE(catter);
FilePiece test(dup(fileno(catter)), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
// I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924
if (!test_line.empty() || !ref_line.empty()) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
BOOST_CHECK_THROW(test.get(), EndOfFileException);
BOOST_REQUIRE(!pclose(catter));
}
#endif // __APPLE__
#endif // HAVE_ZLIB
} // namespace
} // namespace util

View File

@ -119,6 +119,12 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi
} // namespace detail
template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > {
public:
PairedIterator(const KeyIter &key, const ValueIter &value) :
ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {}
};
template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) {
ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin));
detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less);

View File

@ -2,8 +2,9 @@
#include "util/mmap.hh"
#include "util/scoped.hh"
#include <iostream>
#include <assert.h>
#include <err.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/mman.h>
@ -14,8 +15,10 @@ namespace util {
scoped_mmap::~scoped_mmap() {
if (data_ != (void*)-1) {
if (munmap(data_, size_))
err(1, "munmap failed ");
if (munmap(data_, size_)) {
std::cerr << "munmap failed for " << size_ << " bytes." << std::endl;
abort();
}
}
}
@ -62,8 +65,49 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int
return ret;
}
void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset) {
return MapOrThrow(size, false, MAP_FILE | MAP_PRIVATE, prefault, fd, offset);
namespace {
void ReadAll(int fd, void *to_void, std::size_t amount) {
uint8_t *to = static_cast<uint8_t*>(to_void);
while (amount) {
ssize_t ret = read(fd, to, amount);
if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed.");
if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read.");
amount -= ret;
to += ret;
}
}
const int kFileFlags =
#ifdef MAP_FILE
MAP_FILE | MAP_SHARED
#else
MAP_SHARED
#endif
;
} // namespace
void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) {
switch (method) {
case LAZY:
out.reset(MapOrThrow(size, false, kFileFlags, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
break;
case POPULATE_OR_LAZY:
#ifdef MAP_POPULATE
case POPULATE_OR_READ:
#endif
out.reset(MapOrThrow(size, false, kFileFlags, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
break;
#ifndef MAP_POPULATE
case POPULATE_OR_READ:
#endif
case READ:
out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED);
if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc");
if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed.");
ReadAll(fd, out.get(), size);
break;
}
}
void *MapAnonymous(std::size_t size) {
@ -76,14 +120,14 @@ void *MapAnonymous(std::size_t size) {
| MAP_PRIVATE, false, -1, 0);
}
void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem) {
void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH));
if (-1 == file.get())
UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing");
if (-1 == ftruncate(file.get(), size))
UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed");
try {
mem.reset(MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0), size);
return MapOrThrow(size, true, kFileFlags, false, file.get(), 0);
} catch (ErrnoException &e) {
e << " in file " << name;
throw;

View File

@ -6,6 +6,7 @@
#include <cstddef>
#include <inttypes.h>
#include <sys/types.h>
namespace util {
@ -19,8 +20,8 @@ class scoped_mmap {
void *get() const { return data_; }
const char *begin() const { return reinterpret_cast<char*>(data_); }
const char *end() const { return reinterpret_cast<char*>(data_) + size_; }
const uint8_t *begin() const { return reinterpret_cast<uint8_t*>(data_); }
const uint8_t *end() const { return reinterpret_cast<uint8_t*>(data_) + size_; }
std::size_t size() const { return size_; }
void reset(void *data, std::size_t size) {
@ -79,23 +80,27 @@ class scoped_memory {
scoped_memory &operator=(const scoped_memory &);
};
struct scoped_mapped_file {
scoped_fd fd;
scoped_mmap mem;
};
typedef enum {
// mmap with no prepopulate
LAZY,
// On linux, pass MAP_POPULATE to mmap.
POPULATE_OR_LAZY,
// Populate on Linux. malloc and read on non-Linux.
POPULATE_OR_READ,
// malloc and read.
READ
} LoadMethod;
// Wrapper around mmap to check it worked and hide some platform macros.
void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0);
void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset = 0);
void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out);
void *MapAnonymous(std::size_t size);
// Open file name with mmap of size bytes, all of which are initially zero.
void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem);
inline void MapZeroedWrite(const char *name, std::size_t size, scoped_mapped_file &out) {
MapZeroedWrite(name, size, out.fd, out.mem);
}
void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file);
} // namespace util
#endif // UTIL_SCOPED__
#endif // UTIL_MMAP__

View File

@ -1,129 +1,129 @@
/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
* code is released to the public domain. For business purposes, Murmurhash is
* under the MIT license."
* This is modified from the original:
* ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.
* length changed to unsigned int.
* placed in namespace util
* add MurmurHashNative
* default option = 0 for seed
*/
#include "util/murmur_hash.hh"
namespace util {
//-----------------------------------------------------------------------------
// MurmurHash2, 64-bit versions, by Austin Appleby
// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment
// and endian-ness issues if used across multiple platforms.
// 64-bit hash for 64-bit platforms
uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
{
const uint64_t m = 0xc6a4a7935bd1e995ULL;
const int r = 47;
uint64_t h = seed ^ (len * m);
const uint64_t * data = (const uint64_t *)key;
const uint64_t * end = data + (len/8);
while(data != end)
{
uint64_t k = *data++;
k *= m;
k ^= k >> r;
k *= m;
h ^= k;
h *= m;
}
const unsigned char * data2 = (const unsigned char*)data;
switch(len & 7)
{
case 7: h ^= uint64_t(data2[6]) << 48;
case 6: h ^= uint64_t(data2[5]) << 40;
case 5: h ^= uint64_t(data2[4]) << 32;
case 4: h ^= uint64_t(data2[3]) << 24;
case 3: h ^= uint64_t(data2[2]) << 16;
case 2: h ^= uint64_t(data2[1]) << 8;
case 1: h ^= uint64_t(data2[0]);
h *= m;
};
h ^= h >> r;
h *= m;
h ^= h >> r;
return h;
}
// 64-bit hash for 32-bit platforms
uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
{
const unsigned int m = 0x5bd1e995;
const int r = 24;
unsigned int h1 = seed ^ len;
unsigned int h2 = 0;
const unsigned int * data = (const unsigned int *)key;
while(len >= 8)
{
unsigned int k1 = *data++;
k1 *= m; k1 ^= k1 >> r; k1 *= m;
h1 *= m; h1 ^= k1;
len -= 4;
unsigned int k2 = *data++;
k2 *= m; k2 ^= k2 >> r; k2 *= m;
h2 *= m; h2 ^= k2;
len -= 4;
}
if(len >= 4)
{
unsigned int k1 = *data++;
k1 *= m; k1 ^= k1 >> r; k1 *= m;
h1 *= m; h1 ^= k1;
len -= 4;
}
switch(len)
{
case 3: h2 ^= ((unsigned char*)data)[2] << 16;
case 2: h2 ^= ((unsigned char*)data)[1] << 8;
case 1: h2 ^= ((unsigned char*)data)[0];
h2 *= m;
};
h1 ^= h2 >> 18; h1 *= m;
h2 ^= h1 >> 22; h2 *= m;
h1 ^= h2 >> 17; h1 *= m;
h2 ^= h1 >> 19; h2 *= m;
uint64_t h = h1;
h = (h << 32) | h2;
return h;
}
uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
if (sizeof(int) == 4) {
return MurmurHash64B(key, len, seed);
} else {
return MurmurHash64A(key, len, seed);
}
}
} // namespace util
/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
* code is released to the public domain. For business purposes, Murmurhash is
* under the MIT license."
* This is modified from the original:
* ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.
* length changed to unsigned int.
* placed in namespace util
* add MurmurHashNative
* default option = 0 for seed
*/
#include "util/murmur_hash.hh"
namespace util {
//-----------------------------------------------------------------------------
// MurmurHash2, 64-bit versions, by Austin Appleby
// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment
// and endian-ness issues if used across multiple platforms.
// 64-bit hash for 64-bit platforms
uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
{
const uint64_t m = 0xc6a4a7935bd1e995ULL;
const int r = 47;
uint64_t h = seed ^ (len * m);
const uint64_t * data = (const uint64_t *)key;
const uint64_t * end = data + (len/8);
while(data != end)
{
uint64_t k = *data++;
k *= m;
k ^= k >> r;
k *= m;
h ^= k;
h *= m;
}
const unsigned char * data2 = (const unsigned char*)data;
switch(len & 7)
{
case 7: h ^= uint64_t(data2[6]) << 48;
case 6: h ^= uint64_t(data2[5]) << 40;
case 5: h ^= uint64_t(data2[4]) << 32;
case 4: h ^= uint64_t(data2[3]) << 24;
case 3: h ^= uint64_t(data2[2]) << 16;
case 2: h ^= uint64_t(data2[1]) << 8;
case 1: h ^= uint64_t(data2[0]);
h *= m;
};
h ^= h >> r;
h *= m;
h ^= h >> r;
return h;
}
// 64-bit hash for 32-bit platforms
uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
{
const unsigned int m = 0x5bd1e995;
const int r = 24;
unsigned int h1 = seed ^ len;
unsigned int h2 = 0;
const unsigned int * data = (const unsigned int *)key;
while(len >= 8)
{
unsigned int k1 = *data++;
k1 *= m; k1 ^= k1 >> r; k1 *= m;
h1 *= m; h1 ^= k1;
len -= 4;
unsigned int k2 = *data++;
k2 *= m; k2 ^= k2 >> r; k2 *= m;
h2 *= m; h2 ^= k2;
len -= 4;
}
if(len >= 4)
{
unsigned int k1 = *data++;
k1 *= m; k1 ^= k1 >> r; k1 *= m;
h1 *= m; h1 ^= k1;
len -= 4;
}
switch(len)
{
case 3: h2 ^= ((unsigned char*)data)[2] << 16;
case 2: h2 ^= ((unsigned char*)data)[1] << 8;
case 1: h2 ^= ((unsigned char*)data)[0];
h2 *= m;
};
h1 ^= h2 >> 18; h1 *= m;
h2 ^= h1 >> 22; h2 *= m;
h1 ^= h2 >> 17; h1 *= m;
h2 ^= h1 >> 19; h2 *= m;
uint64_t h = h1;
h = (h << 32) | h2;
return h;
}
uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
if (sizeof(int) == 4) {
return MurmurHash64B(key, len, seed);
} else {
return MurmurHash64A(key, len, seed);
}
}
} // namespace util

View File

@ -1,7 +1,7 @@
#ifndef UTIL_MURMUR_HASH__
#define UTIL_MURMUR_HASH__
#include <cstddef>
#include <stdint.h>
#include <inttypes.h>
namespace util {

View File

@ -78,6 +78,8 @@ template <class Proxy> class ProxyIterator {
const Proxy *operator->() const { return &p_; }
Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); }
const InnerIterator &Inner() { return p_.Inner(); }
private:
InnerIterator &I() { return p_.Inner(); }
const InnerIterator &I() const { return p_.Inner(); }

View File

@ -1,12 +1,24 @@
#include "util/scoped.hh"
#include <err.h>
#include <iostream>
#include <stdlib.h>
#include <unistd.h>
namespace util {
scoped_fd::~scoped_fd() {
if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_);
if (fd_ != -1 && close(fd_)) {
std::cerr << "Could not close file " << fd_ << std::endl;
abort();
}
}
scoped_FILE::~scoped_FILE() {
if (file_ && fclose(file_)) {
std::cerr << "Could not close file " << std::endl;
abort();
}
}
} // namespace util

View File

@ -4,6 +4,7 @@
/* Other scoped objects in the style of scoped_ptr. */
#include <cstddef>
#include <cstdio>
namespace util {
@ -61,6 +62,24 @@ class scoped_fd {
scoped_fd &operator=(const scoped_fd &);
};
class scoped_FILE {
public:
explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {}
~scoped_FILE();
std::FILE *get() { return file_; }
const std::FILE *get() const { return file_; }
void reset(std::FILE *to = NULL) {
scoped_FILE other(file_);
file_ = to;
}
private:
std::FILE *file_;
};
} // namespace util
#endif // UTIL_SCOPED__

View File

@ -75,7 +75,7 @@ template <class PackingT> class SortedUniformMap {
#endif
{}
SortedUniformMap(void *start, std::size_t/*allocated*/) :
SortedUniformMap(void *start, std::size_t /*allocated*/) :
begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)),
end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start))
#ifdef DEBUG

View File

@ -30,14 +30,14 @@
#include "util/string_piece.hh"
#ifdef USE_BOOST
#ifdef HAVE_BOOST
#include <boost/functional/hash/hash.hpp>
#endif
#include <algorithm>
#include <iostream>
#ifdef USE_ICU
#ifdef HAVE_ICU
U_NAMESPACE_BEGIN
#endif
@ -46,12 +46,12 @@ std::ostream& operator<<(std::ostream& o, const StringPiece& piece) {
return o;
}
#ifdef USE_BOOST
#ifdef HAVE_BOOST
size_t hash_value(const StringPiece &str) {
return boost::hash_range(str.data(), str.data() + str.length());
}
#endif
#ifdef USE_ICU
#ifdef HAVE_ICU
U_NAMESPACE_END
#endif

View File

@ -1,4 +1,4 @@
/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n. If
/* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If
* you don't use ICU, then this will use the Google implementation from Chrome.
* This has been modified from the original version to let you choose.
*/
@ -49,14 +49,14 @@
#define BASE_STRING_PIECE_H__
//Uncomment this line if you use ICU in your code.
//#define USE_ICU
//#define HAVE_ICU
//Uncomment this line if you want boost hashing for your StringPieces.
//#define USE_BOOST
//#define HAVE_BOOST
#include <cstring>
#include <iosfwd>
#ifdef USE_ICU
#ifdef HAVE_ICU
#include <unicode/stringpiece.h>
U_NAMESPACE_BEGIN
#else
@ -230,7 +230,7 @@ inline bool operator>=(const StringPiece& x, const StringPiece& y) {
// allow StringPiece to be logged (needed for unit testing).
extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece);
#ifdef USE_BOOST
#ifdef HAVE_BOOST
size_t hash_value(const StringPiece &str);
/* Support for lookup of StringPiece in boost::unordered_map<std::string> */
@ -253,7 +253,7 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece
}
#endif
#ifdef USE_ICU
#ifdef HAVE_ICU
U_NAMESPACE_END
#endif

View File

@ -316,7 +316,7 @@ statscore_t Optimizer::Run(Point& P)const{
exit(2);
}
if (scorer->getReferenceSize()!=FData->size()){
cerr<<"error size mismatch between FeatureData and Scorer"<<endl;
cerr<<"error length mismatch between feature file and score file"<<endl;
exit(2);
}

View File

@ -46,10 +46,10 @@ int main(int argc, char **argv) {
srcphrase = Moses::Tokenize<std::string>(line);
std::vector<Moses::StringTgtCand> tgtcands;
std::vector<Moses::StringWordAlignmentCand> src_wa, tgt_wa;
std::vector<std::string> wordAlignment;
if(useAlignments)
ptree.GetTargetCandidates(srcphrase, tgtcands, src_wa, tgt_wa);
ptree.GetTargetCandidates(srcphrase, tgtcands, wordAlignment);
else
ptree.GetTargetCandidates(srcphrase, tgtcands);
@ -60,19 +60,7 @@ int main(int argc, char **argv) {
std::cout << " |||";
if(useAlignments) {
for(uint j = 0; j < src_wa[i].second.size(); j++)
if(src_wa[i].second[j] == "-1")
std::cout << " ()";
else
std::cout << " (" << src_wa[i].second[j] << ")";
std::cout << " |||";
for(uint j = 0; j < tgt_wa[i].second.size(); j++)
if(tgt_wa[i].second[j] == "-1")
std::cout << " ()";
else
std::cout << " (" << tgt_wa[i].second[j] << ")";
std::cout << " |||";
std::cout << " " << wordAlignment[i] << " |||";
}
for(uint j = 0; j < tgtcands[i].second.size(); j++)

View File

@ -227,7 +227,14 @@
isa = PBXProject;
buildConfigurationList = 1DEB923508733DC60010E9CD /* Build configuration list for PBXProject "moses-chart-cmd" */;
compatibilityVersion = "Xcode 3.1";
developmentRegion = English;
hasScannedForEncodings = 1;
knownRegions = (
English,
Japanese,
French,
German,
);
mainGroup = 08FB7794FE84155DC02AAC07 /* moses-chart-cmd */;
projectDirPath = "";
projectReferences = (
@ -336,7 +343,7 @@
HEADER_SEARCH_PATHS = ../moses/src;
INSTALL_PATH = /usr/local/bin;
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib,
../srilm/lib/macosx,
../kenlm/lm,
../randlm/lib,
@ -349,7 +356,6 @@
"-loolm",
"-lflm",
"-llattice",
"-lkenlm",
"-lrandlm",
);
PRODUCT_NAME = "moses-chart-cmd";
@ -365,7 +371,7 @@
HEADER_SEARCH_PATHS = ../moses/src;
INSTALL_PATH = /usr/local/bin;
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib,
../srilm/lib/macosx,
../kenlm/lm,
../randlm/lib,
@ -378,7 +384,6 @@
"-loolm",
"-lflm",
"-llattice",
"-lkenlm",
"-lrandlm",
);
PRODUCT_NAME = "moses-chart-cmd";

View File

@ -83,6 +83,46 @@ bool ReadInput(IOWrapper &ioWrapper, InputTypeEnum inputType, InputType*& source
return (source ? true : false);
}
static void PrintFeatureWeight(const FeatureFunction* ff) {
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
for (size_t i = 0; i < numScoreComps; ++i) {
cout << ff->GetScoreProducerDescription() << " "
<< ff->GetScoreProducerWeightShortName() << " "
<< values[i] << endl;
}
} else {
cout << ff->GetScoreProducerDescription() << " " <<
ff->GetScoreProducerWeightShortName() << " sparse" << endl;
}
}
static void ShowWeights() {
cout.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for (size_t i = 0; i < sff.size(); ++i) {
PrintFeatureWeight(sff[i]);
}
for (size_t i = 0; i < pds.size(); ++i) {
PrintFeatureWeight(pds[i]);
}
for (size_t i = 0; i < gds.size(); ++i) {
PrintFeatureWeight(gds[i]);
}
for (size_t i = 0; i < slf.size(); ++i) {
PrintFeatureWeight(slf[i]);
}
}
int main(int argc, char* argv[])
{
@ -108,6 +148,11 @@ int main(int argc, char* argv[])
const StaticData &staticData = StaticData::Instance();
if (!StaticData::LoadDataStatic(&parameter))
return EXIT_FAILURE;
if (parameter.isParamSpecified("show-weights")) {
ShowWeights();
exit(0);
}
assert(staticData.GetSearchAlgorithm() == ChartDecoding);

View File

@ -14,6 +14,7 @@
1C8CFF4F0AD68D3600FA22E2 /* TranslationAnalysis.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1C8CFF470AD68D3600FA22E2 /* TranslationAnalysis.cpp */; };
1C8CFF500AD68D3600FA22E2 /* TranslationAnalysis.h in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1C8CFF480AD68D3600FA22E2 /* TranslationAnalysis.h */; };
1CE646E411679F6900EC77CC /* libOnDiskPt.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CE646E311679F5F00EC77CC /* libOnDiskPt.a */; };
1EBB175F126C16B800AE6102 /* libkenlm.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 1EBB175A126C169000AE6102 /* libkenlm.a */; };
B219B8540E93812700EAB407 /* libmoses.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 03306D670C0B240B00CA1311 /* libmoses.a */; };
B219B8580E9381AC00EAB407 /* IOWrapper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = B219B8560E9381AC00EAB407 /* IOWrapper.cpp */; };
B28B1ED3110F52BB00AAD188 /* LatticeMBR.cpp in Sources */ = {isa = PBXBuildFile; fileRef = B28B1ED2110F52BB00AAD188 /* LatticeMBR.cpp */; };
@ -48,6 +49,20 @@
remoteGlobalIDString = D2AAC045055464E500DB518D;
remoteInfo = OnDiskPt;
};
1EBB1759126C169000AE6102 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 1EBB1752126C169000AE6102 /* kenlm.xcodeproj */;
proxyType = 2;
remoteGlobalIDString = D2AAC046055464E500DB518D;
remoteInfo = kenlm;
};
1ED0E9661277CFC500AC18B1 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = 1EBB1752126C169000AE6102 /* kenlm.xcodeproj */;
proxyType = 1;
remoteGlobalIDString = D2AAC045055464E500DB518D;
remoteInfo = kenlm;
};
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
@ -74,6 +89,7 @@
1C8CFF470AD68D3600FA22E2 /* TranslationAnalysis.cpp */ = {isa = PBXFileReference; fileEncoding = 30; lastKnownFileType = sourcecode.cpp.cpp; name = TranslationAnalysis.cpp; path = src/TranslationAnalysis.cpp; sourceTree = "<group>"; };
1C8CFF480AD68D3600FA22E2 /* TranslationAnalysis.h */ = {isa = PBXFileReference; fileEncoding = 30; lastKnownFileType = sourcecode.c.h; name = TranslationAnalysis.h; path = src/TranslationAnalysis.h; sourceTree = "<group>"; };
1CE646DB11679F5F00EC77CC /* OnDiskPt.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = OnDiskPt.xcodeproj; path = ../OnDiskPt/OnDiskPt.xcodeproj; sourceTree = SOURCE_ROOT; };
1EBB1752126C169000AE6102 /* kenlm.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = kenlm.xcodeproj; path = ../kenlm/kenlm.xcodeproj; sourceTree = SOURCE_ROOT; };
8DD76F6C0486A84900D96B5E /* moses-cmd */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "moses-cmd"; sourceTree = BUILT_PRODUCTS_DIR; };
B219B8560E9381AC00EAB407 /* IOWrapper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = IOWrapper.cpp; path = src/IOWrapper.cpp; sourceTree = "<group>"; };
B219B8570E9381AC00EAB407 /* IOWrapper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = IOWrapper.h; path = src/IOWrapper.h; sourceTree = "<group>"; };
@ -86,6 +102,7 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
1EBB175F126C16B800AE6102 /* libkenlm.a in Frameworks */,
1CE646E411679F6900EC77CC /* libOnDiskPt.a in Frameworks */,
B219B8540E93812700EAB407 /* libmoses.a in Frameworks */,
);
@ -105,6 +122,7 @@
08FB7794FE84155DC02AAC07 /* moses-cmd */ = {
isa = PBXGroup;
children = (
1EBB1752126C169000AE6102 /* kenlm.xcodeproj */,
1CE646DB11679F5F00EC77CC /* OnDiskPt.xcodeproj */,
03306D5F0C0B240B00CA1311 /* moses.xcodeproj */,
08FB7795FE84155DC02AAC07 /* Source */,
@ -147,6 +165,14 @@
name = Products;
sourceTree = "<group>";
};
1EBB1753126C169000AE6102 /* Products */ = {
isa = PBXGroup;
children = (
1EBB175A126C169000AE6102 /* libkenlm.a */,
);
name = Products;
sourceTree = "<group>";
};
C6859E8C029090F304C91782 /* Documentation */ = {
isa = PBXGroup;
children = (
@ -170,6 +196,7 @@
dependencies = (
03306D780C0B244800CA1311 /* PBXTargetDependency */,
1CE6472E1167A11600EC77CC /* PBXTargetDependency */,
1ED0E9671277CFC500AC18B1 /* PBXTargetDependency */,
);
name = "moses-cmd";
productInstallPath = "$(HOME)/bin";
@ -184,10 +211,21 @@
isa = PBXProject;
buildConfigurationList = 03306D3F0C0B23F200CA1311 /* Build configuration list for PBXProject "moses-cmd" */;
compatibilityVersion = "Xcode 2.4";
developmentRegion = English;
hasScannedForEncodings = 1;
knownRegions = (
English,
Japanese,
French,
German,
);
mainGroup = 08FB7794FE84155DC02AAC07 /* moses-cmd */;
projectDirPath = "";
projectReferences = (
{
ProductGroup = 1EBB1753126C169000AE6102 /* Products */;
ProjectRef = 1EBB1752126C169000AE6102 /* kenlm.xcodeproj */;
},
{
ProductGroup = 03306D600C0B240B00CA1311 /* Products */;
ProjectRef = 03306D5F0C0B240B00CA1311 /* moses.xcodeproj */;
@ -219,6 +257,13 @@
remoteRef = 1CE646E211679F5F00EC77CC /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
1EBB175A126C169000AE6102 /* libkenlm.a */ = {
isa = PBXReferenceProxy;
fileType = archive.ar;
path = libkenlm.a;
remoteRef = 1EBB1759126C169000AE6102 /* PBXContainerItemProxy */;
sourceTree = BUILT_PRODUCTS_DIR;
};
/* End PBXReferenceProxy section */
/* Begin PBXSourcesBuildPhase section */
@ -247,6 +292,11 @@
name = OnDiskPt;
targetProxy = 1CE6472D1167A11600EC77CC /* PBXContainerItemProxy */;
};
1ED0E9671277CFC500AC18B1 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
name = kenlm;
targetProxy = 1ED0E9661277CFC500AC18B1 /* PBXContainerItemProxy */;
};
/* End PBXTargetDependency section */
/* Begin XCBuildConfiguration section */
@ -269,7 +319,7 @@
HEADER_SEARCH_PATHS = ../moses/src;
INSTALL_PATH = "$(HOME)/bin";
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib/,
../srilm/lib/macosx,
../kenlm/lm,
../randlm/lib,
@ -281,7 +331,6 @@
"-ldstruct",
"-lz",
"-lirstlm",
"-lkenlm",
"-lrandlm",
);
PREBINDING = NO;
@ -308,7 +357,7 @@
HEADER_SEARCH_PATHS = ../moses/src;
INSTALL_PATH = "$(HOME)/bin";
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib/,
../srilm/lib/macosx,
../kenlm/lm,
../randlm/lib,
@ -320,7 +369,6 @@
"-ldstruct",
"-lz",
"-lirstlm",
"-lkenlm",
"-lrandlm",
);
PREBINDING = NO;
@ -339,7 +387,7 @@
HEADER_SEARCH_PATHS = ../moses/src;
INSTALL_PATH = "$(HOME)/bin";
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib/,
../srilm/lib/macosx,
../kenlm/lm,
../randlm/lib,
@ -351,7 +399,6 @@
"-ldstruct",
"-lz",
"-lirstlm",
"-lkenlm",
"-lrandlm",
);
PREBINDING = NO;

View File

@ -319,7 +319,7 @@ void OutputNBest(std::ostream& out, const Moses::TrellisPathList &nBestList, con
bool labeledOutput = staticData.IsLabeledNBestList();
bool reportAllFactors = staticData.GetReportAllFactorsNBest();
bool includeAlignment = staticData.NBestIncludesAlignment();
//bool includeWordAlignment = staticData.PrintAlignmentInfoInNbest();
bool includeWordAlignment = staticData.PrintAlignmentInfoInNbest();
TrellisPathList::const_iterator iter;
for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter)
@ -454,13 +454,31 @@ void OutputNBest(std::ostream& out, const Moses::TrellisPathList &nBestList, con
out<< "-" << targetRange.GetEndPos();
}
}
}
}
if (StaticData::Instance().IsPathRecoveryEnabled()) {
out << "|||";
OutputInput(out, edges[0]);
}
if (includeWordAlignment) {
out << " |||";
for (int currEdge = (int)edges.size() - 2 ; currEdge >= 0 ; currEdge--)
{
const Hypothesis &edge = *edges[currEdge];
const WordsRange &sourceRange = edge.GetCurrSourceWordsRange();
WordsRange targetRange = path.GetTargetWordsRange(edge);
const int sourceOffset = sourceRange.GetStartPos();
const int targetOffset = targetRange.GetStartPos();
const AlignmentInfo AI = edge.GetCurrTargetPhrase().GetAlignmentInfo();
AlignmentInfo::const_iterator iter;
for (iter = AI.begin(); iter != AI.end(); ++iter)
{
out << " " << iter->first+sourceOffset << "-" << iter->second+targetOffset;
}
}
}
if (StaticData::Instance().IsPathRecoveryEnabled()) {
out << "|||";
OutputInput(out, edges[0]);
}
out << endl;
}

View File

@ -57,10 +57,12 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
using namespace std;
using namespace Moses;
static const size_t PRECISION = 3;
/** Enforce rounding */
void fix(std::ostream& stream) {
void fix(std::ostream& stream, size_t size) {
stream.setf(std::ios::fixed);
stream.precision(3);
stream.precision(size);
}
@ -145,7 +147,7 @@ class TranslationTask : public Task {
//Word Graph
if (m_wordGraphCollector) {
ostringstream out;
fix(out);
fix(out,PRECISION);
manager.GetWordGraph(m_lineNumber, out);
m_wordGraphCollector->Write(m_lineNumber, out.str());
}
@ -153,7 +155,7 @@ class TranslationTask : public Task {
//Search Graph
if (m_searchGraphCollector) {
ostringstream out;
fix(out);
fix(out,PRECISION);
manager.OutputSearchGraph(m_lineNumber, out);
m_searchGraphCollector->Write(m_lineNumber, out.str());
@ -174,7 +176,7 @@ class TranslationTask : public Task {
if (m_outputCollector) {
ostringstream out;
ostringstream debug;
fix(debug);
fix(debug,PRECISION);
//All derivations - send them to debug stream
if (staticData.PrintAllDerivations()) {
@ -267,7 +269,7 @@ class TranslationTask : public Task {
//detailed translation reporting
if (m_detailedTranslationCollector) {
ostringstream out;
fix(out);
fix(out,PRECISION);
TranslationAnalysis::PrintTranslationAnalysis(manager.GetTranslationSystem(), out, manager.GetBestHypothesis());
m_detailedTranslationCollector->Write(m_lineNumber,out.str());
}
@ -291,6 +293,45 @@ class TranslationTask : public Task {
};
static void PrintFeatureWeight(const FeatureFunction* ff) {
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
for (size_t i = 0; i < numScoreComps; ++i) {
cout << ff->GetScoreProducerDescription() << " "
<< ff->GetScoreProducerWeightShortName() << " "
<< values[i] << endl;
}
} else {
cout << ff->GetScoreProducerDescription() << " " <<
ff->GetScoreProducerWeightShortName() << " sparse" << endl;
}
}
static void ShowWeights() {
fix(cout,6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for (size_t i = 0; i < sff.size(); ++i) {
PrintFeatureWeight(sff[i]);
}
for (size_t i = 0; i < slf.size(); ++i) {
PrintFeatureWeight(slf[i]);
}
for (size_t i = 0; i < pds.size(); ++i) {
PrintFeatureWeight(pds[i]);
}
for (size_t i = 0; i < gds.size(); ++i) {
PrintFeatureWeight(gds[i]);
}
}
int main(int argc, char** argv) {
#ifdef HAVE_PROTOBUF
@ -303,8 +344,8 @@ int main(int argc, char** argv) {
TRACE_ERR(endl);
}
fix(cout);
fix(cerr);
fix(cout,PRECISION);
fix(cerr,PRECISION);
Parameter* params = new Parameter();
@ -334,6 +375,11 @@ int main(int argc, char** argv) {
if (!StaticData::LoadDataStatic(params)) {
exit(1);
}
if (params->isParamSpecified("show-weights")) {
ShowWeights();
exit(0);
}
const StaticData& staticData = StaticData::Instance();
// set up read/writing class

View File

@ -9,6 +9,8 @@
/* Begin PBXBuildFile section */
1E2B861D12555F25000770D6 /* LanguageModelRandLM.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E2B861B12555F25000770D6 /* LanguageModelRandLM.cpp */; };
1E2B861E12555F25000770D6 /* LanguageModelRandLM.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E2B861C12555F25000770D6 /* LanguageModelRandLM.h */; };
1E5AAF1512B25C9E0071864D /* LanguageModelImplementation.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1E5AAF1312B25C9E0071864D /* LanguageModelImplementation.cpp */; };
1E5AAF1612B25C9E0071864D /* LanguageModelImplementation.h in Headers */ = {isa = PBXBuildFile; fileRef = 1E5AAF1412B25C9E0071864D /* LanguageModelImplementation.h */; };
1ED00036124BC2690029177F /* ChartTranslationOption.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1ED00034124BC2690029177F /* ChartTranslationOption.cpp */; };
1ED00037124BC2690029177F /* ChartTranslationOption.h in Headers */ = {isa = PBXBuildFile; fileRef = 1ED00035124BC2690029177F /* ChartTranslationOption.h */; };
1ED0FE2A124BB9380029177F /* AlignmentInfo.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1ED0FD4C124BB9380029177F /* AlignmentInfo.cpp */; };
@ -223,11 +225,14 @@
1ED0FF02124BB9380029177F /* XmlOption.h in Headers */ = {isa = PBXBuildFile; fileRef = 1ED0FE29124BB9380029177F /* XmlOption.h */; };
1ED0FFD3124BC0BF0029177F /* ChartTranslationOptionList.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1ED0FFD1124BC0BF0029177F /* ChartTranslationOptionList.cpp */; };
1ED0FFD4124BC0BF0029177F /* ChartTranslationOptionList.h in Headers */ = {isa = PBXBuildFile; fileRef = 1ED0FFD2124BC0BF0029177F /* ChartTranslationOptionList.h */; };
1EEB43EE1264A6F200739BA5 /* PhraseDictionarySCFG.h in Headers */ = {isa = PBXBuildFile; fileRef = 1EEB43ED1264A6F200739BA5 /* PhraseDictionarySCFG.h */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
1E2B861B12555F25000770D6 /* LanguageModelRandLM.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = LanguageModelRandLM.cpp; path = src/LanguageModelRandLM.cpp; sourceTree = "<group>"; };
1E2B861C12555F25000770D6 /* LanguageModelRandLM.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = LanguageModelRandLM.h; path = src/LanguageModelRandLM.h; sourceTree = "<group>"; };
1E5AAF1312B25C9E0071864D /* LanguageModelImplementation.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = LanguageModelImplementation.cpp; path = src/LanguageModelImplementation.cpp; sourceTree = "<group>"; };
1E5AAF1412B25C9E0071864D /* LanguageModelImplementation.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = LanguageModelImplementation.h; path = src/LanguageModelImplementation.h; sourceTree = "<group>"; };
1ED00034124BC2690029177F /* ChartTranslationOption.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ChartTranslationOption.cpp; path = src/ChartTranslationOption.cpp; sourceTree = "<group>"; };
1ED00035124BC2690029177F /* ChartTranslationOption.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ChartTranslationOption.h; path = src/ChartTranslationOption.h; sourceTree = "<group>"; };
1ED0FD4C124BB9380029177F /* AlignmentInfo.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = AlignmentInfo.cpp; path = src/AlignmentInfo.cpp; sourceTree = "<group>"; };
@ -444,6 +449,7 @@
1ED0FE29124BB9380029177F /* XmlOption.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = XmlOption.h; path = src/XmlOption.h; sourceTree = "<group>"; };
1ED0FFD1124BC0BF0029177F /* ChartTranslationOptionList.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ChartTranslationOptionList.cpp; path = src/ChartTranslationOptionList.cpp; sourceTree = "<group>"; };
1ED0FFD2124BC0BF0029177F /* ChartTranslationOptionList.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ChartTranslationOptionList.h; path = src/ChartTranslationOptionList.h; sourceTree = "<group>"; };
1EEB43ED1264A6F200739BA5 /* PhraseDictionarySCFG.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = PhraseDictionarySCFG.h; path = src/PhraseDictionarySCFG.h; sourceTree = "<group>"; };
D2AAC046055464E500DB518D /* libmoses.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libmoses.a; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
@ -471,12 +477,6 @@
08FB7795FE84155DC02AAC07 /* Source */ = {
isa = PBXGroup;
children = (
1E2B861B12555F25000770D6 /* LanguageModelRandLM.cpp */,
1E2B861C12555F25000770D6 /* LanguageModelRandLM.h */,
1ED00034124BC2690029177F /* ChartTranslationOption.cpp */,
1ED00035124BC2690029177F /* ChartTranslationOption.h */,
1ED0FFD1124BC0BF0029177F /* ChartTranslationOptionList.cpp */,
1ED0FFD2124BC0BF0029177F /* ChartTranslationOptionList.h */,
1ED0FD4C124BB9380029177F /* AlignmentInfo.cpp */,
1ED0FD4D124BB9380029177F /* AlignmentInfo.h */,
1ED0FD4E124BB9380029177F /* BilingualDynSuffixArray.cpp */,
@ -484,6 +484,10 @@
1ED0FD50124BB9380029177F /* BitmapContainer.cpp */,
1ED0FD51124BB9380029177F /* BitmapContainer.h */,
1ED0FD52124BB9380029177F /* CellCollection.h */,
1ED00034124BC2690029177F /* ChartTranslationOption.cpp */,
1ED00035124BC2690029177F /* ChartTranslationOption.h */,
1ED0FFD1124BC0BF0029177F /* ChartTranslationOptionList.cpp */,
1ED0FFD2124BC0BF0029177F /* ChartTranslationOptionList.h */,
1ED0FD57124BB9380029177F /* ConfusionNet.cpp */,
1ED0FD58124BB9380029177F /* ConfusionNet.h */,
1ED0FD59124BB9380029177F /* DecodeFeature.cpp */,
@ -505,6 +509,7 @@
1ED0FD69124BB9380029177F /* DummyScoreProducers.cpp */,
1ED0FD6A124BB9380029177F /* DummyScoreProducers.h */,
1ED0FD6B124BB9380029177F /* DynSAInclude */,
1ED4FC7C11BDC0D2004E826A /* DynSAInclude */,
1ED0FD73124BB9380029177F /* DynSuffixArray.cpp */,
1ED0FD74124BB9380029177F /* DynSuffixArray.h */,
1ED0FD75124BB9380029177F /* Factor.cpp */,
@ -547,6 +552,8 @@
1ED0FD9A124BB9380029177F /* LanguageModelDelegate.h */,
1ED0FD9B124BB9380029177F /* LanguageModelFactory.cpp */,
1ED0FD9C124BB9380029177F /* LanguageModelFactory.h */,
1E5AAF1312B25C9E0071864D /* LanguageModelImplementation.cpp */,
1E5AAF1412B25C9E0071864D /* LanguageModelImplementation.h */,
1ED0FD9D124BB9380029177F /* LanguageModelInternal.cpp */,
1ED0FD9E124BB9380029177F /* LanguageModelInternal.h */,
1ED0FD9F124BB9380029177F /* LanguageModelIRST.cpp */,
@ -559,6 +566,8 @@
1ED0FDA6124BB9380029177F /* LanguageModelMultiFactor.h */,
1ED0FDA7124BB9380029177F /* LanguageModelParallelBackoff.cpp */,
1ED0FDA8124BB9380029177F /* LanguageModelParallelBackoff.h */,
1E2B861B12555F25000770D6 /* LanguageModelRandLM.cpp */,
1E2B861C12555F25000770D6 /* LanguageModelRandLM.h */,
1ED0FDAB124BB9380029177F /* LanguageModelRemote.cpp */,
1ED0FDAC124BB9380029177F /* LanguageModelRemote.h */,
1ED0FDAD124BB9380029177F /* LanguageModelSingleFactor.cpp */,
@ -607,6 +616,7 @@
1ED0FDDB124BB9380029177F /* PhraseDictionaryOnDisk.h */,
1ED0FDDC124BB9380029177F /* PhraseDictionaryOnDiskChart.cpp */,
1ED0FDDD124BB9380029177F /* PhraseDictionarySCFG.cpp */,
1EEB43ED1264A6F200739BA5 /* PhraseDictionarySCFG.h */,
1ED0FDDE124BB9380029177F /* PhraseDictionarySCFGChart.cpp */,
1ED0FDDF124BB9380029177F /* PhraseDictionaryTree.cpp */,
1ED0FDE0124BB9380029177F /* PhraseDictionaryTree.h */,
@ -683,7 +693,6 @@
1ED0FE27124BB9380029177F /* WordsRange.h */,
1ED0FE28124BB9380029177F /* XmlOption.cpp */,
1ED0FE29124BB9380029177F /* XmlOption.h */,
1ED4FC7C11BDC0D2004E826A /* DynSAInclude */,
);
name = Source;
sourceTree = "<group>";
@ -846,6 +855,8 @@
1ED0FFD4124BC0BF0029177F /* ChartTranslationOptionList.h in Headers */,
1ED00037124BC2690029177F /* ChartTranslationOption.h in Headers */,
1E2B861E12555F25000770D6 /* LanguageModelRandLM.h in Headers */,
1EEB43EE1264A6F200739BA5 /* PhraseDictionarySCFG.h in Headers */,
1E5AAF1612B25C9E0071864D /* LanguageModelImplementation.h in Headers */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@ -876,7 +887,14 @@
isa = PBXProject;
buildConfigurationList = 1DEB91EF08733DB70010E9CD /* Build configuration list for PBXProject "moses" */;
compatibilityVersion = "Xcode 3.1";
developmentRegion = English;
hasScannedForEncodings = 1;
knownRegions = (
English,
Japanese,
French,
German,
);
mainGroup = 08FB7794FE84155DC02AAC07 /* moses */;
projectDirPath = "";
projectRoot = "";
@ -994,6 +1012,7 @@
1ED0FFD3124BC0BF0029177F /* ChartTranslationOptionList.cpp in Sources */,
1ED00036124BC2690029177F /* ChartTranslationOption.cpp in Sources */,
1E2B861D12555F25000770D6 /* LanguageModelRandLM.cpp in Sources */,
1E5AAF1512B25C9E0071864D /* LanguageModelImplementation.cpp in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};

View File

@ -17,7 +17,8 @@ BilingualDynSuffixArray::BilingualDynSuffixArray():
m_trgSA = 0;
m_srcCorpus = new std::vector<wordID_t>();
m_trgCorpus = new std::vector<wordID_t>();
m_vocab = new Vocab(false);
m_srcVocab = new Vocab(false);
m_trgVocab = new Vocab(false);
m_scoreCmp = 0;
}
@ -25,7 +26,8 @@ BilingualDynSuffixArray::~BilingualDynSuffixArray()
{
if(m_srcSA) delete m_srcSA;
if(m_trgSA) delete m_trgSA;
if(m_vocab) delete m_vocab;
if(m_srcVocab) delete m_srcVocab;
if(m_trgVocab) delete m_trgVocab;
if(m_srcCorpus) delete m_srcCorpus;
if(m_trgCorpus) delete m_trgCorpus;
if(m_scoreCmp) delete m_scoreCmp;
@ -37,17 +39,17 @@ bool BilingualDynSuffixArray::Load(
std::string source, std::string target, std::string alignments,
const std::vector<float> &weight)
{
m_inputFactors = FactorMask(inputFactors);
m_outputFactors = FactorMask(outputFactors);
m_inputFactors = inputFactors;
m_outputFactors = outputFactors;
m_scoreCmp = new ScoresComp(weight);
InputFileStream sourceStrme(source);
InputFileStream targetStrme(target);
cerr << "Loading source and target parallel corpus...\n";
LoadCorpus(sourceStrme, inputFactors, Input, *m_srcCorpus, m_srcSntBreaks);
LoadCorpus(targetStrme, outputFactors, Output, *m_trgCorpus, m_trgSntBreaks);
cerr << "Loading source corpus...\n";
LoadCorpus(sourceStrme, m_inputFactors, Input, *m_srcCorpus, m_srcSntBreaks, m_srcVocab);
cerr << "Loading target corpus...\n";
LoadCorpus(targetStrme, m_outputFactors, Output, *m_trgCorpus, m_trgSntBreaks, m_trgVocab);
assert(m_srcSntBreaks.size() == m_trgSntBreaks.size());
m_vocab->MakeClosed(); // avoid adding new words to vocabulary
// build suffix arrays and auxilliary arrays
cerr << "Building Source Suffix Array...\n";
@ -73,7 +75,10 @@ int BilingualDynSuffixArray::LoadRawAlignments(InputFileStream& align)
Utils::splitToInt(line, vtmp, "- ");
assert(vtmp.size() % 2 == 0);
std::vector<short> vAlgn; // store as short ints for memory
iterate(vtmp, itr) vAlgn.push_back(short(*itr));
for (std::vector<int>::const_iterator itr = vtmp.begin();
itr != vtmp.end(); ++itr) {
vAlgn.push_back(short(*itr));
}
m_rawAlignments.push_back(vAlgn);
}
return m_rawAlignments.size();
@ -84,7 +89,10 @@ int BilingualDynSuffixArray::LoadRawAlignments(string& align) {
Utils::splitToInt(align, vtmp, "- ");
assert(vtmp.size() % 2 == 0);
vector<short> vAlgn; // store as short ints for memory
iterate(vtmp, itr) vAlgn.push_back(short(*itr));
for (std::vector<int>::const_iterator itr = vtmp.begin();
itr != vtmp.end(); ++itr) {
vAlgn.push_back(short(*itr));
}
m_rawAlignments.push_back(vAlgn);
return m_rawAlignments.size();
}
@ -165,7 +173,8 @@ void BilingualDynSuffixArray::CleanUp()
}
int BilingualDynSuffixArray::LoadCorpus(InputFileStream& corpus, const FactorList& factors,
const FactorDirection& direction, std::vector<wordID_t>& cArray, std::vector<wordID_t>& sntArray)
const FactorDirection& direction, std::vector<wordID_t>& cArray, std::vector<wordID_t>& sntArray,
Vocab* vocab)
{
std::string line, word;
int sntIdx(0);
@ -178,11 +187,12 @@ int BilingualDynSuffixArray::LoadCorpus(InputFileStream& corpus, const FactorLis
phrase.CreateFromString( factors, line, factorDelimiter);
// store words in vocabulary and corpus
for( size_t i = 0; i < phrase.GetSize(); ++i) {
cArray.push_back( m_vocab->GetWordID( phrase.GetWord(i) ) );
cArray.push_back( vocab->GetWordID(phrase.GetWord(i)) );
}
sntIdx += phrase.GetSize();
}
//cArray.push_back(m_vocab->GetkOOVWordID); // signify end of corpus
//cArray.push_back(vocab->GetkOOVWordID); // signify end of corpus
vocab->MakeClosed(); // avoid adding words
return cArray.size();
}
@ -192,8 +202,8 @@ bool BilingualDynSuffixArray::GetLocalVocabIDs(const Phrase& src, SAPhrase &outp
size_t phraseSize = src.GetSize();
for (size_t pos = 0; pos < phraseSize; ++pos) {
const Word &word = src.GetWord(pos);
wordID_t arrayId = m_vocab->GetWordID(word);
if (arrayId == m_vocab->GetkOOVWordID())
wordID_t arrayId = m_srcVocab->GetWordID(word);
if (arrayId == m_srcVocab->GetkOOVWordID())
{ // oov
return false;
}
@ -219,7 +229,7 @@ pair<float, float> BilingualDynSuffixArray::GetLexicalWeight(const PhrasePair& p
wordID_t srcWord = m_srcCorpus->at(srcIdx + m_srcSntBreaks[phrasepair.m_sntIndex]); // localIDs
const std::vector<int>& srcWordAlignments = alignment.alignedList.at(srcIdx);
if(srcWordAlignments.size() == 0) { // get p(NULL|src)
pair<wordID_t, wordID_t> wordpair = std::make_pair(srcWord, m_vocab->GetkOOVWordID());
pair<wordID_t, wordID_t> wordpair = std::make_pair(srcWord, m_srcVocab->GetkOOVWordID());
itrCache = m_wordPairCache.find(wordpair);
if(itrCache == m_wordPairCache.end()) { // if not in cache
CacheWordProbs(srcWord);
@ -251,10 +261,11 @@ pair<float, float> BilingualDynSuffixArray::GetLexicalWeight(const PhrasePair& p
for(int trgIdx = phrasepair.m_startTarget; trgIdx <= phrasepair.m_endTarget; ++trgIdx) {
float trgSumPairProbs(0);
wordID_t trgWord = m_trgCorpus->at(trgIdx + m_trgSntBreaks[phrasepair.m_sntIndex]);
iterate(targetProbs, trgItr) {
for (std::map<pair<wordID_t, wordID_t>, float>::const_iterator trgItr
= targetProbs.begin(); trgItr != targetProbs.end(); ++trgItr) {
if(trgItr->first.second == trgWord)
trgSumPairProbs += trgItr->second;
}
}
if(trgSumPairProbs == 0) continue; // currently don't store target-side SA
int noAligned = alignment.numberAligned.at(trgIdx);
float trgNormalizer = noAligned < 2 ? 1.0 : 1.0 / float(noAligned);
@ -280,7 +291,7 @@ void BilingualDynSuffixArray::CacheWordProbs(wordID_t srcWord) const
const std::vector<int> srcAlg = GetSentenceAlignment(sntIdx).alignedList.at(srcWrdSntIdx); // list of target words for this source word
//const std::vector<int>& srcAlg = m_alignments.at(sntIdx).alignedList.at(srcWrdSntIdx); // list of target words for this source word
if(srcAlg.size() == 0) {
++counts[m_vocab->GetkOOVWordID()]; // if not alligned then align to NULL word
++counts[m_srcVocab->GetkOOVWordID()]; // if not alligned then align to NULL word
++denom;
}
else { //get target words aligned to srcword in this sentence
@ -319,11 +330,10 @@ TargetPhrase* BilingualDynSuffixArray::GetMosesFactorIDs(const SAPhrase& phrase)
{
TargetPhrase* targetPhrase = new TargetPhrase(Output);
for(size_t i=0; i < phrase.words.size(); ++i) { // look up trg words
Word& word = m_vocab->GetWord( phrase.words[i]);
assert(word != m_vocab->GetkOOVWord());
Word& word = m_trgVocab->GetWord( phrase.words[i]);
assert(word != m_trgVocab->GetkOOVWord());
targetPhrase->AddWord(word);
}
// scoring
return targetPhrase;
}
@ -349,7 +359,7 @@ void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src,
int sntIndex = sntIndexes.at(snt); // get corpus index for sentence
if(sntIndex == -1) continue; // bad flag set by GetSntIndexes()
ExtractPhrases(sntIndex, wrdIndices[snt], sourceSize, phrasePairs);
cerr << "extracted " << phrasePairs.size() << endl;
//cerr << "extracted " << phrasePairs.size() << endl;
totalTrgPhrases += phrasePairs.size(); // keep track of count of each extracted phrase pair
std::vector<PhrasePair*>::iterator iterPhrasePair;
for (iterPhrasePair = phrasePairs.begin(); iterPhrasePair != phrasePairs.end(); ++iterPhrasePair) {
@ -384,7 +394,7 @@ void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src,
for(ritr = phraseScores.rbegin(); ritr != phraseScores.rend(); ++ritr) {
Scores scoreVector = ritr->first;
TargetPhrase *targetPhrase = GetMosesFactorIDs(*ritr->second);
cerr << *targetPhrase << endl;
//cerr << *targetPhrase << endl;
target.push_back( make_pair( scoreVector, targetPhrase));
if(target.size() == maxReturn) break;
@ -423,27 +433,33 @@ std::vector<unsigned> BilingualDynSuffixArray::SampleSelection(std::vector<unsig
void BilingualDynSuffixArray::addSntPair(string& source, string& target, string& alignment) {
vuint_t srcFactor, trgFactor;
cerr << "source, target, alignment = " << source << ", " << target << ", " << alignment << endl;
std::istringstream sss(source), sst(target), ssa(alignment);
string word;
const std::string& factorDelimiter = StaticData::Instance().GetFactorDelimiter();
const unsigned oldSrcCrpSize = m_srcCorpus->size(), oldTrgCrpSize = m_trgCorpus->size();
cerr << "old source corpus size = " << oldSrcCrpSize << "\told target size = " << oldTrgCrpSize << endl;
m_vocab->MakeOpen();
while(sss >> word) {
srcFactor.push_back(m_vocab->GetWordID(word)); // get vocab id
Phrase sphrase(Input);
sphrase.CreateFromString(m_inputFactors, source, factorDelimiter);
m_srcVocab->MakeOpen();
// store words in vocabulary and corpus
for(size_t i = 0; i < sphrase.GetSize(); ++i) {
srcFactor.push_back(m_srcVocab->GetWordID(sphrase.GetWord(i))); // get vocab id
cerr << "srcFactor[" << (srcFactor.size() - 1) << "] = " << srcFactor.back() << endl;
m_srcCorpus->push_back(srcFactor.back()); // add word to corpus
}
m_srcSntBreaks.push_back(oldSrcCrpSize); // former end of corpus is index of new sentence
while(sst >> word) {
trgFactor.push_back(m_vocab->GetWordID(word));
m_srcVocab->MakeClosed();
Phrase tphrase(Output);
tphrase.CreateFromString(m_outputFactors, target, factorDelimiter);
m_trgVocab->MakeOpen();
for(size_t i = 0; i < tphrase.GetSize(); ++i) {
trgFactor.push_back(m_trgVocab->GetWordID(tphrase.GetWord(i))); // get vocab id
cerr << "trgFactor[" << (trgFactor.size() - 1) << "] = " << trgFactor.back() << endl;
m_trgCorpus->push_back(trgFactor.back());
}
m_trgSntBreaks.push_back(oldTrgCrpSize);
m_srcSA->InsertFactor(&srcFactor, oldSrcCrpSize);
m_trgSA->InsertFactor(&trgFactor, oldTrgCrpSize);
m_srcSA->Insert(&srcFactor, oldSrcCrpSize);
m_trgSA->Insert(&trgFactor, oldTrgCrpSize);
LoadRawAlignments(alignment);
m_vocab->MakeClosed();
m_trgVocab->MakeClosed();
}
SentenceAlignment::SentenceAlignment(int sntIndex, int sourceSize, int targetSize)
:m_sntIndex(sntIndex)

View File

@ -62,8 +62,14 @@ public:
bool operator()(const Scores& s1, const Scores& s2) const {
float score1(1), score2(1);
int idx1(0), idx2(0);
iterate(s1, itr) score1 += (*itr * m_weights.at(idx1++));
iterate(s2, itr) score2 += (*itr * m_weights.at(idx2++));
for (Scores::const_iterator itr = s1.begin();
itr != s1.end(); ++itr) {
(*itr * m_weights.at(idx1++));
}
for (Scores::const_iterator itr = s2.begin();
itr != s2.end(); ++itr) {
(*itr * m_weights.at(idx2++));
}
return score1 < score2;
}
private:
@ -86,13 +92,12 @@ private:
DynSuffixArray* m_trgSA;
std::vector<wordID_t>* m_srcCorpus;
std::vector<wordID_t>* m_trgCorpus;
FactorMask m_inputFactors;
FactorMask m_outputFactors;
std::vector<FactorType> m_inputFactors;
std::vector<FactorType> m_outputFactors;
std::vector<unsigned> m_srcSntBreaks, m_trgSntBreaks;
Vocab* m_vocab;
Vocab* m_srcVocab, *m_trgVocab;
ScoresComp* m_scoreCmp;
std::vector<SentenceAlignment> m_alignments;
@ -102,7 +107,8 @@ private:
const size_t m_maxPhraseLength, m_maxSampleSize;
int LoadCorpus(InputFileStream&, const std::vector<FactorType>& factors,
const FactorDirection& direction, std::vector<wordID_t>&, std::vector<wordID_t>&);
const FactorDirection& direction, std::vector<wordID_t>&, std::vector<wordID_t>&,
Vocab*);
int LoadAlignments(InputFileStream& aligs);
int LoadRawAlignments(InputFileStream& aligs);
int LoadRawAlignments(string& aligs);

View File

@ -19,6 +19,7 @@
***********************************************************************/
#include <algorithm>
#include <iostream>
#include "../../moses/src/StaticData.h"
#include "ChartTranslationOptionList.h"
#include "ChartTranslationOption.h"

View File

@ -8,8 +8,6 @@
#include <typeinfo>
#include <stdint.h>
#define iterate(c, i) for(typeof(c.begin()) i = c.begin(); i != c.end(); ++i)
#define piterate(c, i) for(typeof(c->begin()) i = c->begin(); i != c->end(); ++i)
#define THREADED false
#define THREAD_MAX 2
#define MAX_NGRAM_ORDER 8

View File

@ -88,8 +88,10 @@ namespace Moses {
{
// then each vcb entry
*vcbout << m_ids2words.size() << "\n";
iterate(m_ids2words, iter)
for (Id2Word::const_iterator iter = m_ids2words.begin();
iter != m_ids2words.end(); ++iter) {
*vcbout << iter->second << "\t" << iter->first << "\n";
}
return true;
}
@ -134,10 +136,14 @@ namespace Moses {
}
void Vocab::PrintVocab()
{
iterate(m_ids2words, iter)
for (Id2Word::const_iterator iter = m_ids2words.begin();
iter != m_ids2words.end(); ++iter ) {
std::cerr << iter->second << "\t" << iter->first << "\n";
iterate(m_words2ids, iter)
}
for (Word2Id::const_iterator iter = m_words2ids.begin();
iter != m_words2ids.end(); ++iter ) {
std::cerr << iter->second << "\t" << iter->first << "\n";
}
}
} //end namespace

View File

@ -85,7 +85,7 @@ int DynSuffixArray::LastFirstFunc(unsigned L_idx) {
return fIdx;
}
void DynSuffixArray::InsertFactor(vuint_t* newSent, unsigned newIndex) {
void DynSuffixArray::Insert(vuint_t* newSent, unsigned newIndex) {
// for sentences
//stages 1, 2, 4 stay same from 1char case
//(use last word of new text in step 2 and save Ltmp until last insert?)
@ -109,12 +109,13 @@ void DynSuffixArray::InsertFactor(vuint_t* newSent, unsigned newIndex) {
int theLWord = (j == 0 ? Ltmp : newSent->at(j-1));
m_L->insert(m_L->begin() + kprime, theLWord);
piterate(m_SA, itr)
for (vuint_t::iterator itr = m_SA->begin(); itr != m_SA->end(); ++itr) {
if(*itr >= newIndex) ++(*itr);
}
m_SA->insert(m_SA->begin() + kprime, newIndex);
piterate(m_ISA, itr)
for (vuint_t::iterator itr = m_ISA->begin(); itr != m_ISA->end(); ++itr) {
if((int)*itr >= kprime) ++(*itr);
}
m_ISA->insert(m_ISA->begin() + newIndex, kprime);
k = kprime;
@ -152,7 +153,7 @@ void DynSuffixArray::Reorder(unsigned j, unsigned jprime) {
}
}
void DynSuffixArray::DeleteFactor(unsigned index, unsigned num2del) {
void DynSuffixArray::Delete(unsigned index, unsigned num2del) {
int ltmp = m_L->at(m_ISA->at(index));
int true_pos = LastFirstFunc(m_ISA->at(index)); // track cycle shift (newIndex - 1)
for(size_t q = 0; q < num2del; ++q) {
@ -164,19 +165,21 @@ void DynSuffixArray::DeleteFactor(unsigned index, unsigned num2del) {
m_F->erase(m_F->begin() + row);
m_ISA->erase(m_ISA->begin() + index); // order is important
piterate(m_ISA, itr)
for (vuint_t::iterator itr = m_ISA->begin(); itr != m_ISA->end(); ++itr) {
if((int)*itr > row) --(*itr);
}
m_SA->erase(m_SA->begin() + row);
piterate(m_SA, itr)
for (vuint_t::iterator itr = m_SA->begin(); itr != m_SA->end(); ++itr) {
if(*itr > index) --(*itr);
}
}
m_L->at(m_ISA->at(index))= ltmp;
Reorder(LastFirstFunc(m_ISA->at(index)), true_pos);
PrintAuxArrays();
}
void DynSuffixArray::SubstituteFactor(vuint_t* newSents, unsigned newIndex) {
void DynSuffixArray::Substitute(vuint_t* newSents, unsigned newIndex) {
std::cerr << "NEEDS TO IMPLEMENT SUBSITITUTE FACTOR\n";
return;
}

View File

@ -22,9 +22,9 @@ public:
bool GetCorpusIndex(const vuint_t*, vuint_t*);
void Load(FILE*);
void Save(FILE*);
void InsertFactor(vuint_t*, unsigned);
void DeleteFactor(unsigned, unsigned);
void SubstituteFactor(vuint_t*, unsigned);
void Insert(vuint_t*, unsigned);
void Delete(unsigned, unsigned);
void Substitute(vuint_t*, unsigned);
private:
vuint_t* m_SA;

View File

@ -22,10 +22,12 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <cassert>
#include <limits>
#include <iostream>
#include <memory>
#include <sstream>
#include "FFState.h"
#include "LanguageModel.h"
#include "LanguageModelImplementation.h"
#include "TypeDef.h"
#include "Util.h"
#include "Manager.h"
@ -37,9 +39,34 @@ using namespace std;
namespace Moses
{
LanguageModel::LanguageModel() : StatefulFeatureFunction("LM") {}
LanguageModel::~LanguageModel() {}
LanguageModel::LanguageModel(LanguageModelImplementation *implementation) :
StatefulFeatureFunction("LM"),
m_implementation(implementation)
{
#ifndef WITH_THREADS
// ref counting handled by boost otherwise
m_implementation->IncrementReferenceCount();
#endif
}
LanguageModel::LanguageModel(LanguageModel *loadedLM) :
StatefulFeatureFunction("LM"),
m_implementation(loadedLM->m_implementation)
{
#ifndef WITH_THREADS
// ref counting handled by boost otherwise
m_implementation->IncrementReferenceCount();
#endif
}
LanguageModel::~LanguageModel()
{
#ifndef WITH_THREADS
if(m_implementation->DecrementReferenceCount() == 0)
delete m_implementation;
#endif
}
// don't inline virtual funcs...
size_t LanguageModel::GetNumScoreComponents() const
@ -56,10 +83,12 @@ void LanguageModel::CalcScore(const Phrase &phrase
ngramScore = 0;
size_t phraseSize = phrase.GetSize();
if (!phraseSize) return;
vector<const Word*> contextFactor;
contextFactor.reserve(m_nGramOrder);
contextFactor.reserve(GetNGramOrder());
std::auto_ptr<FFState> state(m_implementation->NewState((phrase.GetWord(0) == m_implementation->GetSentenceStartArray()) ?
m_implementation->GetBeginSentenceState() : m_implementation->GetNullContextState()));
size_t currPos = 0;
while (currPos < phraseSize)
{
@ -67,28 +96,32 @@ void LanguageModel::CalcScore(const Phrase &phrase
if (word.IsNonTerminal())
{ // do nothing. reset ngram. needed to score targbet phrases during pt loading in chart decoding
contextFactor.clear();
if (!contextFactor.empty()) {
// TODO: state operator= ?
state.reset(m_implementation->NewState(m_implementation->GetNullContextState()));
contextFactor.clear();
}
}
else
{
ShiftOrPush(contextFactor, word);
assert(contextFactor.size() <= m_nGramOrder);
assert(contextFactor.size() <= GetNGramOrder());
if (word == GetSentenceStartArray())
if (word == m_implementation->GetSentenceStartArray())
{ // do nothing, don't include prob for <s> unigram
assert(currPos == 0);
}
else
{
float partScore = GetValue(contextFactor);
float partScore = m_implementation->GetValueGivenState(contextFactor, *state);
fullScore += partScore;
if (contextFactor.size() == m_nGramOrder)
if (contextFactor.size() == GetNGramOrder())
ngramScore += partScore;
}
}
currPos++;
}
}
}
void LanguageModel::CalcScoreChart(const Phrase &phrase
@ -100,10 +133,12 @@ void LanguageModel::CalcScoreChart(const Phrase &phrase
ngramScore = 0;
size_t phraseSize = phrase.GetSize();
if (!phraseSize) return;
vector<const Word*> contextFactor;
contextFactor.reserve(m_nGramOrder);
contextFactor.reserve(GetNGramOrder());
std::auto_ptr<FFState> state(m_implementation->NewState((phrase.GetWord(0) == m_implementation->GetSentenceStartArray()) ?
m_implementation->GetBeginSentenceState() : m_implementation->GetNullContextState()));
size_t currPos = 0;
while (currPos < phraseSize)
{
@ -111,17 +146,17 @@ void LanguageModel::CalcScoreChart(const Phrase &phrase
assert(!word.IsNonTerminal());
ShiftOrPush(contextFactor, word);
assert(contextFactor.size() <= m_nGramOrder);
assert(contextFactor.size() <= GetNGramOrder());
if (word == GetSentenceStartArray())
if (word == m_implementation->GetSentenceStartArray())
{ // do nothing, don't include prob for <s> unigram
assert(currPos == 0);
}
else
{
float partScore = GetValue(contextFactor);
float partScore = m_implementation->GetValueGivenState(contextFactor, *state);
if (contextFactor.size() == m_nGramOrder)
if (contextFactor.size() == GetNGramOrder())
ngramScore += partScore;
else
beginningBitsOnly += partScore;
@ -130,46 +165,26 @@ void LanguageModel::CalcScoreChart(const Phrase &phrase
currPos++;
}
}
void LanguageModel::ShiftOrPush(vector<const Word*> &contextFactor, const Word &word) const
{
if (contextFactor.size() < m_nGramOrder)
if (contextFactor.size() < GetNGramOrder())
{
contextFactor.push_back(&word);
}
else
{ // shift
for (size_t currNGramOrder = 0 ; currNGramOrder < m_nGramOrder - 1 ; currNGramOrder++)
for (size_t currNGramOrder = 0 ; currNGramOrder < GetNGramOrder() - 1 ; currNGramOrder++)
{
contextFactor[currNGramOrder] = contextFactor[currNGramOrder + 1];
}
contextFactor[m_nGramOrder - 1] = &word;
contextFactor[GetNGramOrder() - 1] = &word;
}
}
LanguageModel::State LanguageModel::GetState(const std::vector<const Word*> &contextFactor, unsigned int* len) const
{
State state;
unsigned int dummy;
if (!len) len = &dummy;
GetValue(contextFactor,&state,len);
return state;
}
struct LMState : public FFState {
const void* lmstate;
LMState(const void* lms) { lmstate = lms; }
virtual int Compare(const FFState& o) const {
const LMState& other = static_cast<const LMState&>(o);
if (other.lmstate > lmstate) return 1;
else if (other.lmstate < lmstate) return -1;
return 0;
}
};
const FFState* LanguageModel::EmptyHypothesisState(const InputType &/*input*/) const {
return new LMState(NULL);
// This is actually correct. The empty _hypothesis_ has <s> in it. Phrases use m_emptyContextState.
return m_implementation->NewState(m_implementation->GetBeginSentenceState());
}
FFState* LanguageModel::Evaluate(
@ -180,71 +195,76 @@ FFState* LanguageModel::Evaluate(
// phrase boundary. Phrase-internal scores are taken directly from the
// translation option. In the unigram case, there is no overlap, so we don't
// need to do anything.
if(m_nGramOrder <= 1)
if(GetNGramOrder() <= 1)
return NULL;
clock_t t=0;
IFVERBOSE(2) { t = clock(); } // track time
const void* prevlm = ps ? (static_cast<const LMState *>(ps)->lmstate) : NULL;
LMState* res = new LMState(prevlm);
if (hypo.GetCurrTargetLength() == 0)
return res;
return ps ? m_implementation->NewState(ps) : NULL;
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
// 1st n-gram
vector<const Word*> contextFactor(m_nGramOrder);
vector<const Word*> contextFactor(GetNGramOrder());
size_t index = 0;
for (int currPos = (int) startPos - (int) m_nGramOrder + 1 ; currPos <= (int) startPos ; currPos++)
for (int currPos = (int) startPos - (int) GetNGramOrder() + 1 ; currPos <= (int) startPos ; currPos++)
{
if (currPos >= 0)
contextFactor[index++] = &hypo.GetWord(currPos);
else
contextFactor[index++] = &GetSentenceStartArray();
else
{
contextFactor[index++] = &m_implementation->GetSentenceStartArray();
}
}
float lmScore = GetValue(contextFactor);
//cout<<"context factor: "<<GetValue(contextFactor)<<endl;
unsigned int statelen;
FFState *res = m_implementation->NewState(ps);
float lmScore = ps ? m_implementation->GetValueGivenState(contextFactor, *res, &statelen) : m_implementation->GetValueForgotState(contextFactor, *res, &statelen);
// main loop
size_t endPos = std::min(startPos + m_nGramOrder - 2
size_t endPos = std::min(startPos + GetNGramOrder() - 2
, currEndPos);
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++)
{
// shift all args down 1 place
for (size_t i = 0 ; i < m_nGramOrder - 1 ; i++)
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
// add last factor
contextFactor.back() = &hypo.GetWord(currPos);
lmScore += GetValue(contextFactor);
lmScore += m_implementation->GetValueGivenState(contextFactor, *res, &statelen);
}
// end of sentence
if (hypo.IsSourceCompleted())
{
const size_t size = hypo.GetSize();
contextFactor.back() = &GetSentenceEndArray();
contextFactor.back() = &m_implementation->GetSentenceEndArray();
for (size_t i = 0 ; i < m_nGramOrder - 1 ; i ++)
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i ++)
{
int currPos = (int)(size - m_nGramOrder + i + 1);
int currPos = (int)(size - GetNGramOrder() + i + 1);
if (currPos < 0)
contextFactor[i] = &GetSentenceStartArray();
contextFactor[i] = &m_implementation->GetSentenceStartArray();
else
contextFactor[i] = &hypo.GetWord((size_t)currPos);
}
lmScore += GetValue(contextFactor, &res->lmstate);
lmScore += m_implementation->GetValueForgotState(contextFactor, *res);
} else {
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
for (size_t i = 0 ; i < m_nGramOrder - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
contextFactor.back() = &hypo.GetWord(currPos);
if (endPos < currEndPos){
//need to get the LM state (otherwise the last LM state is fine)
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
for (size_t i = 0 ; i < GetNGramOrder() - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
contextFactor.back() = &hypo.GetWord(currPos);
}
m_implementation->GetState(contextFactor, *res);
}
res->lmstate = GetState(contextFactor);
}
out->PlusEquals(this, lmScore);
IFVERBOSE(2) { hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t ); }
IFVERBOSE(2) { hypo.GetManager().GetSentenceStats().AddTimeCalcLM( clock()-t ); }
return res;
}

View File

@ -29,6 +29,11 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "Util.h"
#include "FeatureFunction.h"
#include "Word.h"
#include "LanguageModelImplementation.h"
#ifdef WITH_THREADS
#include <boost/shared_ptr.hpp>
#endif
namespace Moses
{
@ -41,36 +46,38 @@ class Phrase;
class LanguageModel : public StatefulFeatureFunction
{
protected:
std::string m_filePath; //! for debugging purposes
size_t m_nGramOrder; //! max n-gram length contained in this LM
Word m_sentenceStartArray, m_sentenceEndArray; //! Contains factors which represents the beging and end words for this LM.
//! Usually <s> and </s>
#ifdef WITH_THREADS
// if we have threads, we also have boost and can let it handle thread-safe reference counting
boost::shared_ptr<LanguageModelImplementation> m_implementation;
#else
LanguageModelImplementation *m_implementation;
#endif
public:
void ShiftOrPush(std::vector<const Word*> &contextFactor, const Word &word) const;
/** constructor to be called by inherited class
/**
* Create a new language model
*/
LanguageModel();
LanguageModel(LanguageModelImplementation* implementation);
public:
/* Returned from LM implementations which points at the state used. For example, if a trigram score was requested
* but the LM backed off to using the trigram, the State pointer will point to the bigram.
* Used for more agressive pruning of hypothesis
/**
* Create a new language model reusing an already loaded implementation
*/
typedef const void* State;
LanguageModel(LanguageModel *implementation);
virtual ~LanguageModel();
//! see ScoreProducer.h
size_t GetNumScoreComponents() const;
//! Single or multi-factor
virtual LMType GetLMType() const = 0;
/* whether this LM can be used on a particular phrase.
* Should return false if phrase size = 0 or factor types required don't exists
*/
virtual bool Useable(const Phrase &phrase) const = 0;
bool Useable(const Phrase &phrase) const {
return m_implementation->Useable(phrase);
}
/* calc total unweighted LM score of this phrase and return score via arguments.
* Return scores should always be in natural log, regardless of representation with LM implementation.
@ -79,43 +86,23 @@ public:
* \param fullScore scores of all unigram, bigram... of contiguous n-gram of the phrase
* \param ngramScore score of only n-gram of order m_nGramOrder
*/
void CalcScore(const Phrase &phrase
, float &fullScore
, float &ngramScore) const;
void CalcScore(
const Phrase &phrase,
float &fullScore,
float &ngramScore) const;
void CalcScoreChart(const Phrase &phrase
, float &beginningBitsOnly
, float &ngramScore) const;
void CalcScoreChart(
const Phrase &phrase,
float &beginningBitsOnly,
float &ngramScore) const;
/* get score of n-gram. n-gram should not be bigger than m_nGramOrder
* Specific implementation can return State and len data to be used in hypothesis pruning
* \param contextFactor n-gram to be scored
* \param finalState state used by LM. Return arg
* \param len ???
*/
virtual float GetValue(const std::vector<const Word*> &contextFactor
, State* finalState = 0
, unsigned int* len = 0) const = 0;
//! get State for a particular n-gram
State GetState(const std::vector<const Word*> &contextFactor, unsigned int* len = 0) const;
//! max n-gram order of LM
size_t GetNGramOrder() const
{
return m_nGramOrder;
return m_implementation->GetNGramOrder();
}
//! Contains factors which represents the beging and end words for this LM. Usually <s> and </s>
const Word &GetSentenceStartArray() const
{
return m_sentenceStartArray;
}
const Word &GetSentenceEndArray() const
{
return m_sentenceEndArray;
}
//! scoring weight. Shouldn't this now be superceded by ScoreProducer???
float GetWeight() const;
std::string GetScoreProducerWeightShortName() const
@ -123,9 +110,15 @@ public:
return "lm";
}
//! overrideable funtions for IRST LM to cleanup. Maybe something to do with on demand/cache loading/unloading
virtual void InitializeBeforeSentenceProcessing(){};
virtual void CleanUpAfterSentenceProcessing() {};
void InitializeBeforeSentenceProcessing()
{
m_implementation->InitializeBeforeSentenceProcessing();
}
void CleanUpAfterSentenceProcessing()
{
m_implementation->CleanUpAfterSentenceProcessing();
}
virtual const FFState* EmptyHypothesisState(const InputType &input) const;

View File

@ -1,68 +0,0 @@
// $Id: LanguageModel.h 3078 2010-04-08 17:16:10Z hieuhoang1972 $
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2006 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#ifndef moses_LanguageModelDelegate_h
#define moses_LanguageModelDelegate_h
#include "LanguageModelSingleFactor.h"
namespace Moses {
//! A language model which delegates all its calculation to another language model.
//! Used when you want to have the same language model with two different weights.
class LanguageModelDelegate: public LanguageModelSingleFactor {
public:
LanguageModelDelegate(LanguageModelSingleFactor* delegate) : m_delegate(delegate)
{
m_nGramOrder = m_delegate->GetNGramOrder();
m_factorType = m_delegate->GetFactorType();
m_sentenceStart = m_delegate->GetSentenceStart();
m_sentenceEnd = m_delegate->GetSentenceEnd();
m_sentenceStartArray = m_delegate->GetSentenceStartArray();
m_sentenceEndArray = m_delegate->GetSentenceEndArray();
}
virtual bool Load(const std::string &
, FactorType
, size_t)
{
/* do nothing */
return true;
}
virtual float GetValue(const std::vector<const Word*> &contextFactor, State* finalState, unsigned int* len) const {
return m_delegate->GetValue(contextFactor, finalState, len);
}
private:
LanguageModelSingleFactor* m_delegate;
};
}
#endif

View File

@ -44,6 +44,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
# include "LanguageModelKen.h"
#endif
#include "LanguageModel.h"
#include "LanguageModelInternal.h"
#include "LanguageModelSkip.h"
#include "LanguageModelJoint.h"
@ -62,7 +63,7 @@ namespace LanguageModelFactory
, const std::string &languageModelFile
, int dub)
{
LanguageModel *lm = NULL;
LanguageModelImplementation *lm = NULL;
switch (lmImplementation)
{
case RandLM:
@ -95,7 +96,12 @@ namespace LanguageModelFactory
break;
case Ken:
#ifdef LM_KEN
lm = new LanguageModelKen();
lm = ConstructKenLM(languageModelFile, false);
#endif
break;
case LazyKen:
#ifdef LM_KEN
lm = ConstructKenLM(languageModelFile, true);
#endif
break;
case Joint:
@ -144,7 +150,7 @@ namespace LanguageModelFactory
}
}
return lm;
return new LanguageModel(lm);
}
}

View File

@ -119,16 +119,14 @@ bool LanguageModelIRST::Load(const std::string &filePath,
m_unknownId = m_lmtb->getDict()->oovcode(); // at the level of micro tags
CreateFactors(factorCollection);
VERBOSE(1, "IRST: m_unknownId=" << m_unknownId << std::endl);
VERBOSE(0, "IRST: m_unknownId=" << m_unknownId << std::endl);
//install caches
m_lmtb->init_probcache();
m_lmtb->init_statecache();
m_lmtb->init_lmtcaches(m_lmtb->maxlevel()>2?m_lmtb->maxlevel()-1:2);
m_lmtb->init_caches(m_lmtb->maxlevel()>2?m_lmtb->maxlevel()-1:2);
if (m_lmtb_dub >0) m_lmtb->setlogOOVpenalty(m_lmtb_dub);
free(filenamesOrig);
free(filenamesOrig);
return true;
}
@ -180,6 +178,12 @@ int LanguageModelIRST::GetLmID( const std::string &str ) const
return m_lmtb->getDict()->encode( str.c_str() ); // at the level of micro tags
}
int LanguageModelIRST::GetLmID( const Factor *factor ) const
{
size_t factorId = factor->GetId();
return ( factorId >= m_lmIdLookup.size()) ? m_unknownId : m_lmIdLookup[factorId];
}
float LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, State* finalState, unsigned int* len) const
{
unsigned int dummy;
@ -188,31 +192,28 @@ float LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, Stat
// set up context
size_t count = contextFactor.size();
m_lmtb_ng->size=0;
if (count< (size_t)(m_lmtb_size-1)) m_lmtb_ng->pushc(m_lmtb_sentenceEnd);
if (count< (size_t)m_lmtb_size) m_lmtb_ng->pushc(m_lmtb_sentenceStart);
if (count < 0) { cerr << "ERROR count < 0\n"; exit(100); };
for (size_t i = 0 ; i < count ; i++)
{
//int lmId = GetLmID((*contextFactor[i])[factorType]);
#ifdef DEBUG
cout << "i=" << i << " -> " << (*contextFactor[i])[factorType]->GetString() << "\n";
#endif
int lmId = GetLmID((*contextFactor[i])[factorType]->GetString());
// cerr << (*contextFactor[i])[factorType]->GetString() << " = " << lmId;
m_lmtb_ng->pushc(lmId);
}
if (finalState){
*finalState=(State *)m_lmtb->cmaxsuffptr(*m_lmtb_ng);
// back off stats not currently available
*len = 0;
}
// set up context
int codes[MAX_NGRAM_SIZE];
size_t idx=0;
//fill the farthest positions with at most ONE sentenceEnd symbol and at most ONE sentenceEnd symbol, if "empty" positions are available
//so that the vector looks like = "</s> <s> context_word context_word" for a two-word context and a LM of order 5
if (count < (size_t) (m_lmtb_size-1)) codes[idx++] = m_lmtb_sentenceEnd;
if (count < (size_t) m_lmtb_size) codes[idx++] = m_lmtb_sentenceStart;
for (size_t i = 0 ; i < count ; i++)
codes[idx++] = GetLmID((*contextFactor[i])[factorType]);
float prob;
char* msp = NULL;
unsigned int ilen;
prob = m_lmtb->clprob(codes,idx,NULL,NULL,&msp,&ilen);
if (finalState) *finalState=(State *) msp;
if (len) *len = ilen;
float prob = m_lmtb->clprob(*m_lmtb_ng);
return TransformLMScore(prob);
}

View File

@ -40,7 +40,7 @@ class Phrase;
/** Implementation of single factor LM using IRST's code.
* This is the default LM for Moses and is available from the same sourceforge repository
*/
class LanguageModelIRST : public LanguageModelSingleFactor
class LanguageModelIRST : public LanguageModelPointerState
{
protected:
std::vector<int> m_lmIdLookup;
@ -59,11 +59,7 @@ protected:
void CreateFactors(FactorCollection &factorCollection);
int GetLmID( const std::string &str ) const;
int GetLmID( const Factor *factor ) const{
size_t factorId = factor->GetId();
return ( factorId >= m_lmIdLookup.size()) ? m_unknownId : m_lmIdLookup[factorId];
};
int GetLmID( const Factor *factor ) const;
public:
LanguageModelIRST(int dub);

View File

@ -0,0 +1,56 @@
// $Id$
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2006 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <cassert>
#include <limits>
#include <iostream>
#include <memory>
#include <sstream>
#include "FFState.h"
#include "LanguageModelImplementation.h"
#include "TypeDef.h"
#include "Util.h"
#include "Manager.h"
#include "FactorCollection.h"
#include "Phrase.h"
#include "StaticData.h"
using namespace std;
namespace Moses
{
float LanguageModelImplementation::GetValueGivenState(
const std::vector<const Word*> &contextFactor,
FFState &state,
unsigned int* len) const
{
return GetValueForgotState(contextFactor, state, len);
}
void LanguageModelImplementation::GetState(
const std::vector<const Word*> &contextFactor,
FFState &state) const
{
GetValueForgotState(contextFactor, state, NULL);
}
}

View File

@ -0,0 +1,136 @@
// $Id$
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2006 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#ifndef moses_LanguageModelImplementation_h
#define moses_LanguageModelImplementation_h
#include <string>
#include <vector>
#include "Factor.h"
#include "TypeDef.h"
#include "Util.h"
#include "FeatureFunction.h"
#include "Word.h"
namespace Moses
{
class FactorCollection;
class Factor;
class Phrase;
//! Abstract base class which represent a language model on a contiguous phrase
class LanguageModelImplementation
{
#ifndef WITH_THREADS
protected:
/** constructor to be called by inherited class
*/
LanguageModelImplementation() : m_referenceCount(0) {}
private:
// ref counting handled by boost if we have threads
unsigned int m_referenceCount;
#else
// default constructor is ok
#endif
protected:
std::string m_filePath; //! for debugging purposes
size_t m_nGramOrder; //! max n-gram length contained in this LM
Word m_sentenceStartArray, m_sentenceEndArray; //! Contains factors which represents the beging and end words for this LM.
//! Usually <s> and </s>
public:
virtual ~LanguageModelImplementation() {}
//! Single or multi-factor
virtual LMType GetLMType() const = 0;
/* whether this LM can be used on a particular phrase.
* Should return false if phrase size = 0 or factor types required don't exists
*/
virtual bool Useable(const Phrase &phrase) const = 0;
/* get score of n-gram. n-gram should not be bigger than m_nGramOrder
* Specific implementation can return State and len data to be used in hypothesis pruning
* \param contextFactor n-gram to be scored
* \param state LM state. Input and output. state must be initialized. If state isn't initialized, you want GetValueWithoutState.
* \param len If non-null, the n-gram length is written here.
*/
virtual float GetValueGivenState(const std::vector<const Word*> &contextFactor, FFState &state, unsigned int* len = 0) const;
// Like GetValueGivenState but state may not be initialized (however it is non-NULL).
// For example, state just came from NewState(NULL).
virtual float GetValueForgotState(
const std::vector<const Word*> &contextFactor,
FFState &outState,
unsigned int* len = 0) const = 0;
//! get State for a particular n-gram. We don't care what the score is.
// This is here so models can implement a shortcut to GetValueAndState.
virtual void GetState(
const std::vector<const Word*> &contextFactor,
FFState &outState) const;
virtual FFState *GetNullContextState() const = 0;
virtual FFState *GetBeginSentenceState() const = 0;
virtual FFState *NewState(const FFState *from = NULL) const = 0;
//! max n-gram order of LM
size_t GetNGramOrder() const
{
return m_nGramOrder;
}
//! Contains factors which represents the beging and end words for this LM. Usually <s> and </s>
const Word &GetSentenceStartArray() const
{
return m_sentenceStartArray;
}
const Word &GetSentenceEndArray() const
{
return m_sentenceEndArray;
}
//! overrideable funtions for IRST LM to cleanup. Maybe something to do with on demand/cache loading/unloading
virtual void InitializeBeforeSentenceProcessing(){};
virtual void CleanUpAfterSentenceProcessing() {};
#ifndef WITH_THREADS
// ref counting handled by boost otherwise
unsigned int IncrementReferenceCount()
{
return ++m_referenceCount;
}
unsigned int DecrementReferenceCount()
{
return --m_referenceCount;
}
#endif
};
}
#endif

View File

@ -10,7 +10,6 @@ using namespace std;
namespace Moses
{
bool LanguageModelInternal::Load(const std::string &filePath
, FactorType factorType
, size_t nGramOrder)

View File

@ -9,7 +9,7 @@ namespace Moses
/** Guaranteed cross-platform LM implementation designed to mimic LM used in regression tests
*/
class LanguageModelInternal : public LanguageModelSingleFactor
class LanguageModelInternal : public LanguageModelPointerState
{
protected:
std::vector<const NGramNode*> m_lmIdLookup;
@ -26,7 +26,6 @@ protected:
float GetValue(const Factor *factor0, const Factor *factor1, const Factor *factor2, State* finalState) const;
public:
LanguageModelInternal() {}
bool Load(const std::string &filePath
, FactorType factorType
, size_t nGramOrder);

View File

@ -82,7 +82,7 @@ public:
return m_lmImpl->Load(filePath, m_implFactor, nGramOrder);
}
float GetValue(const std::vector<const Word*> &contextFactor, State* finalState = NULL, unsigned int* len = NULL) const
float GetValueForgotState(const std::vector<const Word*> &contextFactor, FFState &outState, unsigned int* len = NULL) const
{
if (contextFactor.size() == 0)
{
@ -117,12 +117,27 @@ public:
}
// calc score on chunked phrase
float ret = m_lmImpl->GetValue(jointContext, finalState, len);
float ret = m_lmImpl->GetValueForgotState(jointContext, outState, len);
RemoveAllInColl(jointContext);
return ret;
}
FFState *GetNullContextState() const
{
return m_lmImpl->GetNullContextState();
}
FFState *GetBeginSentenceState() const
{
return m_lmImpl->GetBeginSentenceState();
}
FFState *NewState(const FFState *from) const
{
return m_lmImpl->NewState(from);
}
};

View File

@ -20,12 +20,14 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <cassert>
#include <limits>
#include <cstring>
#include <iostream>
#include <fstream>
#include "lm/ngram.hh"
#include "lm/binary_format.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
#include "LanguageModelKen.h"
#include "FFState.h"
#include "TypeDef.h"
#include "Util.h"
#include "FactorCollection.h"
@ -38,26 +40,103 @@ using namespace std;
namespace Moses
{
LanguageModelKen::LanguageModelKen()
:m_ngram(NULL)
namespace {
class MappingBuilder : public lm::ngram::EnumerateVocab {
public:
MappingBuilder(FactorType factorType, FactorCollection &factorCollection, std::vector<lm::WordIndex> &mapping)
: m_factorType(factorType), m_factorCollection(factorCollection), m_mapping(mapping) {}
void Add(lm::WordIndex index, const StringPiece &str) {
m_word.assign(str.data(), str.size());
std::size_t factorId = m_factorCollection.AddFactor(Output, m_factorType, m_word)->GetId();
if (m_mapping.size() <= factorId) {
// 0 is <unk> :-)
m_mapping.resize(factorId + 1);
}
m_mapping[factorId] = index;
}
private:
std::string m_word;
FactorType m_factorType;
FactorCollection &m_factorCollection;
std::vector<lm::WordIndex> &m_mapping;
};
struct KenLMState : public FFState {
lm::ngram::State state;
int Compare(const FFState &o) const {
const KenLMState &other = static_cast<const KenLMState &>(o);
if (state.valid_length_ < other.state.valid_length_) return -1;
if (state.valid_length_ > other.state.valid_length_) return 1;
return std::memcmp(state.history_, other.state.history_, sizeof(lm::WordIndex) * state.valid_length_);
}
};
/** Implementation of single factor LM using Ken's code.
*/
template <class Model> class LanguageModelKen : public LanguageModelSingleFactor
{
private:
Model *m_ngram;
std::vector<lm::WordIndex> m_lmIdLookup;
bool m_lazy;
FFState *m_nullContextState;
FFState *m_beginSentenceState;
void TranslateIDs(const std::vector<const Word*> &contextFactor, lm::WordIndex *indices) const;
public:
LanguageModelKen(bool lazy);
~LanguageModelKen();
bool Load(const std::string &filePath
, FactorType factorType
, size_t nGramOrder);
float GetValueGivenState(const std::vector<const Word*> &contextFactor, FFState &state, unsigned int* len = 0) const;
float GetValueForgotState(const std::vector<const Word*> &contextFactor, FFState &outState, unsigned int* len=0) const;
void GetState(const std::vector<const Word*> &contextFactor, FFState &outState) const;
FFState *GetNullContextState() const;
FFState *GetBeginSentenceState() const;
FFState *NewState(const FFState *from = NULL) const;
lm::WordIndex GetLmID(const std::string &str) const;
void CleanUpAfterSentenceProcessing() {}
void InitializeBeforeSentenceProcessing() {}
};
template <class Model> void LanguageModelKen<Model>::TranslateIDs(const std::vector<const Word*> &contextFactor, lm::WordIndex *indices) const
{
FactorType factorType = GetFactorType();
// set up context
for (size_t i = 0 ; i < contextFactor.size(); i++)
{
std::size_t factor = contextFactor[i]->GetFactor(factorType)->GetId();
lm::WordIndex new_word = (factor >= m_lmIdLookup.size() ? 0 : m_lmIdLookup[factor]);
indices[contextFactor.size() - 1 - i] = new_word;
}
}
template <class Model> LanguageModelKen<Model>::LanguageModelKen(bool lazy)
:m_ngram(NULL), m_lazy(lazy)
{
}
LanguageModelKen::~LanguageModelKen()
template <class Model> LanguageModelKen<Model>::~LanguageModelKen()
{
delete m_ngram;
}
bool LanguageModelKen::Load(const std::string &filePath,
template <class Model> bool LanguageModelKen<Model>::Load(const std::string &filePath,
FactorType factorType,
size_t nGramOrder)
size_t /*nGramOrder*/)
{
cerr << "In LanguageModelKen::Load: nGramOrder = " << nGramOrder << " will be ignored. Using whatever the file has.\n";
m_ngram = new lm::ngram::Model(filePath.c_str());
m_factorType = factorType;
m_nGramOrder = m_ngram->Order();
m_filePath = filePath;
FactorCollection &factorCollection = FactorCollection::Instance();
@ -65,43 +144,43 @@ bool LanguageModelKen::Load(const std::string &filePath,
m_sentenceStartArray[m_factorType] = m_sentenceStart;
m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_);
m_sentenceEndArray[m_factorType] = m_sentenceEnd;
MappingBuilder builder(m_factorType, factorCollection, m_lmIdLookup);
lm::ngram::Config config;
IFVERBOSE(1) {
config.messages = &std::cerr;
} else {
config.messages = NULL;
}
config.enumerate_vocab = &builder;
config.load_method = m_lazy ? util::LAZY : util::POPULATE_OR_READ;
m_ngram = new Model(filePath.c_str(), config);
m_nGramOrder = m_ngram->Order();
KenLMState *tmp = new KenLMState();
tmp->state = m_ngram->NullContextState();
m_nullContextState = tmp;
tmp = new KenLMState();
tmp->state = m_ngram->BeginSentenceState();
m_beginSentenceState = tmp;
return true;
}
/* get score of n-gram. n-gram should not be bigger than m_nGramOrder
* Specific implementation can return State and len data to be used in hypothesis pruning
* \param contextFactor n-gram to be scored
* \param finalState state used by LM. Return arg
* \param len ???
*/
float LanguageModelKen::GetValue(const vector<const Word*> &contextFactor, State* finalState, unsigned int* len) const
template <class Model> float LanguageModelKen<Model>::GetValueGivenState(const std::vector<const Word*> &contextFactor, FFState &state, unsigned int* len) const
{
FactorType factorType = GetFactorType();
size_t count = contextFactor.size();
assert(count <= GetNGramOrder());
if (count == 0)
if (contextFactor.empty())
{
finalState = NULL;
return 0;
}
// set up context
vector<lm::WordIndex> ngramId(count);
for (size_t i = 0 ; i < count; i++)
{
const Factor *factor = contextFactor[i]->GetFactor(factorType);
const string &word = factor->GetString();
// TODO(hieuhoang1972): precompute this.
ngramId[i] = m_ngram->GetVocabulary().Index(word);
}
lm::ngram::State &realState = static_cast<KenLMState&>(state).state;
std::size_t factor = contextFactor.back()->GetFactor(GetFactorType())->GetId();
lm::WordIndex new_word = (factor >= m_lmIdLookup.size() ? 0 : m_lmIdLookup[factor]);
lm::ngram::State copied(realState);
lm::FullScoreReturn ret(m_ngram->FullScore(copied, new_word, realState));
// TODO(hieuhoang1972): use my stateful interface instead of this stateless one you asked heafield to kludge for you.
lm::ngram::HieuShouldRefactorMoses ret(m_ngram->SlowStatelessScore(&*ngramId.begin(), &*ngramId.end()));
if (finalState)
{
*finalState = ret.meaningless_unique_state;
}
if (len)
{
*len = ret.ngram_length;
@ -109,9 +188,75 @@ float LanguageModelKen::GetValue(const vector<const Word*> &contextFactor, State
return TransformLMScore(ret.prob);
}
lm::WordIndex LanguageModelKen::GetLmID(const std::string &str) const {
template <class Model> float LanguageModelKen<Model>::GetValueForgotState(const vector<const Word*> &contextFactor, FFState &outState, unsigned int* len) const
{
if (contextFactor.empty())
{
static_cast<KenLMState&>(outState).state = m_ngram->NullContextState();
return 0;
}
lm::WordIndex indices[contextFactor.size()];
TranslateIDs(contextFactor, indices);
lm::FullScoreReturn ret(m_ngram->FullScoreForgotState(indices + 1, indices + contextFactor.size(), indices[0], static_cast<KenLMState&>(outState).state));
if (len)
{
*len = ret.ngram_length;
}
return TransformLMScore(ret.prob);
}
template <class Model> void LanguageModelKen<Model>::GetState(const std::vector<const Word*> &contextFactor, FFState &outState) const {
if (contextFactor.empty()) {
static_cast<KenLMState&>(outState).state = m_ngram->NullContextState();
return;
}
lm::WordIndex indices[contextFactor.size()];
TranslateIDs(contextFactor, indices);
m_ngram->GetState(indices, indices + contextFactor.size(), static_cast<KenLMState&>(outState).state);
}
template <class Model> FFState *LanguageModelKen<Model>::GetNullContextState() const {
return m_nullContextState;
}
template <class Model> FFState *LanguageModelKen<Model>::GetBeginSentenceState() const {
return m_beginSentenceState;
}
template <class Model> FFState *LanguageModelKen<Model>::NewState(const FFState *from) const {
KenLMState *ret = new KenLMState;
if (from) {
ret->state = static_cast<const KenLMState&>(*from).state;
}
return ret;
}
template <class Model> lm::WordIndex LanguageModelKen<Model>::GetLmID(const std::string &str) const {
return m_ngram->GetVocabulary().Index(str);
}
} // namespace
LanguageModelSingleFactor *ConstructKenLM(const std::string &file, bool lazy) {
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(file.c_str(), model_type)) {
switch(model_type) {
case lm::ngram::HASH_PROBING:
return new LanguageModelKen<lm::ngram::ProbingModel>(lazy);
case lm::ngram::HASH_SORTED:
return new LanguageModelKen<lm::ngram::SortedModel>(lazy);
case lm::ngram::TRIE_SORTED:
return new LanguageModelKen<lm::ngram::TrieModel>(lazy);
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
} else {
return new LanguageModelKen<lm::ngram::ProbingModel>(lazy);
}
}
}

Some files were not shown because too many files have changed in this diff Show More