mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-11 09:54:22 +03:00
batch-size -> mini-batch
This commit is contained in:
parent
fc0d8bf735
commit
280345b0f8
@ -202,7 +202,7 @@ void Config::AddOptions(size_t argc, char** argv) {
|
||||
("cpu-threads", po::value<size_t>()->default_value(1),
|
||||
"Number of threads on the CPU.")
|
||||
#endif
|
||||
("batch-size", po::value<size_t>()->default_value(1),
|
||||
("mini-batch", po::value<size_t>()->default_value(1),
|
||||
"Number of sentences in one batch.")
|
||||
("bunch-size", po::value<size_t>()->default_value(1),
|
||||
"Number of batches in one bunch.")
|
||||
@ -285,7 +285,7 @@ void Config::AddOptions(size_t argc, char** argv) {
|
||||
SET_OPTION("no-debpe", bool);
|
||||
SET_OPTION("beam-size", size_t);
|
||||
SET_OPTION("cpu-threads", size_t);
|
||||
SET_OPTION("batch-size", size_t);
|
||||
SET_OPTION("mini-batch", size_t);
|
||||
SET_OPTION("bunch-size", size_t);
|
||||
#ifdef CUDA
|
||||
SET_OPTION("gpu-threads", size_t);
|
||||
|
@ -26,7 +26,7 @@ int main(int argc, char* argv[]) {
|
||||
std::size_t taskCounter = 0;
|
||||
|
||||
size_t bunchSize = god.Get<size_t>("bunch-size");
|
||||
size_t maxBatchSize = god.Get<size_t>("batch-size");
|
||||
size_t maxBatchSize = god.Get<size_t>("mini-batch");
|
||||
std::cerr << "mode=" << god.Get("mode") << std::endl;
|
||||
|
||||
if (god.Get<bool>("wipo") || god.Get<size_t>("cpu-threads")) {
|
||||
|
@ -18,10 +18,10 @@ class BestHyps : public BestHypsBase
|
||||
public:
|
||||
BestHyps(const BestHyps ©) = delete;
|
||||
BestHyps(const God &god)
|
||||
: nthElement_(god.Get<size_t>("beam-size"), god.Get<size_t>("batch-size"),
|
||||
: nthElement_(god.Get<size_t>("beam-size"), god.Get<size_t>("mini-batch"),
|
||||
mblas::CudaStreamHandler::GetStream()),
|
||||
keys(god.Get<size_t>("beam-size") * god.Get<size_t>("batch-size")),
|
||||
Costs(god.Get<size_t>("beam-size") * god.Get<size_t>("batch-size")),
|
||||
keys(god.Get<size_t>("beam-size") * god.Get<size_t>("mini-batch")),
|
||||
Costs(god.Get<size_t>("beam-size") * god.Get<size_t>("mini-batch")),
|
||||
weights_(god.GetScorerWeights())
|
||||
{
|
||||
//std::cerr << "BestHyps::BestHyps" << std::endl;
|
||||
|
@ -101,7 +101,7 @@ class Decoder {
|
||||
Alignment(const God &god, const Weights& model)
|
||||
: w_(model),
|
||||
WC_(w_.C_(0,0)),
|
||||
dBatchMapping_(god.Get<size_t>("batch-size") * god.Get<size_t>("beam-size"), 0)
|
||||
dBatchMapping_(god.Get<size_t>("mini-batch") * god.Get<size_t>("beam-size"), 0)
|
||||
{}
|
||||
|
||||
void Init(const mblas::Matrix& SourceContext) {
|
||||
|
Loading…
Reference in New Issue
Block a user