diff --git a/lol/Crypto/Lol/Cyclotomic/CRTSentinel.hs b/lol/Crypto/Lol/Cyclotomic/CRTSentinel.hs index 0c4258e7..ddf0f946 100644 --- a/lol/Crypto/Lol/Cyclotomic/CRTSentinel.hs +++ b/lol/Crypto/Lol/Cyclotomic/CRTSentinel.hs @@ -73,5 +73,5 @@ embedCRTCS _ _ = fromJust embedCRT twaceCRTCS :: (Tensor t, m `Divides` m', CRTrans Maybe r, TElt t r) => CSentinel t m' r -> CSentinel t m r -> t m' r -> t m r twaceCRTCS _ _ = fromJust twaceCRT -{-# INLINABLE twaceCRTCS #-} +{-# INLINE twaceCRTCS #-} diff --git a/lol/Crypto/Lol/Cyclotomic/Cyc.hs b/lol/Crypto/Lol/Cyclotomic/Cyc.hs index b84fa469..ffd48add 100644 --- a/lol/Crypto/Lol/Cyclotomic/Cyc.hs +++ b/lol/Crypto/Lol/Cyclotomic/Cyc.hs @@ -402,7 +402,7 @@ embed' (Sub (c :: Cyc t k r)) = embed' c -- | The "tweaked trace" (twace) function -- \(\Tw(x) = (\hat{m} / \hat{m}') \cdot \Tr((g' / g) \cdot x)\), -- which fixes \(R\) pointwise (i.e., @twace . embed == id@). -twace :: forall t m m' r . (m `Divides` m', CElt t r) +twace :: forall t m m' r . (m `Divides` m', UCRTElt t r, ZeroTestable r) => Cyc t m' r -> Cyc t m r {-# INLINABLE twace #-} twace (Pow u) = Pow $ U.twacePow u @@ -562,10 +562,10 @@ instance (Correct gad zq, Fact m, CElt t zq) => Correct gad (Cyc t m zq) where ---------- Change of representation (internal use only) ---------- -toPow', toDec', toCRT' :: (Fact m, CElt t r) => Cyc t m r -> Cyc t m r -{-# INLINE toPow' #-} -{-# INLINE toDec' #-} -{-# INLINE toCRT' #-} +toPow', toDec', toCRT' :: (Fact m, UCRTElt t r, ZeroTestable r) => Cyc t m r -> Cyc t m r +{-# INLINABLE toPow' #-} +{-# INLINABLE toDec' #-} +{-# INLINABLE toCRT' #-} -- | Force to powerful-basis representation (for internal use only). toPow' c@(Pow _) = c diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor.hs b/lol/Crypto/Lol/Cyclotomic/Tensor.hs index 342ff5eb..fcb53d9a 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor.hs +++ b/lol/Crypto/Lol/Cyclotomic/Tensor.hs @@ -202,7 +202,7 @@ mulGCRT, divGCRT, crt, crtInv :: {-# INLINABLE mulGCRT #-} {-# INLINABLE divGCRT #-} {-# INLINABLE crt #-} -{-# INLINABLE crtInv #-} +{-# INLINE crtInv #-} -- | Multiply by \(g_m\) in the CRT basis. (This function is simply an -- appropriate entry from 'crtFuncs'.) @@ -223,6 +223,7 @@ crtInv = (\(_,_,_,_,f) -> f) <$> crtFuncs -- (This function is simply an appropriate entry from 'crtExtFuncs'.) twaceCRT :: forall t m m' mon r . (CRTrans mon r, Tensor t, m `Divides` m', TElt t r) => mon (t m' r -> t m r) +{-# INLINABLE twaceCRT #-} twaceCRT = proxyT hasCRTFuncs (Proxy::Proxy (t m' r)) *> proxyT hasCRTFuncs (Proxy::Proxy (t m r)) *> (fst <$> crtExtFuncs) @@ -413,6 +414,7 @@ indexInfo = let pps = proxy ppsFact (Proxy::Proxy m) -- the index into the powerful\/decoding basis of \(\O_{m'}\) of the -- \(i\)th entry of the powerful/decoding basis of \(\O_m\). extIndicesPowDec :: (m `Divides` m') => Tagged '(m, m') (U.Vector Int) +{-# INLINABLE extIndicesPowDec #-} extIndicesPowDec = do (_, phi, _, tots) <- indexInfo return $ U.generate phi (fromIndexPair tots . (0,)) @@ -438,15 +440,15 @@ baseWrapper f = do -- | A lookup table for 'toIndexPair' applied to indices \([\varphi(m')]\). baseIndicesPow :: forall m m' . (m `Divides` m') => Tagged '(m, m') (U.Vector (Int,Int)) +{-# INLINABLE baseIndicesPow #-} -- | A lookup table for 'baseIndexDec' applied to indices \([\varphi(m')]\). baseIndicesDec :: forall m m' . (m `Divides` m') => Tagged '(m, m') (U.Vector (Maybe (Int,Bool))) - +{-# INLINABLE baseIndicesDec #-} -- | Same as 'baseIndicesPow', but only includes the second component -- of each pair. baseIndicesCRT :: forall m m' . (m `Divides` m') => Tagged '(m, m') (U.Vector Int) - baseIndicesPow = baseWrapper (toIndexPair . totients) -- this one is more complicated; requires the prime powers diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor.hs b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor.hs index 1db49b41..c9b6c649 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor.hs +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor.hs @@ -21,8 +21,7 @@ -- | Wrapper for a C++ implementation of the 'Tensor' interface. -module Crypto.Lol.Cyclotomic.Tensor.CTensor -( CT ) where +module Crypto.Lol.Cyclotomic.Tensor.CTensor (CT) where import Algebra.Additive as Additive (C) import Algebra.Module as Module (C) @@ -44,7 +43,7 @@ import Data.Traversable as T import Data.Vector.Generic as V (fromList, toList, unzip) import Data.Vector.Storable as SV (Vector, convert, foldl', fromList, generate, - length, map, mapM, replicate, + length, map, replicate, replicateM, thaw, thaw, toList, unsafeFreeze, unsafeWith, zipWith, (!)) @@ -158,13 +157,14 @@ toZV v@(ZV _) = v zvToCT' :: forall m r . (Storable r) => IZipVector m r -> CT' m r zvToCT' v = coerce (convert $ unIZipVector v :: Vector r) -wrap :: (Storable r) => (CT' l r -> CT' m r) -> (CT l r -> CT m r) +wrap :: (Storable s, Storable r) => (CT' l s -> CT' m r) -> (CT l s -> CT m r) {-# INLINABLE wrap #-} wrap f (CT v) = CT $ f v wrap f (ZV v) = CT $ f $ zvToCT' v -wrapM :: (Storable r, Monad mon) => (CT' l r -> mon (CT' m r)) - -> (CT l r -> mon (CT m r)) +wrapM :: (Storable s, Storable r, Monad mon) => (CT' l s -> mon (CT' m r)) + -> (CT l s -> mon (CT m r)) +{-# INLINABLE wrapM #-} wrapM f (CT v) = CT <$> f v wrapM f (ZV v) = CT <$> f (zvToCT' v) @@ -248,15 +248,14 @@ instance Tensor CT where scalarPow = CT . scalarPow' -- Vector code - l = wrap $ untag $ basicDispatch dl - lInv = wrap $ untag $ basicDispatch dlinv + l = wrap $ basicDispatch dl + lInv = wrap $ basicDispatch dlinv - mulGPow = wrap mulGPow' - mulGDec = wrap $ untag $ basicDispatch dmulgdec + mulGPow = wrap $ basicDispatch dmulgpow + mulGDec = wrap $ basicDispatch dmulgdec - divGPow = wrapM divGPow' - -- we divide by p in the C code (for divGDec only(?)), do NOT call checkDiv! - divGDec = wrapM $ Just . untag (basicDispatch dginvdec) + divGPow = wrapM $ dispatchGInv dginvpow + divGDec = wrapM $ dispatchGInv dginvdec crtFuncs = (,,,,) <$> return (CT . repl) <*> @@ -285,14 +284,16 @@ instance Tensor CT where crtSetDec = (CT <$>) <$> coerceBasis crtSetDec' - fmapT f (CT v) = CT $ coerce (SV.map f) v - fmapT f v@(ZV _) = fmapT f $ toCT v + fmapT f = wrap $ coerce (SV.map f) - zipWithT f (CT (CT' v1)) (CT (CT' v2)) = CT $ CT' $ SV.zipWith f v1 v2 - zipWithT f v1 v2 = zipWithT f (toCT v1) (toCT v2) + zipWithT f v1' v2' = + let (CT (CT' v1)) = toCT v1' + (CT (CT' v2)) = toCT v2' + in CT $ CT' $ SV.zipWith f v1 v2 - unzipT (CT (CT' v)) = (CT . CT') *** (CT . CT') $ unzip v - unzipT v = unzipT $ toCT v + unzipT v = + let (CT (CT' x)) = toCT v + in (CT . CT') *** (CT . CT') $ unzip x {-# INLINABLE entailIndexT #-} {-# INLINABLE entailEqT #-} @@ -313,14 +314,13 @@ instance Tensor CT where {-# INLINABLE embedDec #-} {-# INLINABLE tGaussianDec #-} {-# INLINABLE gSqNormDec #-} - {-# INLINABLE crtExtFuncs #-} + {-# INLINE crtExtFuncs #-} {-# INLINABLE coeffs #-} {-# INLINABLE powBasisPow #-} {-# INLINABLE crtSetDec #-} {-# INLINABLE fmapT #-} - {-# INLINABLE zipWithT #-} - {-# INLINABLE unzipT #-} - + {-# INLINE zipWithT #-} + {-# INLINE unzipT #-} coerceTw :: (Functor mon) => TaggedT '(m, m') mon (Vector r -> Vector r) -> mon (CT' m' r -> CT' m r) coerceTw = (coerce <$>) . untagT @@ -338,12 +338,21 @@ coerceCoeffs = coerce coerceBasis :: Tagged '(m,m') [Vector r] -> Tagged m [CT' m' r] coerceBasis = coerce -mulGPow' :: (TElt CT r, Fact m) => CT' m r -> CT' m r -mulGPow' = untag $ basicDispatch dmulgpow - -divGPow' :: (TElt CT r, Fact m, IntegralDomain r, ZeroTestable r) - => CT' m r -> Maybe (CT' m r) -divGPow' = untag $ checkDiv $ basicDispatch dginvpow +dispatchGInv :: forall m r . (Storable r, Fact m) + => (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16) + -> CT' m r -> Maybe (CT' m r) +dispatchGInv f = + let factors = proxy (marshalFactors <$> ppsFact) (Proxy::Proxy m) + totm = proxy (fromIntegral <$> totientFact) (Proxy::Proxy m) + numFacts = fromIntegral $ SV.length factors + in \(CT' x) -> unsafePerformIO $ do + yout <- SV.thaw x + ret <- SM.unsafeWith yout (\pout -> + SV.unsafeWith factors (\pfac -> + f pout totm pfac numFacts)) + if ret /= 0 + then Just . CT' <$> unsafeFreeze yout + else return Nothing withBasicArgs :: forall m r . (Fact m, Storable r) => (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()) @@ -361,8 +370,8 @@ withBasicArgs f = basicDispatch :: (Storable r, Fact m) => (Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO ()) - -> Tagged m (CT' m r -> CT' m r) -basicDispatch f = return $ unsafePerformIO . withBasicArgs f + -> CT' m r -> CT' m r +basicDispatch f = unsafePerformIO . withBasicArgs f gSqNormDec' :: (Storable r, Fact m, Dispatch r) => Tagged m (CT' m r -> r) @@ -384,19 +393,6 @@ ctCRTInv = do return $ \x -> unsafePerformIO $ withPtrArray ruinv' (\ruptr -> with mhatInv (flip withBasicArgs x . dcrtinv ruptr)) -checkDiv :: (Storable r, IntegralDomain r, ZeroTestable r, Fact m) - => Tagged m (CT' m r -> CT' m r) -> Tagged m (CT' m r -> Maybe (CT' m r)) -checkDiv f = do - f' <- f - oddRad' <- fromIntegral <$> oddRadicalFact - return $ \x -> - let (CT' y) = f' x - in CT' <$> SV.mapM (`divIfDivis` oddRad') y - -divIfDivis :: (IntegralDomain r, ZeroTestable r) => r -> r -> Maybe r -divIfDivis num den = let (q,r) = num `divMod` den - in if isZero r then Just q else Nothing - cZipDispatch :: (Storable r, Fact m) => (Ptr r -> Ptr r -> Int64 -> IO ()) -> Tagged m (CT' m r -> CT' m r -> CT' m r) diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Backend.hs b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Backend.hs index 9e3e5dfa..8e68e647 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Backend.hs +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Backend.hs @@ -166,9 +166,9 @@ class (repr ~ CTypeOf r) => Dispatch' repr r where -- | Equivalent to 'Tensor's @mulGDec@. dmulgdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO () -- | Equivalent to 'Tensor's @divGPow@. - dginvpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO () + dginvpow :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16 -- | Equivalent to 'Tensor's @divGDec@. - dginvdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO () + dginvdec :: Ptr r -> Int64 -> Ptr CPP -> Int16 -> IO Int16 -- | Equivalent to @zipWith (*)@ dmul :: Ptr r -> Ptr r -> Int64 -> IO () @@ -315,12 +315,12 @@ foreign import ccall unsafe "tensorGPowC" tensorGPowC :: Int16 -> Ptr (C foreign import ccall unsafe "tensorGDecR" tensorGDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO () foreign import ccall unsafe "tensorGDecRq" tensorGDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO () foreign import ccall unsafe "tensorGDecC" tensorGDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO () -foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO () -foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO () -foreign import ccall unsafe "tensorGInvPowC" tensorGInvPowC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO () -foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO () -foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO () -foreign import ccall unsafe "tensorGInvDecC" tensorGInvDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO () +foreign import ccall unsafe "tensorGInvPowR" tensorGInvPowR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16 +foreign import ccall unsafe "tensorGInvPowRq" tensorGInvPowRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO Int16 +foreign import ccall unsafe "tensorGInvPowC" tensorGInvPowC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16 +foreign import ccall unsafe "tensorGInvDecR" tensorGInvDecR :: Int16 -> Ptr Int64 -> Int64 -> Ptr CPP -> Int16 -> IO Int16 +foreign import ccall unsafe "tensorGInvDecRq" tensorGInvDecRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr Int64 -> IO Int16 +foreign import ccall unsafe "tensorGInvDecC" tensorGInvDecC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> IO Int16 foreign import ccall unsafe "tensorCRTRq" tensorCRTRq :: Int16 -> Ptr (ZqBasic q Int64) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (ZqBasic q Int64)) -> Ptr Int64 -> IO () foreign import ccall unsafe "tensorCRTC" tensorCRTC :: Int16 -> Ptr (Complex Double) -> Int64 -> Ptr CPP -> Int16 -> Ptr (Ptr (Complex Double)) -> IO () diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Extension.hs b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Extension.hs index 98911dd7..69e1fc05 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Extension.hs +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/Extension.hs @@ -50,6 +50,8 @@ backpermute' is v = generate (U.length is) (\i -> v ! (is U.! i)) embedPow', embedDec' :: (Additive r, Storable r, m `Divides` m') => Tagged '(m, m') (Vector r -> Vector r) +{-# INLINABLE embedPow' #-} +{-# INLINABLE embedDec' #-} -- | Embeds an vector in the powerful basis of the the mth cyclotomic ring -- to an vector in the powerful basis of the m'th cyclotomic ring when @m | m'@ embedPow' = (\indices arr -> generate (U.length indices) $ \idx -> @@ -98,6 +100,7 @@ kronToVec v = do twaceCRT' :: forall mon m m' r . (Storable r, CRTrans mon r, m `Divides` m') => TaggedT '(m, m') mon (Vector r -> Vector r) +{-# INLINE twaceCRT' #-} twaceCRT' = tagT $ do g' <- proxyT (kronToVec gCRTK) (Proxy::Proxy m') gInv <- proxyT (kronToVec gInvCRTK) (Proxy::Proxy m) diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/common.h b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/common.h index ad01b7f5..367f43b0 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/common.h +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/common.h @@ -1,17 +1,8 @@ #ifndef COMMON_H_ #define COMMON_H_ -#include -#include #include "types.h" -#define ASSERT(EXP) { \ - if (!(EXP)) { \ - fprintf (stderr, "Assertion in file '%s' line %d : " #EXP " is false\n", __FILE__, __LINE__); \ - exit(-1); \ - } \ -} - // calculates base ** exp hDim_t ipow(hDim_t base, hShort_t exp); diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/g.cpp b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/g.cpp index 3bdebb17..a7f79835 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/g.cpp +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/g.cpp @@ -99,12 +99,11 @@ template void gInvDec (ring* y, hShort_t tupSize, hDim_t lts, hD } ring rp; rp = p; - ring acc = lastOut / rp; - ASSERT ((acc * rp) == lastOut); // this line asserts that lastOut % p == 0, without calling % operator + ring acc = lastOut; for (i = p-2; i > 0; --i) { hDim_t idx = tensorOffset + i*rts; ring tmp = acc; - acc -= y[idx*tupSize]; // we already divided acc by p, do not multiply y[idx] by p + acc -= y[idx*tupSize]*rp; y[idx*tupSize] = tmp; } y[tensorOffset*tupSize] = acc; @@ -144,34 +143,120 @@ extern "C" void tensorGDecC (hShort_t tupSize, Complex* y, hDim_t totm, PrimeExp tensorFuserPrime (y, tupSize, gDec, totm, peArr, sizeOfPE, (hInt_t*)0); } -extern "C" void tensorGInvPowR (hShort_t tupSize, hInt_t* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) +hInt_t oddRad(PrimeExponent* peArr, hShort_t sizeOfPE) { + hInt_t oddrad; + oddrad = 1; + for(int i = 0; i < sizeOfPE; i++) { + hShort_t p = peArr[i].prime; + if (p != 2) { + oddrad *= peArr[i].prime; + } + } + return oddrad; +} + +extern "C" hShort_t tensorGInvPowR (hShort_t tupSize, hInt_t* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) { tensorFuserPrime (y, tupSize, gInvPow, totm, peArr, sizeOfPE, (hInt_t*)0); + + hInt_t oddrad = oddRad(peArr, sizeOfPE); + + for(int i = 0; i < tupSize*totm; i++) { + if (y[i] % oddrad) { + y[i] /= oddrad; + } + else { + return 0; + } + } + return 1; } -extern "C" void tensorGInvPowRq (hShort_t tupSize, Zq* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE, hInt_t* qs) +extern "C" hShort_t tensorGInvPowRq (hShort_t tupSize, Zq* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE, hInt_t* qs) { tensorFuserPrime (y, tupSize, gInvPow, totm, peArr, sizeOfPE, qs); + + hInt_t oddrad = oddRad(peArr, sizeOfPE); + + for(int i = 0; i < tupSize; i++) { + Zq::q = qs[i]; // global update + hInt_t ori = reciprocal(Zq::q, oddrad); + Zq oddradInv; + oddradInv = ori; + if (ori == 0) { + return 0; // error condition + } + for(hDim_t j = 0; j < totm; j++) { + y[j*tupSize+i] *= oddradInv; + } + } + canonicalizeZq(y,tupSize,totm,qs); + return 1; } -extern "C" void tensorGInvPowC (hShort_t tupSize, Complex* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) +extern "C" hShort_t tensorGInvPowC (hShort_t tupSize, Complex* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) { tensorFuserPrime (y, tupSize, gInvPow, totm, peArr, sizeOfPE, (hInt_t*)0); + + hInt_t oddrad = oddRad(peArr, sizeOfPE); + Complex oddradInv; + oddradInv = 1 / oddrad; + for(int i = 0; i < tupSize*totm; i++) { + y[i] *= oddradInv; + } + return 1; } -extern "C" void tensorGInvDecR (hShort_t tupSize, hInt_t* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) +extern "C" hShort_t tensorGInvDecR (hShort_t tupSize, hInt_t* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) { tensorFuserPrime (y, tupSize, gInvDec, totm, peArr, sizeOfPE, (hInt_t*)0); + + hInt_t oddrad = oddRad(peArr, sizeOfPE); + + for(int i = 0; i < tupSize*totm; i++) { + if (y[i] % oddrad) { + y[i] /= oddrad; + } + else { + return 0; + } + } + return 1; } -extern "C" void tensorGInvDecRq (hShort_t tupSize, Zq* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE, hInt_t* qs) +extern "C" hShort_t tensorGInvDecRq (hShort_t tupSize, Zq* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE, hInt_t* qs) { tensorFuserPrime (y, tupSize, gInvDec, totm, peArr, sizeOfPE, qs); + + hInt_t oddrad = oddRad(peArr, sizeOfPE); + + for(int i = 0; i < tupSize; i++) { + Zq::q = qs[i]; // global update + hInt_t ori = reciprocal(Zq::q, oddrad); + Zq oddradInv; + oddradInv = ori; + if (ori == 0) { + return 0; // error condition + } + for(hDim_t j = 0; j < totm; j++) { + y[j*tupSize+i] *= oddradInv; + } + } + canonicalizeZq(y,tupSize,totm,qs); + return 1; } -extern "C" void tensorGInvDecC (hShort_t tupSize, Complex* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) +extern "C" hShort_t tensorGInvDecC (hShort_t tupSize, Complex* y, hDim_t totm, PrimeExponent* peArr, hShort_t sizeOfPE) { tensorFuserPrime (y, tupSize, gInvDec, totm, peArr, sizeOfPE, (hInt_t*)0); + + hInt_t oddrad = oddRad(peArr, sizeOfPE); + Complex oddradInv; + oddradInv = 1 / oddrad; + for(int i = 0; i < tupSize*totm; i++) { + y[i] *= oddradInv; + } + return 1; } \ No newline at end of file diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/types.h b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/types.h index 6b25a47a..6b7b3c39 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/types.h +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/types.h @@ -3,6 +3,8 @@ #define TENSORTYPES_H_ #include +#include +#include typedef int64_t hInt_t ; typedef int32_t hDim_t ; @@ -18,6 +20,13 @@ typedef struct hInt_t reciprocal (hInt_t a, hInt_t b); +#define ASSERT(EXP) { \ + if (!(EXP)) { \ + fprintf (stderr, "Assertion in file '%s' line %d : " #EXP " is false\n", __FILE__, __LINE__); \ + exit(-1); \ + } \ +} + //http://stackoverflow.com/questions/37572628 #ifdef __cplusplus //http://stackoverflow.com/a/4421719 @@ -55,14 +64,11 @@ class Zq { Zq binv; binv = reciprocal(q,b.x); + ASSERT (binv.x); // binv == 0 indicates that x is not invertible mod q *this *= binv; return *this; } }; -inline char operator==(Zq a, const Zq& b) -{ - return (a.x == b.x); -} inline Zq operator+(Zq a, const Zq& b) { a += b; @@ -129,13 +135,6 @@ class Complex return *this; } }; -inline char operator==(Complex a, const Complex& b) -{ - // This is only used in divGDec, where we do a divisiblity check. - // The divisibility check should always succeed for Complex since \C is a field, - // however if we actually implement equality, it would fail due to roundoff. - return 1; -} inline Complex operator+(Complex a, const Complex& b) { a += b; diff --git a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/zq.cpp b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/zq.cpp index a2f0d840..5b39a31c 100644 --- a/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/zq.cpp +++ b/lol/Crypto/Lol/Cyclotomic/Tensor/CTensor/zq.cpp @@ -17,7 +17,10 @@ hInt_t reciprocal (hInt_t a, hInt_t b) y = lasty - quotient*y; lasty = tmp; } - ASSERT (a==1); // if this one fails, then b is not invertible mod a + // if a!=1, then b is not invertible mod a + if(a!=1) { + return 0; + } // this actually returns EITHER the reciprocal OR reciprocal + fieldSize hInt_t res = lasty + fieldSize; diff --git a/lol/Crypto/Lol/Cyclotomic/UCyc.hs b/lol/Crypto/Lol/Cyclotomic/UCyc.hs index daa1cda9..7c043906 100644 --- a/lol/Crypto/Lol/Cyclotomic/UCyc.hs +++ b/lol/Crypto/Lol/Cyclotomic/UCyc.hs @@ -33,7 +33,7 @@ module Crypto.Lol.Cyclotomic.UCyc ( -- * Data types and constraints - UCyc, P, D, C, E, UCycEC, UCRTElt, NFElt + UCyc, P, D, C, E, UCycEC, UCycPC, UCRTElt, NFElt -- * Changing representation , toPow, toDec, toCRT, fmapPow, fmapDec , unzipPow, unzipDec, unzipCRTC, unzipCRTE @@ -92,6 +92,9 @@ data E -- | Convenient synonym for either CRT representation. type UCycEC t m r = Either (UCyc t m E r) (UCyc t m C r) +-- | Convenient synonym for random sampling. +type UCycPC t m r = Either (UCyc t m P r) (UCyc t m C r) + -- | Represents a cyclotomic ring such as \(\Z[\zeta_m]\), -- \(\Z_q[\zeta_m]\), and \(\Q(\zeta_m)\) in an explicit -- representation: @t@ is the 'Tensor' type for storing coefficient tensors; @@ -330,6 +333,7 @@ unzipCRTC :: (Fact m, UCRTElt t (a,b), UCRTElt t a, UCRTElt t b) => UCyc t m C (a,b) -> (Either (UCyc t m P a) (UCyc t m C a), Either (UCyc t m P b) (UCyc t m C b)) +{-# INLINABLE unzipCRTC #-} unzipCRTC (CRTC s v) = let (ac,bc) = unzipT v (ap,bp) = Pow *** Pow $ unzipT $ crtInvCS s v @@ -343,6 +347,7 @@ unzipCRTE :: (Fact m, UCRTElt t (a,b), UCRTElt t a, UCRTElt t b) => UCyc t m E (a,b) -> (Either (UCyc t m P a) (UCyc t m E a), Either (UCyc t m P b) (UCyc t m E b)) +{-# INLINABLE unzipCRTE #-} unzipCRTE (CRTE _ v) = let (ae,be) = unzipT v (a',b') = unzipT $ fmapT fromExt $ runIdentity crtInv v @@ -353,7 +358,7 @@ unzipCRTE (CRTE _ v) -- | Multiply by the special element \(g_m\). mulG :: (Fact m, UCRTElt t r) => UCyc t m rep r -> UCyc t m rep r -{-# INLINABLE mulG #-} +{-# INLINE mulG #-} mulG (Pow v) = Pow $ mulGPow v mulG (Dec v) = Dec $ mulGDec v mulG (CRTC s v) = CRTC s $ mulGCRTCS s v @@ -381,7 +386,7 @@ divGDec (Dec v) = Dec <$> T.divGDec v -- | Similar to 'divGPow'. divGCRTC :: (Fact m, UCRTElt t r) => UCyc t m C r -> UCyc t m C r -{-# INLINABLE divGCRTC #-} +{-# INLINE divGCRTC #-} divGCRTC (CRTC s v) = CRTC s $ divGCRTCS s v -- | Yield the scaled squared norm of \(g_m \cdot e\) under @@ -481,8 +486,8 @@ twaceDec (Dec v) = Dec $ twacePowDec v -- | Twace into a subring, for the CRT basis. (The output is an -- 'Either' because the subring might not support 'C'.) twaceCRTC :: (m `Divides` m', UCRTElt t r) - => UCyc t m' C r -> Either (UCyc t m P r) (UCyc t m C r) -{-# INLINABLE twaceCRTC #-} + => UCyc t m' C r -> UCycPC t m r +{-# INLINE twaceCRTC #-} twaceCRTC x@(CRTC s' v) = case crtSentinel of -- go to CRTC if valid for target, else go to Pow @@ -663,7 +668,7 @@ instance (Random r, UCRTElt t r, Fact m) => Random (UCyc t m D r) where randomR _ = error "randomR non-sensical for UCyc" instance (Random r, UCRTElt t r, Fact m) - => Random (Either (UCyc t m P r) (UCyc t m C r)) where + => Random (UCycPC t m r) where -- create in CRTC basis if possible, otherwise in powerful random = let cons = case crtSentinel of diff --git a/lol/benchmarks/BenchParams.hs b/lol/benchmarks/BenchParams.hs new file mode 100644 index 00000000..5a46c834 --- /dev/null +++ b/lol/benchmarks/BenchParams.hs @@ -0,0 +1,77 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module BenchParams where + +import Utils + +import Crypto.Lol +import Crypto.Lol.Types +import Crypto.Random.DRBG + +import Data.Singletons +import Data.Promotion.Prelude.Eq +import Data.Singletons.TypeRepStar () + + + + +type Tensors = '[T] +type MRCombos = + '[ '(M, R) ] + +type T = RT +type M = F64*F9*F25 -- F9*F5*F7*F11 -- F64*F9*F25 -- +type R = Zq 1065601 --34651 -- Zq 14401 -- +type M' = M --F3*F5*F11 + +testParam :: Proxy '(T, M, R) +testParam = Proxy + +testParam' :: Proxy '(T,M, R, HashDRBG) +testParam' = Proxy + +twoIdxParam :: Proxy '(T, M', M, R) +twoIdxParam = Proxy + +{- +type Tensors = '[CT,RT] +type MRCombos = + '[ '(F1024, Zq 1051649), -- 1024 / 512 + '(F2048, Zq 1054721), -- 2048 / 1024 + '(F64 * F27, Zq 1048897), -- 1728 / 576 + '(F64 * F81, Zq 1073089), -- 5184 / 1728 + '(F64*F9*F25, Zq 1065601) -- 14400 / 3840 + ] +-} + +type MM'RCombos = + '[ '(F8 * F91, F8 * F91 * F4, Zq 8737), + '(F8 * F91, F8 * F91 * F5, Zq 14561), + '(F128, F128 * F91, Zq 23297) + ] + +-- EAC: must be careful where we use Nub: apparently TypeRepStar doesn't work well with the Tensor constructors +type AllParams = ( '(,) <$> Tensors) <*> MRCombos +allParams :: Proxy AllParams +allParams = Proxy + +type LiftParams = ( '(,) <$> Tensors) <*> MRCombos +liftParams :: Proxy LiftParams +liftParams = Proxy + +type TwoIdxParams = ( '(,) <$> Tensors) <*> MM'RCombos +twoIdxParams :: Proxy TwoIdxParams +twoIdxParams = Proxy + +type ErrorParams = ( '(,) <$> '[HashDRBG]) <*> LiftParams +errorParams :: Proxy ErrorParams +errorParams = Proxy + +data Liftable :: TyFun (Factored, *) Bool -> * +type instance Apply Liftable '(m',r) = Int64 :== (LiftOf r) + +data RemoveM :: TyFun (Factored, Factored, *) (Factored, *) -> * +type instance Apply RemoveM '(m,m',r) = '(m',r) diff --git a/lol/benchmarks/CTBenches.hs b/lol/benchmarks/CTBenches.hs new file mode 100644 index 00000000..8fde8c49 --- /dev/null +++ b/lol/benchmarks/CTBenches.hs @@ -0,0 +1,44 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module CTBenches (ctBenches) where + +import Control.Applicative +import Control.Monad.Random + +import Crypto.Lol.Cyclotomic.Tensor +import Crypto.Lol.Prelude +import Crypto.Lol.Types +import Crypto.Random.DRBG + +import Criterion +import BenchParams + +ctBenches :: IO Benchmark +ctBenches = do + x1 :: T M (R, R) <- getRandom + x2 :: T M R <- getRandom + x3 :: T M R <- getRandom + gen <- newGenIO + return $ bgroup "CT" [ + bench "unzipPow" $ nf unzipT' x1, + bench "unzipDec" $ nf unzipT' x1, + bench "unzipCRT" $ nf unzipT' x1, + bench "zipWith (*)" $ nf (zipWithT' (*) x2) x3, + bench "crt" $ nf (wrap $ fromJust' "CTBenches.crt" crt') x2, + bench "crtInv" $ nf (wrap $ fromJust' "CTBenches.crtInv" crtinv') x2, + bench "l" $ nf (wrap l') x2, + bench "lInv" $ nf (wrap lInv') x2, + bench "*g Pow" $ nf (wrap mulGPow'') x2, + bench "*g CRT" $ nf (wrap $ fromJust' "CTBenches.gcrt" mulGCRT'') x2, + bench "lift" $ nf (fmapT lift) x2, + bench "error" $ nf (evalRand (fmapT (roundMult one) <$> + (CT <$> cDispatchGaussian + (0.1 :: Double) :: Rand (CryptoRand HashDRBG) (T M Double))) :: CryptoRand HashDRBG -> (T M Int64)) gen + + ] \ No newline at end of file diff --git a/lol/benchmarks/CycBenches.hs b/lol/benchmarks/CycBenches.hs index 9c700277..24dbd475 100644 --- a/lol/benchmarks/CycBenches.hs +++ b/lol/benchmarks/CycBenches.hs @@ -1,13 +1,14 @@ -{-# LANGUAGE DataKinds, FlexibleContexts, - NoImplicitPrelude, RebindableSyntax, - ScopedTypeVariables, TypeFamilies, - TypeOperators, UndecidableInstances #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} module CycBenches (cycBenches) where import Apply.Cyc import Benchmarks -import Utils +import BenchParams import Control.Monad.Random @@ -15,24 +16,30 @@ import Crypto.Lol import Crypto.Lol.Types import Crypto.Random.DRBG -import Data.Singletons -import Data.Promotion.Prelude.Eq -import Data.Singletons.TypeRepStar () - cycBenches :: IO Benchmark cycBenches = benchGroup "Cyc" [ - benchGroup "unzipCycPow" $ applyUnzip allParams $ hideArgs bench_unzipCycPow, - benchGroup "unzipCycCRT" $ applyUnzip allParams $ hideArgs bench_unzipCycCRT, - benchGroup "*" $ applyBasic allParams $ hideArgs bench_mul, - benchGroup "crt" $ applyBasic allParams $ hideArgs bench_crt, - benchGroup "crtInv" $ applyBasic allParams $ hideArgs bench_crtInv, - benchGroup "l" $ applyBasic allParams $ hideArgs bench_l, - benchGroup "*g Pow" $ applyBasic allParams $ hideArgs bench_mulgPow, - benchGroup "*g CRT" $ applyBasic allParams $ hideArgs bench_mulgCRT, - benchGroup "lift" $ applyLift liftParams $ hideArgs bench_liftPow, - benchGroup "error" $ applyError errorParams $ hideArgs $ bench_errRounded 0.1, - benchGroup "twace" $ applyTwoIdx twoIdxParams $ hideArgs bench_twacePow, - benchGroup "embed" $ applyTwoIdx twoIdxParams $ hideArgs bench_embedPow + benchGroup "unzipPow" $ [hideArgs bench_unzipCycPow testParam], + benchGroup "unzipDec" $ [hideArgs bench_unzipCycDec testParam], + benchGroup "unzipCRT" $ [hideArgs bench_unzipCycCRT testParam], + benchGroup "zipWith (*)" $ [hideArgs bench_mul testParam], + benchGroup "crt" $ [hideArgs bench_crt testParam], + benchGroup "crtInv" $ [hideArgs bench_crtInv testParam], + benchGroup "l" $ [hideArgs bench_l testParam], + benchGroup "lInv" $ [hideArgs bench_lInv testParam], + benchGroup "*g Pow" $ [hideArgs bench_mulgPow testParam], + benchGroup "*g Dec" $ [hideArgs bench_mulgDec testParam], + benchGroup "*g CRT" $ [hideArgs bench_mulgCRT testParam], + benchGroup "divg Pow" $ [hideArgs bench_divgPow testParam], + benchGroup "divg Dec" $ [hideArgs bench_divgDec testParam], + benchGroup "divg CRT" $ [hideArgs bench_divgCRT testParam], + benchGroup "lift" $ [hideArgs bench_liftPow testParam], + benchGroup "error" $ [hideArgs (bench_errRounded 0.1) testParam'], + benchGroup "twacePow" $ [hideArgs bench_twacePow twoIdxParam], + benchGroup "twaceDec" $ [hideArgs bench_twaceDec twoIdxParam], + benchGroup "twaceCRT" $ [hideArgs bench_twaceCRT twoIdxParam], + benchGroup "embedPow" $ [hideArgs bench_embedPow twoIdxParam], + benchGroup "embedDec" $ [hideArgs bench_embedDec twoIdxParam], + benchGroup "embedCRT" $ [hideArgs bench_embedCRT twoIdxParam] ] bench_unzipCycPow :: (UnzipCtx t m r) => Cyc t m (r,r) -> Bench '(t,m,r) @@ -40,6 +47,11 @@ bench_unzipCycPow a = let a' = advisePow a in bench unzipCyc a' +bench_unzipCycDec :: (UnzipCtx t m r) => Cyc t m (r,r) -> Bench '(t,m,r) +bench_unzipCycDec a = + let a' = adviseDec a + in bench unzipCyc a' + bench_unzipCycCRT :: (UnzipCtx t m r) => Cyc t m (r,r) -> Bench '(t,m,r) bench_unzipCycCRT a = let a' = adviseCRT a @@ -64,18 +76,38 @@ bench_crtInv x = let y = adviseCRT x in bench advisePow y bench_l :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) bench_l x = let y = adviseDec x in bench advisePow y +-- convert input from Pow basis to Dec basis +bench_lInv :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) +bench_lInv x = let y = advisePow x in bench adviseDec y + -- lift an element in the Pow basis bench_liftPow :: forall t m r . (LiftCtx t m r) => Cyc t m r -> Bench '(t,m,r) -bench_liftPow x = let y = advisePow x in bench (liftCyc Pow :: Cyc t m r -> Cyc t m (LiftOf r)) y +bench_liftPow x = let y = advisePow x in bench (liftCyc Pow) y -- multiply by g when input is in Pow basis bench_mulgPow :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) bench_mulgPow x = let y = advisePow x in bench mulG y +-- multiply by g when input is in Dec basis +bench_mulgDec :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) +bench_mulgDec x = let y = adviseDec x in bench mulG y + -- multiply by g when input is in CRT basis bench_mulgCRT :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) bench_mulgCRT x = let y = adviseCRT x in bench mulG y +-- divide by g when input is in Pow basis +bench_divgPow :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) +bench_divgPow x = let y = advisePow $ mulG x in bench divG y + +-- divide by g when input is in Dec basis +bench_divgDec :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) +bench_divgDec x = let y = adviseDec $ mulG x in bench divG y + +-- divide by g when input is in CRT basis +bench_divgCRT :: (BasicCtx t m r) => Cyc t m r -> Bench '(t,m,r) +bench_divgCRT x = let y = adviseCRT x in bench divG y + -- generate a rounded error term bench_errRounded :: forall t m r gen . (ErrorCtx t m r gen) => Double -> Bench '(t,m,r,gen) @@ -89,46 +121,32 @@ bench_twacePow x = let y = advisePow x in bench (twace :: Cyc t m' r -> Cyc t m r) y +bench_twaceDec :: forall t m m' r . (TwoIdxCtx t m m' r) + => Cyc t m' r -> Bench '(t,m,m',r) +bench_twaceDec x = + let y = adviseDec x + in bench (twace :: Cyc t m' r -> Cyc t m r) y + +bench_twaceCRT :: forall t m m' r . (TwoIdxCtx t m m' r) + => Cyc t m' r -> Bench '(t,m,m',r) +bench_twaceCRT x = + let y = adviseCRT x + in bench (twace :: Cyc t m' r -> Cyc t m r) y + bench_embedPow :: forall t m m' r . (TwoIdxCtx t m m' r) => Cyc t m r -> Bench '(t,m,m',r) bench_embedPow x = let y = advisePow x - in bench (embed :: Cyc t m r -> Cyc t m' r) y - -type Tensors = '[CT,RT] -type MRCombos = - '[ '(F1024, Zq 1051649), -- 1024 / 512 - '(F2048, Zq 1054721), -- 2048 / 1024 - '(F64 * F27, Zq 1048897), -- 1728 / 576 - '(F64 * F81, Zq 1073089), -- 5184 / 1728 - '(F64*F9*F25, Zq 1065601) -- 14400 / 3840 - ] - -type MM'RCombos = - '[ '(F8 * F91, F8 * F91 * F4, Zq 8737), - '(F8 * F91, F8 * F91 * F5, Zq 14561), - '(F128, F128 * F91, Zq 23297) - ] - --- EAC: must be careful where we use Nub: apparently TypeRepStar doesn't work well with the Tensor constructors -type AllParams = ( '(,) <$> Tensors) <*> MRCombos -allParams :: Proxy AllParams -allParams = Proxy - -type LiftParams = ( '(,) <$> Tensors) <*> MRCombos -liftParams :: Proxy LiftParams -liftParams = Proxy - -type TwoIdxParams = ( '(,) <$> Tensors) <*> MM'RCombos -twoIdxParams :: Proxy TwoIdxParams -twoIdxParams = Proxy - -type ErrorParams = ( '(,) <$> '[HashDRBG]) <*> LiftParams -errorParams :: Proxy ErrorParams -errorParams = Proxy - -data Liftable :: TyFun (Factored, *) Bool -> * -type instance Apply Liftable '(m',r) = Int64 :== (LiftOf r) - -data RemoveM :: TyFun (Factored, Factored, *) (Factored, *) -> * -type instance Apply RemoveM '(m,m',r) = '(m',r) + in bench (advisePow . embed :: Cyc t m r -> Cyc t m' r) y + +bench_embedDec :: forall t m m' r . (TwoIdxCtx t m m' r) + => Cyc t m r -> Bench '(t,m,m',r) +bench_embedDec x = + let y = adviseDec x + in bench (adviseDec . embed :: Cyc t m r -> Cyc t m' r) y + +bench_embedCRT :: forall t m m' r . (TwoIdxCtx t m m' r) + => Cyc t m r -> Bench '(t,m,m',r) +bench_embedCRT x = + let y = adviseCRT x + in bench (adviseCRT . embed :: Cyc t m r -> Cyc t m' r) y diff --git a/lol/benchmarks/Main.hs b/lol/benchmarks/Main.hs index 7e45342c..db322fd4 100644 --- a/lol/benchmarks/Main.hs +++ b/lol/benchmarks/Main.hs @@ -1,15 +1,177 @@ +{- +import TensorBenches +import Criterion.Main + +main :: IO () +main = defaultMain =<< sequence [ + tensorBenches + ] +-} +{-# LANGUAGE BangPatterns, RecordWildCards #-} import CycBenches +import SimpleTensorBenches import TensorBenches +import SimpleUCycBenches import UCycBenches -import ZqBenches -import Criterion.Main +import Criterion.Internal (runAndAnalyseOne) +import Criterion.Main.Options (defaultConfig) +import Criterion.Measurement (secs) +import Criterion.Monad (Criterion, withConfig) +import Criterion.Types +import Control.Monad (foldM, forM_, when) +import Control.Monad.IO.Class (MonadIO, liftIO) + +import Control.Exception (evaluate) + +import Control.DeepSeq (rnf) + +import Data.List (transpose) +import qualified Data.Map as Map +import Data.Maybe + +import Statistics.Resampling.Bootstrap (Estimate(..)) +import System.Console.ANSI +import System.IO +import Text.Printf + +-- table print parameters +colWidth, testNameWidth :: Int +colWidth = 15 +testNameWidth = 40 +verb :: Verb +verb = Progress + +benches :: [String] +benches = [ + --"unzipPow", + --"unzipDec", + --"unzipCRT", + "zipWith (*)", + "crt", + "crtInv", + "l", + "lInv", + "*g Pow", + --"*g Dec", + "*g CRT", + "divg Pow", + "divg Dec", + --"divg CRT", + "lift"{-, + "error", + "twacePow", + "twaceDec", + "twaceCRT", + "embedPow", + "embedDec", + "embedCRT"-} -main :: IO () -main = defaultMain =<< sequence [ - zqBenches, - tensorBenches, - ucycBenches, - cycBenches ] + +data Verb = Progress | Abridged | Full deriving (Eq) + +main :: IO () +main = do + hSetBuffering stdout NoBuffering -- for better printing of progress + reports <- mapM (getReports =<<) [ + {-simpleTensorBenches, + tensorBenches, + simpleUCycBenches,-} + ucycBenches{-, + cycBenches-} + ] + when (verb == Progress) $ putStrLn "" + printTable $ map reverse reports + +printTable :: [[Report]] -> IO () +printTable rpts = do + let colLbls = map (takeWhile (/= '/') . reportName . head) rpts + printf testName "" + mapM_ (\lbl -> printf col lbl) colLbls + printf "\n" + mapM_ printRow $ transpose rpts + +col, testName :: String +testName = "%-" ++ (show testNameWidth) ++ "s " +col = "%-" ++ (show colWidth) ++ "s " + +printANSI :: (MonadIO m) => Color -> String -> m () +printANSI sgr str = liftIO $ do + setSGR [SetColor Foreground Vivid sgr] + putStrLn str + setSGR [Reset] + +config :: Config +config = defaultConfig {verbosity = if verb == Full then Normal else Quiet} + +getRuntime :: Report -> Double +getRuntime Report{..} = + let SampleAnalysis{..} = reportAnalysis + (builtin, _) = splitAt 1 anRegress + mests = map (\Regression{..} -> Map.lookup "iters" regCoeffs) builtin + [Estimate{..}] = catMaybes mests + in estPoint + +-- See Criterion.Internal.analyseOne +printRow :: [Report] -> IO () +printRow xs@(rpt : _) = do + printf testName $ stripOuterGroup $ reportName rpt + let times = map getRuntime xs + minTime = minimum times + printCol t = + if t > (1.1*minTime) + then do + setSGR [SetColor Foreground Vivid Red] + printf col $ secs t + setSGR [Reset] + else printf col $ secs t + forM_ times printCol + putStrLn "" + +stripOuterGroup :: String -> String +stripOuterGroup = tail . dropWhile (/= '/') + +getReports :: Benchmark -> IO [Report] +getReports = withConfig config . runAndAnalyse + +-- | Run, and analyse, one or more benchmarks. +-- From Criterion.Internal +runAndAnalyse :: Benchmark + -> Criterion [Report] +runAndAnalyse bs = for bs $ \idx desc bm -> do + when (verb == Abridged || verb == Full) $ liftIO $ putStr $ "benchmark " ++ desc + when (verb == Full) $ liftIO $ putStrLn "" + (Analysed rpt) <- runAndAnalyseOne idx desc bm + when (verb == Progress) $ liftIO $ putStr "." + when (verb == Abridged) $ liftIO $ putStrLn $ "..." ++ (secs $ getRuntime rpt) + return rpt + +-- | Iterate over benchmarks. +-- From Criterion.Internal +for :: MonadIO m => Benchmark + -> (Int -> String -> Benchmarkable -> m a) -> m [a] +for bs0 handle = snd <$> go (0::Int, []) ("", bs0) + where + select = flip elem benches . takeWhile (/= '/') . stripOuterGroup + go (!idx,drs) (pfx, Environment mkenv mkbench) + | shouldRun pfx mkbench = do + e <- liftIO $ do + ee <- mkenv + evaluate (rnf ee) + return ee + go (idx,drs) (pfx, mkbench e) + | otherwise = return (idx,drs) + go (!idx, drs) (pfx, Benchmark desc b) + | select desc' = do + x <- handle idx desc' b; + return (idx + 1, x:drs) + | otherwise = return (idx, drs) + where desc' = addPrefix pfx desc + go (!idx,drs) (pfx, BenchGroup desc bs) = + foldM go (idx,drs) [(addPrefix pfx desc, b) | b <- bs] + + shouldRun pfx mkbench = + any (select . addPrefix pfx) . benchNames . mkbench $ + error "Criterion.env could not determine the list of your benchmarks since they force the environment (see the documentation for details)" diff --git a/lol/benchmarks/SimpleTensorBenches.hs b/lol/benchmarks/SimpleTensorBenches.hs new file mode 100644 index 00000000..ce29fde6 --- /dev/null +++ b/lol/benchmarks/SimpleTensorBenches.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module SimpleTensorBenches (simpleTensorBenches) where + +import Control.Applicative +import Control.Monad.Random + +import Crypto.Lol.Prelude +import Crypto.Lol.Cyclotomic.Tensor +import Crypto.Lol.Types +import Crypto.Random.DRBG + +import Criterion +import BenchParams + +simpleTensorBenches :: IO Benchmark +simpleTensorBenches = do + x1 :: T M (R, R) <- getRandom + x2 :: T M R <- getRandom + x3 :: T M R <- getRandom + x4 :: T M' R <- getRandom + let x2' = mulGPow x2 + x2'' = mulGDec x2 + gen <- newGenIO + return $ bgroup "STensor" [ + bench "unzipPow" $ nf unzipT x1, + bench "unzipDec" $ nf unzipT x1, + bench "unzipCRT" $ nf unzipT x1, + bench "zipWith (*)" $ nf (zipWithT (*) x2) x3, + bench "crt" $ nf (fromJust' "SimpleTensorBenches.crt" crt) x2, + bench "crtInv" $ nf (fromJust' "SimpleTensorBenches.crtInv" crtInv) x2, + bench "l" $ nf l x2, + bench "lInv" $ nf lInv x2, + bench "*g Pow" $ nf mulGPow x2, + bench "*g Dec" $ nf mulGDec x2, + bench "*g CRT" $ nf (fromJust' "SimpleTensorBenches.*gcrt" mulGCRT) x2, + bench "divg Pow" $ nf divGPow x2', + bench "divg Dec" $ nf divGDec x2'', + bench "divg CRT" $ nf (fromJust' "SimpleTensorBenches./gcrt" divGCRT) x2, + bench "lift" $ nf (fmapT lift) x2, + bench "error" $ nf (evalRand (fmapT (roundMult one) <$> + (tGaussianDec (0.1 :: Double) :: Rand (CryptoRand HashDRBG) (T M Double))) :: CryptoRand HashDRBG -> T M Int64) gen, + bench "twacePow" $ nf (twacePowDec :: T M R -> T M' R) x2, + bench "twaceDec" $ nf (twacePowDec :: T M R -> T M' R) x2, + bench "twaceCRT" $ nf (fromJust' "SimpleTensorBenches.twaceCRT" twaceCRT :: T M R -> T M' R) x2, + bench "embedPow" $ nf (embedPow :: T M' R -> T M R) x4, + bench "embedDec" $ nf (embedDec :: T M' R -> T M R) x4, + bench "embedCRT" $ nf (fromJust' "SimpleTensorBenches.embedCRT" embedCRT :: T M' R -> T M R) x4 + ] \ No newline at end of file diff --git a/lol/benchmarks/SimpleUCycBenches.hs b/lol/benchmarks/SimpleUCycBenches.hs new file mode 100644 index 00000000..5933f263 --- /dev/null +++ b/lol/benchmarks/SimpleUCycBenches.hs @@ -0,0 +1,61 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} + +module SimpleUCycBenches (simpleUCycBenches) where + +import Control.Applicative +import Control.Monad.Random +import BenchParams + +import Crypto.Lol.Prelude +import Crypto.Lol.Cyclotomic.UCyc +import Crypto.Lol.Types +import Crypto.Random.DRBG + +import Criterion + +simpleUCycBenches :: IO Benchmark +simpleUCycBenches = do + x1 :: UCyc T M P (R, R) <- getRandom + let x1' = toDec x1 + (Right x2) :: UCycPC T M (R, R) <- getRandom + x3 :: UCycEC T M R <- pcToEC <$> getRandom + x4 :: UCyc T M P R <- getRandom + let x5 = toDec x4 + (Right x6) :: UCycPC T M R <- getRandom + x7 :: UCyc T M' P R <- getRandom + let x8 = toDec x7 + x4' = mulG x4 + x5' = mulG x5 + (Right x9) :: UCycPC T M' R <- getRandom + gen <- newGenIO + return $ bgroup "SUCyc" [ + bench "unzipPow" $ nf unzipPow x1, + bench "unzipDec" $ nf unzipDec x1', + bench "unzipCRT" $ nf unzipCRTC x2, + bench "zipWith (*)" $ nf (x3*) x3, + bench "crt" $ nf toCRT x4, + bench "crtInv" $ nf toPow x6, + bench "l" $ nf toPow x5, + bench "lInv" $ nf toDec x4, + bench "*g Pow" $ nf mulG x4, + bench "*g Dec" $ nf mulG x5, + bench "*g CRT" $ nf mulG x6, + bench "divg Pow" $ nf divGPow x4', + bench "divg Dec" $ nf divGDec x5', + bench "divg CRT" $ nf divGCRTC x6, + bench "lift" $ nf lift x4, + bench "error" $ nf (evalRand (errorRounded (0.1 :: Double) :: Rand (CryptoRand HashDRBG) (UCyc T M D Int64))) gen, + bench "twacePow" $ nf (twacePow :: UCyc T M P R -> UCyc T M' P R) x4, + bench "twaceDec" $ nf (twaceDec :: UCyc T M D R -> UCyc T M' D R) x5, + bench "twaceCRT" $ nf (twaceCRTC :: UCyc T M C R -> UCycPC T M' R) x6, + bench "embedPow" $ nf (embedPow :: UCyc T M' P R -> UCyc T M P R) x7, + bench "embedDec" $ nf (embedDec :: UCyc T M' D R -> UCyc T M D R) x8, + bench "embedCRT" $ nf (embedCRTC :: UCyc T M' C R -> UCycPC T M R) x9 + ] + +pcToEC :: UCycPC t m r -> UCycEC t m r +pcToEC (Right x) = (Right x) \ No newline at end of file diff --git a/lol/benchmarks/TensorBenches.hs b/lol/benchmarks/TensorBenches.hs index 59a61892..820132e7 100644 --- a/lol/benchmarks/TensorBenches.hs +++ b/lol/benchmarks/TensorBenches.hs @@ -1,28 +1,131 @@ -{-# LANGUAGE DataKinds, FlexibleContexts, - NoImplicitPrelude, RebindableSyntax, - ScopedTypeVariables, TypeFamilies, - TypeOperators, UndecidableInstances #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} module TensorBenches (tensorBenches) where import Apply.Cyc import Benchmarks -import Utils +import BenchParams -import Crypto.Lol +import Control.Applicative +import Control.Monad.Random + +import Crypto.Lol.Prelude import Crypto.Lol.Cyclotomic.Tensor import Crypto.Lol.Types +import Crypto.Random.DRBG tensorBenches :: IO Benchmark tensorBenches = benchGroup "Tensor" [ - benchGroup "l" $ applyBasic (Proxy::Proxy QuickParams) $ hideArgs bench_l] + benchGroup "unzipPow" $ [hideArgs bench_unzip testParam], + benchGroup "unzipDec" $ [hideArgs bench_unzip testParam], + benchGroup "unzipCRT" $ [hideArgs bench_unzip testParam], + benchGroup "zipWith (*)" $ [hideArgs bench_mul testParam], + benchGroup "crt" $ [hideArgs bench_crt testParam], + benchGroup "crtInv" $ [hideArgs bench_crtInv testParam], + benchGroup "l" $ [hideArgs bench_l testParam], + benchGroup "lInv" $ [hideArgs bench_lInv testParam], + benchGroup "*g Pow" $ [hideArgs bench_mulgPow testParam], + benchGroup "*g Dec" $ [hideArgs bench_mulgDec testParam], + benchGroup "*g CRT" $ [hideArgs bench_mulgCRT testParam], + benchGroup "divg Pow" $ [hideArgs bench_divgPow testParam], + benchGroup "divg Dec" $ [hideArgs bench_divgDec testParam], + benchGroup "divg CRT" $ [hideArgs bench_divgCRT testParam], + benchGroup "lift" $ [hideArgs bench_liftPow testParam], + benchGroup "error" $ [hideArgs (bench_errRounded 0.1) testParam'], + benchGroup "twacePow" $ [hideArgs bench_twacePow twoIdxParam], + benchGroup "twaceDec" $ [hideArgs bench_twacePow twoIdxParam], -- yes, twacePow is correct here. It's the same function! + benchGroup "twaceCRT" $ [hideArgs bench_twaceCRT twoIdxParam], + benchGroup "embedPow" $ [hideArgs bench_embedPow twoIdxParam], + benchGroup "embedDec" $ [hideArgs bench_embedDec twoIdxParam], + benchGroup "embedCRT" $ [hideArgs bench_embedCRT twoIdxParam] + ] + +bench_unzip :: (UnzipCtx t m r) => t m (r,r) -> Bench '(t,m,r) +bench_unzip = bench unzipT + +-- no CRT conversion, just coefficient-wise multiplication +bench_mul :: (BasicCtx t m r) => t m r -> t m r -> Bench '(t,m,r) +bench_mul a = bench (zipWithT (*) a) + +-- convert input from Pow basis to CRT basis +bench_crt :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_crt = bench (fromJust' "TensorBenches.bench_crt" crt) + +-- convert input from CRT basis to Pow basis +bench_crtInv :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_crtInv = bench (fromJust' "TensorBenches.bench_crtInv" crtInv) -- convert input from Dec basis to Pow basis -bench_l :: (Tensor t, Fact m, Additive r, TElt t r, NFData (t m r)) => t m r -> Bench '(t,m,r) +bench_l :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) bench_l = bench l -type QuickTest = '[ '(F128, Zq 257), - '(F32 * F9, Zq 577), - '(F32 * F9, Int64) ] -type Tensors = '[CT,RT] -type QuickParams = ( '(,) <$> Tensors) <*> QuickTest +-- convert input from Dec basis to Pow basis +bench_lInv :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_lInv = bench lInv + +-- lift an element in the Pow basis +bench_liftPow :: forall t m r . (LiftCtx t m r) => t m r -> Bench '(t,m,r) +bench_liftPow = bench (fmapT lift) + +-- multiply by g when input is in Pow basis +bench_mulgPow :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_mulgPow = bench mulGPow + +-- multiply by g when input is in Dec basis +bench_mulgDec :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_mulgDec = bench mulGDec + +-- multiply by g when input is in CRT basis +bench_mulgCRT :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_mulgCRT = bench (fromJust' "TensorBenches.bench_mulgCRT" mulGCRT) + +-- divide by g when input is in Pow basis +bench_divgPow :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_divgPow x = + let y = mulGPow x + in bench divGPow y + +-- divide by g when input is in Dec basis +bench_divgDec :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_divgDec x = + let y = mulGDec x + in bench divGDec y + +-- divide by g when input is in CRT basis +bench_divgCRT :: (BasicCtx t m r) => t m r -> Bench '(t,m,r) +bench_divgCRT = bench (fromJust' "TensorBenches.bench_divgCRT" divGCRT) + +-- generate a rounded error term +bench_errRounded :: forall t m r gen . (ErrorCtx t m r gen) + => Double -> Bench '(t,m,r,gen) +bench_errRounded v = benchIO $ do + gen <- newGenIO + return $ evalRand + (fmapT (roundMult one) <$> + (tGaussianDec v :: Rand (CryptoRand gen) (t m Double)) :: Rand (CryptoRand gen) (t m (LiftOf r))) gen + +bench_twacePow :: forall t m m' r . (TwoIdxCtx t m m' r) + => t m' r -> Bench '(t,m,m',r) +bench_twacePow = bench (twacePowDec :: t m' r -> t m r) + +bench_twaceCRT :: forall t m m' r . (TwoIdxCtx t m m' r) + => t m' r -> Bench '(t,m,m',r) +bench_twaceCRT = bench (fromJust' "TensorBenches.bench_twaceCRT" twaceCRT :: t m' r -> t m r) + +bench_embedPow :: forall t m m' r . (TwoIdxCtx t m m' r) + => t m r -> Bench '(t,m,m',r) +bench_embedPow = bench (embedPow :: t m r -> t m' r) + +bench_embedDec :: forall t m m' r . (TwoIdxCtx t m m' r) + => t m r -> Bench '(t,m,m',r) +bench_embedDec = bench (embedDec :: t m r -> t m' r) + +bench_embedCRT :: forall t m m' r . (TwoIdxCtx t m m' r) + => t m r -> Bench '(t,m,m',r) +bench_embedCRT = bench (fromJust' "TensorBenches.bench_embedCRT" embedCRT :: t m r -> t m' r) diff --git a/lol/benchmarks/UCycBenches.hs b/lol/benchmarks/UCycBenches.hs index d46a9190..793caeab 100644 --- a/lol/benchmarks/UCycBenches.hs +++ b/lol/benchmarks/UCycBenches.hs @@ -1,48 +1,142 @@ -{-# LANGUAGE DataKinds, FlexibleContexts, - NoImplicitPrelude, RebindableSyntax, - ScopedTypeVariables, TypeFamilies, - TypeOperators, UndecidableInstances #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} module UCycBenches (ucycBenches) where import Apply.Cyc import Benchmarks -import Utils +import BenchParams -import Crypto.Lol +import Control.Monad.Random + +import Crypto.Lol.Prelude import Crypto.Lol.Cyclotomic.UCyc import Crypto.Lol.Types +import Crypto.Random.DRBG ucycBenches :: IO Benchmark ucycBenches = benchGroup "UCyc" [ - benchGroup "l" $ applyBasic (Proxy::Proxy QuickParams) $ hideArgs bench_l, - benchGroup "twace" $ applyTwoIdx twoIdxParams $ hideArgs bench_twacePow, - benchGroup "embed" $ applyTwoIdx twoIdxParams $ hideArgs bench_embedPow + benchGroup "unzipPow" $ [hideArgs bench_unzipUCycPow testParam], + benchGroup "unzipDec" $ [hideArgs bench_unzipUCycDec testParam], + benchGroup "unzipCRT" $ [hideArgs bench_unzipUCycCRT testParam], + benchGroup "zipWith (*)" $ [hideArgs bench_mul testParam], + benchGroup "crt" $ [hideArgs bench_crt testParam], + benchGroup "crtInv" $ [hideArgs bench_crtInv testParam], + benchGroup "l" $ [hideArgs bench_l testParam], + benchGroup "lInv" $ [hideArgs bench_lInv testParam], + benchGroup "*g Pow" $ [hideArgs bench_mulgPow testParam], + benchGroup "*g Dec" $ [hideArgs bench_mulgDec testParam], + benchGroup "*g CRT" $ [hideArgs bench_mulgCRT testParam], + benchGroup "divg Pow" $ [hideArgs bench_divgPow testParam], + benchGroup "divg Dec" $ [hideArgs bench_divgDec testParam], + benchGroup "divg CRT" $ [hideArgs bench_divgCRT testParam], + benchGroup "lift" $ [hideArgs bench_liftPow testParam], + benchGroup "error" $ [hideArgs (bench_errRounded 0.1) testParam'], + benchGroup "twacePow" $ [hideArgs bench_twacePow twoIdxParam], + benchGroup "twaceDec" $ [hideArgs bench_twaceDec twoIdxParam], + benchGroup "twaceCRT" $ [hideArgs bench_twaceCRT twoIdxParam], + benchGroup "embedPow" $ [hideArgs bench_embedPow twoIdxParam], + benchGroup "embedDec" $ [hideArgs bench_embedDec twoIdxParam], + benchGroup "embedCRT" $ [hideArgs bench_embedCRT twoIdxParam] ] +bench_unzipUCycPow :: (UnzipCtx t m r) => UCyc t m P (r,r) -> Bench '(t,m,r) +bench_unzipUCycPow = bench unzipPow + +bench_unzipUCycDec :: (UnzipCtx t m r) => UCyc t m D (r,r) -> Bench '(t,m,r) +bench_unzipUCycDec = bench unzipDec + +bench_unzipUCycCRT :: (UnzipCtx t m r) => UCycPC t m (r,r) -> Bench '(t,m,r) +bench_unzipUCycCRT (Right a) = bench unzipCRTC a + +pcToEC :: UCycPC t m r -> UCycEC t m r +pcToEC (Right x) = (Right x) + +-- no CRT conversion, just coefficient-wise multiplication +bench_mul :: (BasicCtx t m r) => UCycPC t m r -> UCycPC t m r -> Bench '(t,m,r) +bench_mul a b = + let a' = pcToEC a + b' = pcToEC b + in bench (a' *) b' + +-- convert input from Pow basis to CRT basis +bench_crt :: (BasicCtx t m r) => UCyc t m P r -> Bench '(t,m,r) +bench_crt = bench toCRT + +-- convert input from CRT basis to Pow basis +bench_crtInv :: (BasicCtx t m r) => UCycPC t m r -> Bench '(t,m,r) +bench_crtInv (Right a) = bench toPow a + -- convert input from Dec basis to Pow basis bench_l :: (BasicCtx t m r) => UCyc t m D r -> Bench '(t,m,r) bench_l = bench toPow +-- convert input from Pow basis to Dec basis +bench_lInv :: (BasicCtx t m r) => UCyc t m P r -> Bench '(t,m,r) +bench_lInv = bench toDec + +-- lift an element in the Pow basis +bench_liftPow :: (LiftCtx t m r) => UCyc t m P r -> Bench '(t,m,r) +bench_liftPow = bench lift + +-- multiply by g when input is in Pow basis +bench_mulgPow :: (BasicCtx t m r) => UCyc t m P r -> Bench '(t,m,r) +bench_mulgPow = bench mulG + +-- multiply by g when input is in Dec basis +bench_mulgDec :: (BasicCtx t m r) => UCyc t m D r -> Bench '(t,m,r) +bench_mulgDec = bench mulG + +-- multiply by g when input is in CRT basis +bench_mulgCRT :: (BasicCtx t m r) => UCycPC t m r -> Bench '(t,m,r) +bench_mulgCRT (Right a) = bench mulG a + +-- divide by g when input is in Pow basis +bench_divgPow :: (BasicCtx t m r) => UCyc t m P r -> Bench '(t,m,r) +bench_divgPow x = + let y = mulG x + in bench divGPow y + +-- divide by g when input is in Dec basis +bench_divgDec :: (BasicCtx t m r) => UCyc t m D r -> Bench '(t,m,r) +bench_divgDec x = + let y = mulG x + in bench divGDec y + +-- divide by g when input is in CRT basis +bench_divgCRT :: (BasicCtx t m r) => UCycPC t m r -> Bench '(t,m,r) +bench_divgCRT (Right a) = bench divGCRTC a + +-- generate a rounded error term +bench_errRounded :: forall t m r gen . (ErrorCtx t m r gen) + => Double -> Bench '(t,m,r,gen) +bench_errRounded v = benchIO $ do + gen <- newGenIO + return $ evalRand (errorRounded v :: Rand (CryptoRand gen) (UCyc t m D (LiftOf r))) gen + bench_twacePow :: forall t m m' r . (TwoIdxCtx t m m' r) => UCyc t m' P r -> Bench '(t,m,m',r) bench_twacePow = bench (twacePow :: UCyc t m' P r -> UCyc t m P r) +bench_twaceDec :: forall t m m' r . (TwoIdxCtx t m m' r) + => UCyc t m' D r -> Bench '(t,m,m',r) +bench_twaceDec = bench (twaceDec :: UCyc t m' D r -> UCyc t m D r) + +bench_twaceCRT :: forall t m m' r . (TwoIdxCtx t m m' r) + => UCycPC t m' r -> Bench '(t,m,m',r) +bench_twaceCRT (Right a) = bench (twaceCRTC :: UCyc t m' C r -> UCycPC t m r) a + bench_embedPow :: forall t m m' r . (TwoIdxCtx t m m' r) => UCyc t m P r -> Bench '(t,m,m',r) bench_embedPow = bench (embedPow :: UCyc t m P r -> UCyc t m' P r) -type QuickTest = '[ '(F128, Zq 257), - '(F32 * F9, Zq 577), - '(F32 * F9, Int64) ] -type Tensors = '[CT,RT] -type QuickParams = ( '(,) <$> Tensors) <*> QuickTest - -type MM'RCombos = - '[ '(F8 * F91, F8 * F91 * F4, Zq 8737), - '(F8 * F91, F8 * F91 * F5, Zq 14561), - '(F128, F128 * F91, Zq 23297) - ] -type TwoIdxParams = ( '(,) <$> Tensors) <*> MM'RCombos -twoIdxParams :: Proxy TwoIdxParams -twoIdxParams = Proxy +bench_embedDec :: forall t m m' r . (TwoIdxCtx t m m' r) + => UCyc t m D r -> Bench '(t,m,m',r) +bench_embedDec = bench (embedDec :: UCyc t m D r -> UCyc t m' D r) + +bench_embedCRT :: forall t m m' r . (TwoIdxCtx t m m' r) + => UCycPC t m r -> Bench '(t,m,m',r) +bench_embedCRT (Right a) = bench (embedCRTC :: UCyc t m C r -> UCycPC t m' r) a diff --git a/lol/lol.cabal b/lol/lol.cabal index b8a4b832..966334fd 100644 --- a/lol/lol.cabal +++ b/lol/lol.cabal @@ -84,8 +84,8 @@ library -- ghc optimizations if flag(opt) - ghc-options: -O3 -Odph -funbox-strict-fields -fno-liberate-case -funfolding-use-threshold1000 -funfolding-keeness-factor1000 - + -- makes lift much faster! + ghc-options: -funfolding-use-threshold1000 exposed-modules: Crypto.Lol Crypto.Lol.Types @@ -205,14 +205,15 @@ Benchmark bench-lol if flag(llvm) ghc-options: -fllvm -optlo-O3 - -- ghc-options: -threaded -rtsopts - ghc-options: -O3 -Odph -funbox-strict-fields -fno-liberate-case -funfolding-use-threshold1000 -funfolding-keeness-factor1000 - -- ghc-options: -O2 -Odph -funbox-strict-fields -fwarn-dodgy-imports -rtsopts - -- ghc-options: -fno-liberate-case -funfolding-use-threshold1000 -funfolding-keeness-factor1000 + ghc-options: -O2 + ghc-options: -ddump-to-file -ddump-simpl + ghc-options: -dsuppress-coercions -dsuppress-type-applications -dsuppress-uniques -dsuppress-module-prefixes build-depends: + ansi-terminal, arithmoi, base, + containers, criterion, deepseq, DRBG, @@ -220,6 +221,7 @@ Benchmark bench-lol MonadRandom, mtl, singletons, + statistics, transformers, vector, repa diff --git a/lol/utils/Apply/Cyc.hs b/lol/utils/Apply/Cyc.hs index 9ddce052..bf785d4b 100644 --- a/lol/utils/Apply/Cyc.hs +++ b/lol/utils/Apply/Cyc.hs @@ -46,7 +46,7 @@ applyBasic params g = run params $ \(BC p) -> g p data UnzipCtxD type UnzipCtx t m r = - (Fact m, CElt t (r,r), Random (t m (r,r)), CElt t r, ShowType '(t,m,r), NFElt r, Random r) + (Fact m, CElt t (r,r), Random (t m (r,r)), CElt t r, ShowType '(t,m,r), NFElt r, Random r, NFData (t m r)) data instance ArgsCtx UnzipCtxD where UzC :: (UnzipCtx t m r) => Proxy '(t,m,r) -> ArgsCtx UnzipCtxD instance (params `Satisfy` UnzipCtxD, UnzipCtx t m r) @@ -64,7 +64,7 @@ applyUnzip params g = run params $ \(UzC p) -> g p data LiftCtxD type LiftCtx t m r = (BasicCtx t m r, Lift' r, CElt t (LiftOf r), NFElt (LiftOf r), ToInteger (LiftOf r), - TElt CT r, TElt RT r, TElt CT (LiftOf r), TElt RT (LiftOf r)) + TElt CT r, TElt RT r, TElt CT (LiftOf r), TElt RT (LiftOf r), NFData (t m (LiftOf r))) data instance ArgsCtx LiftCtxD where LC :: (LiftCtx t m r) => Proxy '(t,m,r) -> ArgsCtx LiftCtxD instance (params `Satisfy` LiftCtxD, LiftCtx t m r) @@ -81,7 +81,7 @@ applyLift params g = run params $ \(LC p) -> g p data ErrorCtxD type ErrorCtx t m r gen = (CElt t r, Fact m, ShowType '(t,m,r,gen), CElt t (LiftOf r), NFElt (LiftOf r), Lift' r, - ToInteger (LiftOf r), CryptoRandomGen gen) + ToInteger (LiftOf r), CryptoRandomGen gen, NFData (t m (LiftOf r))) data instance ArgsCtx ErrorCtxD where EC :: (ErrorCtx t m r gen) => Proxy '(t,m,r,gen) -> ArgsCtx ErrorCtxD instance (params `Satisfy` ErrorCtxD, ErrorCtx t m r gen) @@ -97,7 +97,7 @@ applyError params g = run params $ \(EC p) -> g p data TwoIdxCtxD type TwoIdxCtx t m m' r = (m `Divides` m', CElt t r, IntegralDomain r, Eq r, Random r, NFElt r, - ShowType '(t,m,m',r), Random (t m r), Random (t m' r)) + ShowType '(t,m,m',r), Random (t m r), Random (t m' r), NFData (t m r), NFData (t m' r)) data instance ArgsCtx TwoIdxCtxD where TI :: (TwoIdxCtx t m m' r) => Proxy '(t,m,m',r) -> ArgsCtx TwoIdxCtxD instance (params `Satisfy` TwoIdxCtxD, TwoIdxCtx t m m' r)