[number] generate number with bounds more effectively

This commit is contained in:
Vincent Hanquez 2015-05-23 11:59:10 +01:00
parent a4baf9383b
commit 2153e5690f

View File

@ -21,7 +21,7 @@ import Crypto.Random.Types
import Control.Monad (when)
import Foreign.Ptr
import Foreign.Storable
import Data.Bits ((.|.), (.&.), shiftL, shiftR, complement)
import Data.Bits ((.|.), (.&.), shiftL, complement, testBit)
import Crypto.Internal.ByteArray (Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
@ -79,23 +79,48 @@ generateParams bits genTopPolicy generateOdd
bit = (bits - 1) `mod` 8;
mask = 0xff `shiftL` (bit + 1);
-- | generate a positive integer x, s.t. 0 <= x < m
generateMax :: MonadRandom m => Integer -> m Integer
generateMax 1 = return 0
generateMax m
| m <= 0 = error "negative value for generateMax"
| otherwise = do
result <- randomInt bytesLength
let result' = result `shiftR` bitsPoppedOff
if result' >= m
then generateMax m
else return result'
-- | Generate a positive integer x, s.t. 0 <= x < range
generateMax :: MonadRandom m
=> Integer -- ^ range
-> m Integer
generateMax range
| range <= 1 = return 0
| range < 127 = generateSimple
| canOverGenerate = loopGenerateOver tries
| otherwise = loopGenerate tries
where
bytesLength = lengthBytes m
bitsLength = log2 (m-1) + 1
bitsPoppedOff = 8 - (bitsLength `mod` 8)
-- this "generator" is mostly for quickcheck benefits. it'll be biased if
-- range is not a multiple of 2, but overall, no security should be
-- assumed for a number between 0 and 127.
generateSimple = flip mod range `fmap` generateParams bits Nothing False
randomInt nbBytes = os2ipBytes <$> getRandomBytes nbBytes
loopGenerate count
| count == 0 = error "internal: generateMax (normal) doesn't seems to work properly"
| otherwise = do
r <- generateParams bits Nothing False
if isValid r then return r else loopGenerate (count-1)
loopGenerateOver count
| count == 0 = error "internal: generateMax (over) doesn't seems to work properly"
| otherwise = do
r <- generateParams (bits+1) Nothing False
let r2 = r - range
r3 = r2 - range
if isValid r
then return r
else if isValid r2
then return r2
else if isValid r3
then return r3
else loopGenerateOver (count-1)
bits = numBits range
canOverGenerate = bits > 3 && not (range `testBit` (bits-2)) && not (range `testBit` (bits-3))
isValid n = n < range
tries :: Int
tries = 100
-- | generate a number between the inclusive bound [low,high].
generateBetween :: MonadRandom m => Integer -> Integer -> m Integer