[p256] fix all the bugs found by the now useful P256 test suite

This commit is contained in:
Vincent Hanquez 2015-06-01 07:48:31 +01:00
parent 2c112b8877
commit f63a3c6025
2 changed files with 230 additions and 28 deletions

View File

@ -11,8 +11,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-}
{-# OPTIONS_GHC -fno-warn-unused-matches #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
module Crypto.PubKey.ECC.P256
( Scalar
, Point
@ -22,37 +20,51 @@ module Crypto.PubKey.ECC.P256
, pointsMulVarTime
, pointIsValid
, toPoint
, pointToIntegers
, pointFromIntegers
, pointToBinary
, pointFromBinary
-- * scalar arithmetic
, scalarZero
, scalarIsZero
, scalarAdd
, scalarSub
, scalarInv
, scalarCmp
, scalarFromBinary
, scalarToBinary
, scalarFromInteger
, scalarToInteger
) where
import Data.Word
import Foreign.Ptr
import Foreign.C.Types
import Control.Monad
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Crypto.Internal.ByteArray
import qualified Crypto.Internal.ByteArray as B
import Data.Memory.PtrMethods (memSet)
import Crypto.Error
import Crypto.Number.Serialize.Internal (os2ip, i2ospOf)
import qualified Crypto.Number.Serialize as S (os2ip, i2ospOf)
-- | A P256 scalar
newtype Scalar = Scalar ScrubbedBytes
newtype Scalar = Scalar Bytes
deriving (Eq,ByteArrayAccess)
-- | A P256 point
data Point = Point !Bytes !Bytes
newtype Point = Point Bytes
deriving (Show,Eq)
scalarSize :: Int
scalarSize = 32
pointSize :: Int
pointSize = 64
type P256Digit = Word32
data P256Scalar
@ -71,8 +83,11 @@ data P256X
-- > scalar * G
--
toPoint :: Scalar -> Point
toPoint s = withNewPoint $ \px py -> withScalar s $ \p ->
ccryptonite_p256_basepoint_mul p px py
toPoint s
| scalarIsZero s = error "cannot create point from zero"
| otherwise =
withNewPoint $ \px py -> withScalar s $ \p ->
ccryptonite_p256_basepoint_mul p px py
-- | Add a point to another point
pointAdd :: Point -> Point -> Point
@ -104,6 +119,46 @@ pointIsValid p = unsafeDoIO $ withPoint p $ \px py -> do
r <- ccryptonite_p256_is_valid_point px py
return (r /= 0)
pointToIntegers :: Point -> (Integer, Integer)
pointToIntegers p = unsafeDoIO $ withPoint p $ \px py ->
allocTemp 32 (serialize (castPtr px) (castPtr py))
where
serialize px py temp = do
ccryptonite_p256_to_bin px temp
x <- os2ip temp scalarSize
ccryptonite_p256_to_bin py temp
y <- os2ip temp scalarSize
return (x,y)
pointFromIntegers :: (Integer, Integer) -> Point
pointFromIntegers (x,y) = withNewPoint $ \dx dy ->
allocTemp scalarSize (\temp -> fill temp (castPtr dx) x >> fill temp (castPtr dy) y)
where
-- put @n to @temp in big endian format, then from @temp to @dest in p256 scalar format
fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill temp dest n = do
-- write the integer in big endian format to temp
memSet temp 0 scalarSize
e <- i2ospOf n temp scalarSize
if e == 0
then error "pointFromIntegers: filling failed"
else return ()
-- then fill dest with the P256 scalar from temp
ccryptonite_p256_from_bin temp dest
pointToBinary :: ByteArray ba => Point -> ba
pointToBinary p = B.unsafeCreate pointSize $ \dst -> withPoint p $ \px py -> do
ccryptonite_p256_to_bin (castPtr px) dst
ccryptonite_p256_to_bin (castPtr py) (dst `plusPtr` 32)
pointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
pointFromBinary ba
| B.length ba /= pointSize = CryptoFailed $ CryptoError_PublicKeySizeInvalid
| otherwise =
CryptoPassed $ withNewPoint $ \px py -> B.withByteArray ba $ \src -> do
ccryptonite_p256_from_bin src (castPtr px)
ccryptonite_p256_from_bin (src `plusPtr` scalarSize) (castPtr py)
------------------------------------------------------------------------
-- Scalar methods
------------------------------------------------------------------------
@ -112,14 +167,27 @@ pointIsValid p = unsafeDoIO $ withPoint p $ \px py -> do
scalarZero :: Scalar
scalarZero = withNewScalarFreeze $ \d -> ccryptonite_p256_init d
scalarIsZero :: Scalar -> Bool
scalarIsZero s = unsafeDoIO $ withScalar s $ \d -> do
result <- ccryptonite_p256_is_zero d
return $ result /= 0
scalarNeedReducing :: Ptr P256Scalar -> IO Bool
scalarNeedReducing d = do
c <- ccryptonite_p256_cmp d ccryptonite_SECP256r1_n
return (c >= 0)
-- | Perform addition between two scalars
--
-- > a + b
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd a b =
withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb -> do
void $ ccryptonite_p256_add pa pb d
ccryptonite_p256_mod ccryptonite_SECP256r1_n d d
carry <- ccryptonite_p256_add pa pb d
when (carry /= 0) $ void $ ccryptonite_p256_sub d ccryptonite_SECP256r1_n d
needReducing <- scalarNeedReducing d
when needReducing $ do
ccryptonite_p256_mod ccryptonite_SECP256r1_n d d
-- | Perform subtraction between two scalars
--
@ -127,8 +195,11 @@ scalarAdd a b =
scalarSub :: Scalar -> Scalar -> Scalar
scalarSub a b =
withNewScalarFreeze $ \d -> withScalar a $ \pa -> withScalar b $ \pb -> do
void $ ccryptonite_p256_sub pa pb d
ccryptonite_p256_mod ccryptonite_SECP256r1_n d d
borrow <- ccryptonite_p256_sub pa pb d
when (borrow /= 0) $ void $ ccryptonite_p256_add d ccryptonite_SECP256r1_n d
--needReducing <- scalarNeedReducing d
--when needReducing $ do
-- ccryptonite_p256_mod ccryptonite_SECP256r1_n d d
-- | Give the inverse of the scalar
--
@ -154,33 +225,40 @@ scalarFromBinary ba
| otherwise =
CryptoPassed $ withNewScalarFreeze $ \p -> B.withByteArray ba $ \b ->
ccryptonite_p256_from_bin b p
{-# NOINLINE scalarFromBinary #-}
-- | convert a scalar to binary
scalarToBinary :: ByteArray ba => Scalar -> ba
scalarToBinary s = B.allocAndFreeze scalarSize $ \b -> withScalar s $ \p ->
scalarToBinary s = B.unsafeCreate scalarSize $ \b -> withScalar s $ \p ->
ccryptonite_p256_to_bin p b
{-# NOINLINE scalarToBinary #-}
scalarFromInteger :: Integer -> CryptoFailable Scalar
scalarFromInteger i =
maybe (CryptoFailed CryptoError_SecretKeySizeInvalid) scalarFromBinary (S.i2ospOf 32 i :: Maybe Bytes)
scalarToInteger :: Scalar -> Integer
scalarToInteger s = S.os2ip (scalarToBinary s :: Bytes)
------------------------------------------------------------------------
-- Memory Helpers
------------------------------------------------------------------------
withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint f = unsafeDoIO $ do
(x,y) <- B.allocRet pointCoordSize $ \py -> B.alloc pointCoordSize $ \px -> f px py
return $! Point x y
where pointCoordSize = 32
withNewPoint f = Point $ B.unsafeCreate pointSize $ \px -> f px (pxToPy px)
{-# NOINLINE withNewPoint #-}
withPoint :: Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint (Point x y) f = B.withByteArray x $ \px -> B.withByteArray y $ \py -> f px py
withPoint (Point d) f = B.withByteArray d $ \px -> f px (pxToPy px)
pxToPy :: Ptr P256X -> Ptr P256Y
pxToPy px = castPtr (px `plusPtr` scalarSize)
withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze f = Scalar $ B.allocAndFreeze scalarSize f
{-# NOINLINE withNewScalarFreeze #-}
withTempScalar :: (Ptr P256Scalar -> IO a) -> IO a
withTempScalar f = ignoreSnd <$> B.allocRet scalarSize f
where ignoreSnd :: (a, ScrubbedBytes) -> a
ignoreSnd = fst
withTempScalar f = allocTempScrubbed scalarSize (f . castPtr)
withScalar :: Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar (Scalar d) f = B.withByteArray d f
@ -191,6 +269,18 @@ withScalarZero f =
ccryptonite_p256_init d
f d
allocTemp :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp n f = ignoreSnd <$> B.allocRet n f
where
ignoreSnd :: (a, Bytes) -> a
ignoreSnd = fst
allocTempScrubbed :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed n f = ignoreSnd <$> B.allocRet n f
where
ignoreSnd :: (a, ScrubbedBytes) -> a
ignoreSnd = fst
------------------------------------------------------------------------
-- Foreign bindings
------------------------------------------------------------------------
@ -203,10 +293,14 @@ foreign import ccall "&cryptonite_SECP256r1_b"
foreign import ccall "cryptonite_p256_init"
ccryptonite_p256_init :: Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_is_zero"
ccryptonite_p256_is_zero :: Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256_clear"
ccryptonite_p256_clear :: Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_add"
ccryptonite_p256_add :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256_add_d"
ccryptonite_p256_add_d :: Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256_sub"
ccryptonite_p256_sub :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256_cmp"

View File

@ -1,33 +1,141 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module KAT_PubKey.P256 (tests) where
import Control.Arrow (second)
import qualified Crypto.PubKey.ECC.Types as ECC
import qualified Crypto.PubKey.ECC.Prim as ECC
import qualified Crypto.PubKey.ECC.P256 as P256
import Test.Tasty.KAT
import Test.Tasty.KAT.FileLoader
import Data.ByteArray (Bytes)
import Crypto.Number.Serialize (i2ospOf)
import Crypto.Number.Serialize (i2ospOf, os2ip)
import Crypto.Number.ModArithmetic (inverseCoprimes)
import Crypto.Error
import Imports
newtype P256Scalar = P256Scalar Integer
deriving (Show,Eq,Ord)
instance Arbitrary P256Scalar where
arbitrary = P256Scalar . getQAInteger <$> arbitrary
curve = ECC.getCurveByName ECC.SEC_p256r1
curveN = ECC.ecc_n . ECC.common_curve $ curve
curveGen = ECC.ecc_g . ECC.common_curve $ curve
pointP256ToECC :: P256.Point -> ECC.Point
pointP256ToECC = uncurry ECC.Point . P256.pointToIntegers
unP256Scalar :: P256Scalar -> P256.Scalar
unP256Scalar (P256Scalar r') =
let r = if r' == 0 then 0x2901 else (r' `mod` curveN)
rBytes = i2ospScalar r
in case P256.scalarFromBinary rBytes of
CryptoFailed err -> error ("cannot convert scalar: " ++ show err)
CryptoPassed scalar -> scalar
where
i2ospScalar :: Integer -> Bytes
i2ospScalar i =
case i2ospOf 32 i of
Nothing -> error "invalid size of P256 scalar"
Just b -> b
unP256 :: P256Scalar -> Integer
unP256 (P256Scalar r') = if r' == 0 then 0x2901 else (r' `mod` curveN)
p256ScalarToInteger :: P256.Scalar -> Integer
p256ScalarToInteger s = os2ip (P256.scalarToBinary s :: Bytes)
xS = 0xde2444bebc8d36e682edd27e0f271508617519b3221a8fa0b77cab3989da97c9
yS = 0xc093ae7ff36e5380fc01a5aad1e66659702de80f53cec576b6350b243042a256
xT = 0x55a8b00f8da1d44e62f6b3b25316212e39540dc861c89575bb8cf92e35e0986b
yT = 0x5421c3209c2d6c704835d82ac4c3dd90f61a8a52598b9e7ab656e9d8c8b24316
xR = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e
yR = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264
tests = testGroup "P256"
[ testGroup "scalar"
[ testProperty "marshalling" $ \(Positive r') ->
[ testProperty "marshalling" $ \(QAInteger r') ->
let r = r' `mod` curveN
rBytes = i2ospScalar r
in case P256.scalarFromBinary rBytes of
CryptoFailed err -> error (show err)
CryptoPassed scalar -> rBytes `propertyEq` P256.scalarToBinary scalar
, testProperty "add" $ \r1 r2 ->
let r = (unP256 r1 + unP256 r2) `mod` curveN
r' = P256.scalarAdd (unP256Scalar r1) (unP256Scalar r2)
in r `propertyEq` p256ScalarToInteger r'
, testProperty "add0" $ \r ->
let v = unP256 r
v' = P256.scalarAdd (unP256Scalar r) P256.scalarZero
in v `propertyEq` p256ScalarToInteger v'
, testProperty "add-n-1" $ \r ->
let nm1 = throwCryptoError $ P256.scalarFromInteger (curveN - 1)
v = unP256 r
v' = P256.scalarAdd (unP256Scalar r) nm1
in (((curveN - 1) + v) `mod` curveN) `propertyEq` p256ScalarToInteger v'
, testProperty "sub" $ \r1 r2 ->
let r = (unP256 r1 - unP256 r2) `mod` curveN
r' = P256.scalarSub (unP256Scalar r1) (unP256Scalar r2)
v = (unP256 r2 - unP256 r1) `mod` curveN
v' = P256.scalarSub (unP256Scalar r2) (unP256Scalar r1)
in propertyHold
[ eqTest "r1-r2" r (p256ScalarToInteger r')
, eqTest "r2-r1" v (p256ScalarToInteger v')
]
, testProperty "sub-n-1" $ \r ->
let nm1 = throwCryptoError $ P256.scalarFromInteger (curveN - 1)
v = unP256 r
v' = P256.scalarSub (unP256Scalar r) nm1
in ((v - (curveN - 1)) `mod` curveN) `propertyEq` p256ScalarToInteger v'
, testProperty "inv" $ \r' ->
let inv = inverseCoprimes (unP256 r') curveN
inv' = P256.scalarInv (unP256Scalar r')
in if unP256 r' == 0 then True else inv `propertyEq` p256ScalarToInteger inv'
]
, testGroup "point"
[ testProperty "marshalling" $ \rx ry ->
let p = P256.pointFromIntegers (unP256 rx, unP256 ry)
b = P256.pointToBinary p :: Bytes
p' = P256.pointFromBinary b
in propertyHold [ eqTest "point" (CryptoPassed p) p' ]
, testProperty "marshalling-integer" $ \rx ry ->
let p = P256.pointFromIntegers (unP256 rx, unP256 ry)
(x,y) = P256.pointToIntegers p
in propertyHold [ eqTest "x" (unP256 rx) x, eqTest "y" (unP256 ry) y ]
, testCase "valid-point-1" $ casePointIsValid (xS,yS)
, testCase "valid-point-2" $ casePointIsValid (xR,yR)
, testCase "valid-point-3" $ casePointIsValid (xT,yT)
, testCase "point-add-1" $
let s = P256.pointFromIntegers (xS, yS)
t = P256.pointFromIntegers (xT, yT)
r = P256.pointFromIntegers (xR, yR)
in r @=? P256.pointAdd s t
, testProperty "lift-to-curve" $ propertyLiftToCurve
, testProperty "point-add" $ propertyPointAdd
]
]
where
curve = ECC.getCurveByName ECC.SEC_p256r1
curveN = ECC.ecc_n . ECC.common_curve $ curve
casePointIsValid pointTuple =
let s = P256.pointFromIntegers pointTuple in True @=? P256.pointIsValid s
propertyLiftToCurve r =
let p = P256.toPoint (unP256Scalar r)
(x,y) = P256.pointToIntegers p
pEcc = ECC.pointMul curve (unP256 r) curveGen
in pEcc `propertyEq` ECC.Point x y
propertyPointAdd r1 r2 =
let p1 = P256.toPoint (unP256Scalar r1)
p2 = P256.toPoint (unP256Scalar r2)
pe1 = ECC.pointMul curve (unP256 r1) curveGen
pe2 = ECC.pointMul curve (unP256 r2) curveGen
pR = P256.toPoint (P256.scalarAdd (unP256Scalar r1) (unP256Scalar r2))
peR = ECC.pointAdd curve pe1 pe2
(x,y) = P256.pointToIntegers (P256.pointAdd p1 p2) -- P256.pointToIntegers pR
in propertyHold [ eqTest "p256" pR (P256.pointAdd p1 p2)
, eqTest "ecc" peR (pointP256ToECC pR)
]
i2ospScalar :: Integer -> Bytes
i2ospScalar i =