Improve cache (#347)

Hide `cache-mutex-buckets` from the user. Now configured to be equal to number
of workers. Python bindings which had exposed these are modified to reflect
the API change. `std::optional` enabled on cache, constructed only if enabled.
Pointers used are replaced with an equivalent `std::optional.`

Fixes: #317
This commit is contained in:
Jerin Philip 2022-02-15 11:04:07 +00:00 committed by GitHub
parent a94725b20d
commit 9f55fb4756
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 40 additions and 39 deletions

View File

@ -198,22 +198,19 @@ PYBIND11_MODULE(_bergamot, m) {
.def("pivot", &ServicePyAdapter::pivot);
py::class_<Service::Config>(m, "ServiceConfig")
.def(py::init<>([](size_t numWorkers, bool cacheEnabled, size_t cacheSize, size_t cacheMutexBuckets,
std::string logging) {
.def(py::init<>([](size_t numWorkers, bool cacheEnabled, size_t cacheSize, std::string logging) {
Service::Config config;
config.numWorkers = numWorkers;
config.cacheEnabled = cacheEnabled;
config.cacheSize = cacheSize;
config.cacheMutexBuckets = cacheMutexBuckets;
config.logger.level = logging;
return config;
}),
py::arg("numWorkers") = 1, py::arg("cacheEnabled") = false, py::arg("cacheSize") = 20000,
py::arg("cacheMutexBuckets") = 1, py::arg("logLevel") = "off")
py::arg("logLevel") = "off")
.def_readwrite("numWorkers", &Service::Config::numWorkers)
.def_readwrite("cacheEnabled", &Service::Config::cacheEnabled)
.def_readwrite("cacheSize", &Service::Config::cacheSize)
.def_readwrite("cacheMutexBuckets", &Service::Config::cacheMutexBuckets);
.def_readwrite("cacheSize", &Service::Config::cacheSize);
py::class_<_Model, std::shared_ptr<_Model>>(m, "TranslationModel");
}

View File

@ -23,7 +23,7 @@ size_t hashForCache(const TranslationModel &model, const marian::Words &words) {
// -----------------------------------------------------------------
Request::Request(size_t Id, const TranslationModel &model, Segments &&segments, ResponseBuilder &&responseBuilder,
TranslationCache *cache)
std::optional<TranslationCache> &cache)
: Id_(Id),
model_(model),
segments_(std::move(segments)),
@ -42,7 +42,7 @@ Request::Request(size_t Id, const TranslationModel &model, Segments &&segments,
counter_ = segments_.size();
histories_.resize(segments_.size());
if (cache_ != nullptr) {
if (cache_) {
// Iterate through segments, see if any can be prefilled from cache. If prefilled, mark the particular segments as
// complete (non-empty ProcessedRequestSentence). Also update accounting used elsewhere (counter_) to reflect one
// less segment to translate.
@ -76,7 +76,7 @@ void Request::processHistory(size_t index, Ptr<History> history) {
// Fill in placeholder from History obtained by freshly translating. Since this was a cache-miss to have got through,
// update cache if available to store the result.
histories_[index] = history;
if (cache_ != nullptr) {
if (cache_) {
size_t key = hashForCache(model_, getSegment(index));
cache_->store(key, histories_[index]);
}

View File

@ -54,7 +54,7 @@ class Request {
/// @param [in] cache: Cache supplied externally to attempt to fetch translations or store them after completion for
/// reuse later.
Request(size_t Id, const TranslationModel &model, Segments &&segments, ResponseBuilder &&responseBuilder,
TranslationCache *cache);
std::optional<TranslationCache> &cache);
/// Obtain the count of tokens in the segment correponding to index. Used to
/// insert sentence from multiple requests into the corresponding size bucket.
@ -100,8 +100,8 @@ class Request {
/// std::vector<Ptr<Vocab const>> *vocabs_;
ResponseBuilder responseBuilder_;
/// Cache used to hold unit translations. If nullptr, means no-caching.
TranslationCache *cache_;
/// Cache used to hold unit translations. If nullopt, means no-caching.
std::optional<TranslationCache> &cache_;
};
/// A RequestSentence provides a view to a sentence within a Request. Existence

View File

@ -30,13 +30,17 @@ Response combine(Response &&first, Response &&second) {
return combined;
}
std::optional<TranslationCache> makeOptionalCache(bool enabled, size_t size, size_t mutexBuckets) {
return enabled ? std::make_optional<TranslationCache>(size, mutexBuckets) : std::nullopt;
}
} // namespace
BlockingService::BlockingService(const BlockingService::Config &config)
: config_(config),
requestId_(0),
batchingPool_(),
cache_(config.cacheSize, /*mutexBuckets=*/1),
cache_(makeOptionalCache(config.cacheEnabled, config.cacheSize, /*mutexBuckets = */ 1)),
logger_(config.logger) {}
std::vector<Response> BlockingService::translateMultiple(std::shared_ptr<TranslationModel> translationModel,
@ -62,9 +66,8 @@ std::vector<Response> BlockingService::translateMultipleRaw(std::shared_ptr<Tran
for (size_t i = 0; i < sources.size(); i++) {
auto callback = [i, &responses](Response &&response) { responses[i] = std::move(response); }; //
TranslationCache *cache = config_.cacheEnabled ? &cache_ : nullptr;
Ptr<Request> request =
translationModel->makeRequest(requestId_++, std::move(sources[i]), callback, responseOptions[i], cache);
translationModel->makeRequest(requestId_++, std::move(sources[i]), callback, responseOptions[i], cache_);
batchingPool_.enqueueRequest(translationModel, request);
}
@ -101,9 +104,8 @@ std::vector<Response> BlockingService::pivotMultiple(std::shared_ptr<Translation
// it in allows further use in makePivotRequest
auto callback = [i, &pivotsToTargets](Response &&response) { pivotsToTargets[i] = std::move(response); }; //
TranslationCache *cache = config_.cacheEnabled ? &cache_ : nullptr;
Ptr<Request> request =
second->makePivotRequest(requestId_++, std::move(intermediate), callback, responseOptions[i], cache);
second->makePivotRequest(requestId_++, std::move(intermediate), callback, responseOptions[i], cache_);
batchingPool_.enqueueRequest(second, request);
}
@ -131,7 +133,7 @@ AsyncService::AsyncService(const AsyncService::Config &config)
: requestId_(0),
config_(config),
safeBatchingPool_(),
cache_(config_.cacheSize, config_.cacheMutexBuckets),
cache_(makeOptionalCache(config_.cacheEnabled, config_.cacheSize, /*mutexBuckets=*/config_.numWorkers)),
logger_(config.logger) {
ABORT_IF(config_.numWorkers == 0, "Number of workers should be at least 1 in a threaded workflow");
workers_.reserve(config_.numWorkers);
@ -188,9 +190,8 @@ void AsyncService::pivot(std::shared_ptr<TranslationModel> first, std::shared_pt
};
// Second call.
TranslationCache *cache = config_.cacheEnabled ? &cache_ : nullptr;
Ptr<Request> request =
second->makePivotRequest(requestId_++, std::move(intermediate), joiningCallback, responseOptions, cache);
second->makePivotRequest(requestId_++, std::move(intermediate), joiningCallback, responseOptions, cache_);
safeBatchingPool_.enqueueRequest(second, request);
};
@ -213,9 +214,8 @@ void AsyncService::translate(std::shared_ptr<TranslationModel> translationModel,
void AsyncService::translateRaw(std::shared_ptr<TranslationModel> translationModel, std::string &&source,
CallbackType callback, const ResponseOptions &responseOptions) {
// Producer thread, a call to this function adds new work items. If batches are available, notifies workers waiting.
TranslationCache *cache = config_.cacheEnabled ? &cache_ : nullptr;
Ptr<Request> request =
translationModel->makeRequest(requestId_++, std::move(source), callback, responseOptions, cache);
translationModel->makeRequest(requestId_++, std::move(source), callback, responseOptions, cache_);
safeBatchingPool_.enqueueRequest(translationModel, request);
}

View File

@ -31,9 +31,15 @@ class BlockingService {
public:
struct Config {
bool cacheEnabled{false}; ///< Whether to enable cache or not.
size_t cacheSize{2000}; ///< Size in History items to be stored in the cache. Loosely corresponds to sentences to
/// cache in the real world.
Logger::Config logger; // Configurations for logging
/// Size in History items to be stored in the cache. Loosely corresponds to sentences to
/// cache in the real world. Note that cache has a random-eviction policy. The peak
/// storage at full occupancy is controlled by this parameter. However, whether we attain
/// full occupancy or not is controlled by random factors - specifically how uniformly
/// the hash distributes.
size_t cacheSize{2000};
Logger::Config logger; ///< Configurations for logging
template <class App>
static void addOptions(App &app, Config &config) {
@ -78,7 +84,7 @@ class BlockingService {
std::vector<Response> pivotMultiple(std::shared_ptr<TranslationModel> first, std::shared_ptr<TranslationModel> second,
std::vector<std::string> &&sources,
const std::vector<ResponseOptions> &responseOptions);
TranslationCache::Stats cacheStats() { return cache_.stats(); }
TranslationCache::Stats cacheStats() { return cache_ ? cache_->stats() : TranslationCache::Stats(); }
private:
std::vector<Response> translateMultipleRaw(std::shared_ptr<TranslationModel> translationModel,
@ -97,7 +103,7 @@ class BlockingService {
// Logger which shuts down cleanly with service.
Logger logger_;
TranslationCache cache_;
std::optional<TranslationCache> cache_;
};
/// Effectively a threadpool, providing an API to take a translation request of a source-text, paramaterized by
@ -110,18 +116,13 @@ class AsyncService {
bool cacheEnabled{false}; ///< Whether to enable cache or not.
size_t cacheSize{2000}; ///< Size in History items to be stored in the cache. Loosely corresponds to sentences to
/// cache in the real world.
size_t cacheMutexBuckets{1}; ///< Controls the granularity of locking to reduce contention by bucketing mutexes
///< guarding cache entry read write. Optimal at min(core, numWorkers) assuming a
///< reasonably large cache-size.
Logger::Config logger; // Configurations for logging
Logger::Config logger; // Configurations for logging
template <class App>
static void addOptions(App &app, Config &config) {
app.add_option("--cpu-threads", config.numWorkers, "Workers to form translation backend");
app.add_option("--cache-translations", config.cacheEnabled, "Whether to cache translations or not.");
app.add_option("--cache-size", config.cacheSize, "Number of entries to store in cache.");
app.add_option("--cache-mutex-buckets", config.cacheMutexBuckets,
"Number of mutex buckets to control locking granularity");
Logger::Config::addOptions(app, config.logger);
}
};
@ -170,11 +171,12 @@ class AsyncService {
/// If you do not want to wait, call `clear()` before destructor.
~AsyncService();
TranslationCache::Stats cacheStats() { return cache_.stats(); }
TranslationCache::Stats cacheStats() { return cache_ ? cache_->stats() : TranslationCache::Stats(); }
private:
void translateRaw(std::shared_ptr<TranslationModel> translationModel, std::string &&source, CallbackType callback,
const ResponseOptions &options = ResponseOptions());
AsyncService::Config config_;
std::vector<std::thread> workers_;
@ -193,7 +195,7 @@ class AsyncService {
// Logger which shuts down cleanly with service.
Logger logger_;
TranslationCache cache_;
std::optional<TranslationCache> cache_;
};
} // namespace bergamot

View File

@ -90,7 +90,8 @@ void TranslationModel::loadBackend(size_t idx) {
// Make request process is shared between Async and Blocking workflow of translating.
Ptr<Request> TranslationModel::makeRequest(size_t requestId, std::string &&source, CallbackType callback,
const ResponseOptions &responseOptions, TranslationCache *cache) {
const ResponseOptions &responseOptions,
std::optional<TranslationCache> &cache) {
Segments segments;
AnnotatedText annotatedSource;
@ -103,7 +104,8 @@ Ptr<Request> TranslationModel::makeRequest(size_t requestId, std::string &&sourc
}
Ptr<Request> TranslationModel::makePivotRequest(size_t requestId, AnnotatedText &&previousTarget, CallbackType callback,
const ResponseOptions &responseOptions, TranslationCache *cache) {
const ResponseOptions &responseOptions,
std::optional<TranslationCache> &cache) {
Segments segments;
textProcessor_.processFromAnnotation(previousTarget, segments);

View File

@ -71,10 +71,10 @@ class TranslationModel {
/// @param [in] responseOptions: Configuration used to prepare the Response corresponding to the created request.
// @returns Request created from the query parameters wrapped within a shared-pointer.
Ptr<Request> makeRequest(size_t requestId, std::string&& source, CallbackType callback,
const ResponseOptions& responseOptions, TranslationCache* cache);
const ResponseOptions& responseOptions, std::optional<TranslationCache>& cache);
Ptr<Request> makePivotRequest(size_t requestId, AnnotatedText&& previousTarget, CallbackType callback,
const ResponseOptions& responseOptions, TranslationCache* cache);
const ResponseOptions& responseOptions, std::optional<TranslationCache>& cache);
/// Relays a request to the batching-pool specific to this translation model.
/// @param [in] request: Request constructed through makeRequest