Improves the thread utilization in batch encoding/decoding

This commit is contained in:
Taku Kudo 2023-08-05 09:01:02 +00:00
parent 635fe8423a
commit 8cbdf13794
3 changed files with 656 additions and 611 deletions

View File

@ -1,13 +1,10 @@
# This file was automatically generated by SWIG (http://www.swig.org).
# Version 4.0.2
# This file was automatically generated by SWIG (https://www.swig.org).
# Version 4.1.0
#
# Do not make changes to this file unless you know what you are doing--modify
# Do not make changes to this file unless you know what you are doing - modify
# the SWIG interface file instead.
from sys import version_info as _swig_python_version_info
if _swig_python_version_info < (2, 7, 0):
raise RuntimeError("Python 2.7 or later required")
# Import the low-level C/C++ module
if __package__ or "." in __name__:
from . import _sentencepiece
@ -29,10 +26,10 @@ def _swig_repr(self):
def _swig_setattr_nondynamic_instance_variable(set):
def set_instance_attr(self, name, value):
if name == "thisown":
self.this.own(value)
elif name == "this":
if name == "this":
set(self, name, value)
elif name == "thisown":
self.this.own(value)
elif hasattr(self, name) and isinstance(getattr(type(self), name), property):
set(self, name, value)
else:
@ -109,7 +106,6 @@ class ImmutableSentencePieceText_ImmutableSentencePiece(object):
# Register ImmutableSentencePieceText_ImmutableSentencePiece in _sentencepiece:
_sentencepiece.ImmutableSentencePieceText_ImmutableSentencePiece_swigregister(ImmutableSentencePieceText_ImmutableSentencePiece)
class ImmutableSentencePieceText(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
@ -179,7 +175,6 @@ class ImmutableSentencePieceText(object):
# Register ImmutableSentencePieceText in _sentencepiece:
_sentencepiece.ImmutableSentencePieceText_swigregister(ImmutableSentencePieceText)
class ImmutableNBestSentencePieceText(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
@ -237,7 +232,6 @@ class ImmutableNBestSentencePieceText(object):
# Register ImmutableNBestSentencePieceText in _sentencepiece:
_sentencepiece.ImmutableNBestSentencePieceText_swigregister(ImmutableNBestSentencePieceText)
class SentencePieceProcessor(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
@ -908,7 +902,6 @@ class SentencePieceProcessor(object):
# Register SentencePieceProcessor in _sentencepiece:
_sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)
def SetRandomGeneratorSeed(seed):
return _sentencepiece.SetRandomGeneratorSeed(seed)
class SentencePieceTrainer(object):
@ -992,22 +985,6 @@ class SentencePieceTrainer(object):
# Register SentencePieceTrainer in _sentencepiece:
_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
def SentencePieceTrainer__TrainFromString(arg):
return _sentencepiece.SentencePieceTrainer__TrainFromString(arg)
def SentencePieceTrainer__TrainFromMap(args):
return _sentencepiece.SentencePieceTrainer__TrainFromMap(args)
def SentencePieceTrainer__TrainFromMap2(args, iter):
return _sentencepiece.SentencePieceTrainer__TrainFromMap2(args, iter)
def SentencePieceTrainer__TrainFromMap3(args):
return _sentencepiece.SentencePieceTrainer__TrainFromMap3(args)
def SentencePieceTrainer__TrainFromMap4(args, iter):
return _sentencepiece.SentencePieceTrainer__TrainFromMap4(args, iter)
import re
import csv
@ -1084,4 +1061,3 @@ class _LogStream(object):
self.ostream.close()

View File

@ -3,6 +3,7 @@
%{
#include <atomic>
#include <iostream>
#include <algorithm>
#include <functional>
@ -246,9 +247,11 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
InitNumThreads(ins, &num_threads); \
{ \
ThreadPool pool(ins.size()); \
std::atomic<size_t> index = 0; \
for (int n = 0; n < num_threads; ++n) { \
pool.Schedule([&, n]() { \
for (size_t i = n; i < ins.size(); i += num_threads) { \
pool.Schedule([&]() { \
size_t i = 0; \
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) { \
auto out = enable_sampling ? \
self->Sample##FuncName(ins[i], \
nbest_size, alpha) : \
@ -267,10 +270,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
std::vector<OutType> outs(ins.size()); \
InitNumThreads(ins, &num_threads); \
{ \
std::atomic<size_t> index = 0; \
ThreadPool pool(ins.size()); \
for (int n = 0; n < num_threads; ++n) { \
pool.Schedule([&, n]() { \
for (size_t i = n; i < ins.size(); i += num_threads) { \
pool.Schedule([&]() { \
size_t i = 0; \
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) { \
CheckIds(ins[i], self->GetPieceSize()); \
auto out = self->FuncName(ins[i]); \
ConvertToUnicodeSpans(&out); \
@ -655,12 +660,14 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
InitNumThreads(ins, &num_threads);
{
ThreadPool pool(ins.size());
std::atomic<size_t> index = 0;
for (int n = 0; n < num_threads; ++n) {
pool.Schedule([&, n]() {
for (size_t i = n; i < ins.size(); i += num_threads) {
outs[i] = self->CalculateEntropy(ins[i], alpha);
}
});
pool.Schedule([&]() {
size_t i = 0;
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) {
outs[i] = self->CalculateEntropy(ins[i], alpha);
}
});
}
}
return outs;

File diff suppressed because it is too large Load Diff