move SharedBitGen to random namespace

This commit is contained in:
Taku Kudo 2024-01-06 15:56:51 +00:00
parent 49afc4c6cc
commit adf9e81b63
4 changed files with 10 additions and 20 deletions

View File

@ -303,8 +303,8 @@ bool TrainerInterface::IsValidSentencePiece(
}
template <typename T>
void AddDPNoise(const TrainerSpec &trainer_spec, absl::SharedBitGen &generator,
T *to_update) {
void AddDPNoise(const TrainerSpec &trainer_spec,
random::SharedBitGen &generator, T *to_update) {
if (trainer_spec.differential_privacy_noise_level() > 0) {
float random_num = absl::Gaussian<float>(
generator, 0, trainer_spec.differential_privacy_noise_level());
@ -480,7 +480,7 @@ END:
for (int n = 0; n < num_workers; ++n) {
pool->Schedule([&, n]() {
// One per thread generator.
absl::SharedBitGen generator;
random::SharedBitGen generator;
for (size_t i = n; i < sentences_.size(); i += num_workers) {
AddDPNoise<int64>(trainer_spec_, generator,
&(sentences_[i].second));

View File

@ -288,6 +288,11 @@ namespace random {
std::mt19937 *GetRandomGenerator();
class SharedBitGen {
public:
std::mt19937 *engine() { return GetRandomGenerator(); }
};
template <typename T>
class ReservoirSampler {
public:

View File

@ -21,8 +21,8 @@
namespace absl {
template <typename T>
T Gaussian(SharedBitGen &generator, T mean, T stddev) {
template <typename T, typename G>
T Gaussian(G &generator, T mean, T stddev) {
std::normal_distribution<> dist(mean, stddev);
return dist(*generator.engine());
}

View File

@ -15,19 +15,4 @@
#ifndef ABSL_CONTAINER_RANDOM_H_
#define ABSL_CONTAINER_RANDOM_H_
#include <random>
#include "../../../src/util.h"
using sentencepiece::random::GetRandomGenerator;
namespace absl {
class SharedBitGen {
public:
std::mt19937 *engine() { return GetRandomGenerator(); }
};
} // namespace absl
#endif // ABSL_CONTAINER_RANDOM_H_