batch-size -> mini-batch

This commit is contained in:
Hieu Hoang 2017-01-22 22:21:51 +00:00
parent fc0d8bf735
commit 280345b0f8
4 changed files with 7 additions and 7 deletions

View File

@ -202,7 +202,7 @@ void Config::AddOptions(size_t argc, char** argv) {
("cpu-threads", po::value<size_t>()->default_value(1),
"Number of threads on the CPU.")
#endif
("batch-size", po::value<size_t>()->default_value(1),
("mini-batch", po::value<size_t>()->default_value(1),
"Number of sentences in one batch.")
("bunch-size", po::value<size_t>()->default_value(1),
"Number of batches in one bunch.")
@ -285,7 +285,7 @@ void Config::AddOptions(size_t argc, char** argv) {
SET_OPTION("no-debpe", bool);
SET_OPTION("beam-size", size_t);
SET_OPTION("cpu-threads", size_t);
SET_OPTION("batch-size", size_t);
SET_OPTION("mini-batch", size_t);
SET_OPTION("bunch-size", size_t);
#ifdef CUDA
SET_OPTION("gpu-threads", size_t);

View File

@ -26,7 +26,7 @@ int main(int argc, char* argv[]) {
std::size_t taskCounter = 0;
size_t bunchSize = god.Get<size_t>("bunch-size");
size_t maxBatchSize = god.Get<size_t>("batch-size");
size_t maxBatchSize = god.Get<size_t>("mini-batch");
std::cerr << "mode=" << god.Get("mode") << std::endl;
if (god.Get<bool>("wipo") || god.Get<size_t>("cpu-threads")) {

View File

@ -18,10 +18,10 @@ class BestHyps : public BestHypsBase
public:
BestHyps(const BestHyps &copy) = delete;
BestHyps(const God &god)
: nthElement_(god.Get<size_t>("beam-size"), god.Get<size_t>("batch-size"),
: nthElement_(god.Get<size_t>("beam-size"), god.Get<size_t>("mini-batch"),
mblas::CudaStreamHandler::GetStream()),
keys(god.Get<size_t>("beam-size") * god.Get<size_t>("batch-size")),
Costs(god.Get<size_t>("beam-size") * god.Get<size_t>("batch-size")),
keys(god.Get<size_t>("beam-size") * god.Get<size_t>("mini-batch")),
Costs(god.Get<size_t>("beam-size") * god.Get<size_t>("mini-batch")),
weights_(god.GetScorerWeights())
{
//std::cerr << "BestHyps::BestHyps" << std::endl;

View File

@ -101,7 +101,7 @@ class Decoder {
Alignment(const God &god, const Weights& model)
: w_(model),
WC_(w_.C_(0,0)),
dBatchMapping_(god.Get<size_t>("batch-size") * god.Get<size_t>("beam-size"), 0)
dBatchMapping_(god.Get<size_t>("mini-batch") * god.Get<size_t>("beam-size"), 0)
{}
void Init(const mblas::Matrix& SourceContext) {