Ticket #13059: StoreImpl.hs

File StoreImpl.hs, 13.4 KB (added by RyanGlScott, 2 years ago)
Line 
1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE DataKinds #-}
3{-# LANGUAGE DefaultSignatures #-}
4{-# LANGUAGE DeriveFunctor #-}
5{-# LANGUAGE EmptyCase #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE FlexibleInstances #-}
8{-# LANGUAGE KindSignatures #-}
9{-# LANGUAGE MagicHash #-}
10{-# LANGUAGE MultiParamTypeClasses #-}
11{-# LANGUAGE RankNTypes #-}
12{-# LANGUAGE ScopedTypeVariables #-}
13{-# LANGUAGE TypeFamilies #-}
14{-# LANGUAGE TypeOperators #-}
15{-# LANGUAGE UnboxedTuples #-}
16{-# LANGUAGE UndecidableInstances #-}
17module StoreImpl where
18
19import Control.Exception (Exception(..), throwIO)
20import Control.Monad (when)
21import Data.Proxy
22import Data.Typeable
23import Data.Word
24import Foreign.Ptr
25import qualified Foreign.Storable as Storable (Storable(..))
26import Foreign.Storable (Storable, sizeOf)
27import GHC.Generics
28import GHC.Ptr (Ptr(..))
29import GHC.Prim (ByteArray#, MutableByteArray#, RealWorld,
30                 copyByteArrayToAddr#, copyAddrToByteArray#,
31                 newByteArray#, unsafeFreezeByteArray#)
32import GHC.Types (IO(..), Int(..))
33import GHC.TypeLits
34
35class Store a where
36    size :: Size a
37    poke :: a -> Poke ()
38    peek :: Peek a
39
40    default size :: (Generic a, GStoreSize (Rep a)) => Size a
41    size = genericSize
42    {-# INLINE size #-}
43
44    default poke :: (Generic a, GStorePoke (Rep a)) => a -> Poke ()
45    poke = genericPoke
46    {-# INLINE poke #-}
47
48    default peek :: (Generic a , GStorePeek (Rep a)) => Peek a
49    peek = genericPeek
50    {-# INLINE peek #-}
51
52data Size a
53    = VarSize (a -> Int)
54    | ConstSize !Int
55
56getSizeWith :: Size a -> a -> Int
57getSizeWith (VarSize f) x = f x
58getSizeWith (ConstSize n) _ = n
59{-# INLINE getSizeWith #-}
60
61contramapSize :: (a -> b) -> Size b -> Size a
62contramapSize f (VarSize g) = VarSize (g . f)
63contramapSize _ (ConstSize n) = ConstSize n
64{-# INLINE contramapSize #-}
65
66combineSize :: forall a b c. (Store a, Store b) => (c -> a) -> (c -> b) -> Size c
67combineSize toA toB = combineSizeWith toA toB size size
68{-# INLINE combineSize #-}
69
70
71combineSizeWith :: forall a b c. (c -> a) -> (c -> b) -> Size a -> Size b -> Size c
72combineSizeWith toA toB sizeA sizeB =
73    case (sizeA, sizeB) of
74        (VarSize f, VarSize g) -> VarSize (\x -> f (toA x) + g (toB x))
75        (VarSize f, ConstSize m) -> VarSize (\x -> f (toA x) + m)
76        (ConstSize n, VarSize g) -> VarSize (\x -> n + g (toB x))
77        (ConstSize n, ConstSize m) -> ConstSize (n + m)
78{-# INLINE combineSizeWith #-}
79
80sizeStorableTy :: forall a. Storable a => String -> Size a
81sizeStorableTy ty = ConstSize (sizeOf (error msg :: a))
82  where
83    msg = "In Data.Store.storableSize: " ++ ty ++ "'s sizeOf evaluated its argument."
84{-# INLINE sizeStorableTy #-}
85
86peekStorable :: forall a. (Storable a, Typeable a) => Peek a
87peekStorable = peekStorableTy (show (typeRep (Proxy :: Proxy a)))
88{-# INLINE peekStorable #-}
89
90peekStorableTy :: forall a. Storable a => String -> Peek a
91peekStorableTy ty = Peek $ \ps ptr -> do
92    let ptr' = ptr `plusPtr` sz
93        sz = sizeOf (undefined :: a)
94        remaining = peekStateEndPtr ps `minusPtr` ptr
95    when (sz > remaining) $ -- Do not perform the check on the new pointer, since it could have overflowed
96        tooManyBytes sz remaining ty
97    x <- Storable.peek (castPtr ptr)
98    return (ptr', x)
99
100pokeStorable :: Storable a => a -> Poke ()
101pokeStorable x = Poke $ \ps offset -> do
102    let targetPtr = pokeStatePtr ps `plusPtr` offset
103    Storable.poke targetPtr x
104    let !newOffset = offset + sizeOf x
105    return (newOffset, ())
106{-# INLINE pokeStorable #-}
107
108type Offset = Int
109
110data PeekException = PeekException
111    { peekExBytesFromEnd :: Offset
112    , peekExMessage :: String -- T.Text
113    } deriving (Eq, Show, Typeable)
114
115instance Exception PeekException where
116    displayException (PeekException offset msg) =
117        "Exception while peeking, " ++
118        show offset ++
119        " bytes from end: " ++
120        {-T.unpack-} msg
121
122peekException :: {-T.Text-} String -> Peek a
123peekException msg = Peek $ \ps ptr -> throwIO (PeekException (peekStateEndPtr ps `minusPtr` ptr) msg)
124
125tooManyBytes :: Int -> Int -> String -> IO void
126tooManyBytes needed remaining ty =
127    throwIO $ PeekException remaining $ {- T.pack $ -}
128        "Attempted to read too many bytes for " ++
129        ty ++
130        ". Needed " ++
131        show needed ++ ", but only " ++
132        show remaining ++ " remain."
133
134negativeBytes :: Int -> Int -> String -> IO void
135negativeBytes needed remaining ty =
136    throwIO $ PeekException remaining $ {- T.pack $ -}
137        "Attempted to read negative number of bytes for " ++
138        ty ++
139        ". Tried to read " ++
140        show needed ++ ".  This probably means that we're trying to read invalid data."
141
142data PokeException = PokeException
143    { pokeExByteIndex :: Offset
144    , pokeExMessage :: String -- T.Text
145    }
146    deriving (Eq, Show, Typeable)
147
148instance Exception PokeException where
149    displayException (PokeException offset msg) =
150        "Exception while poking, at byte index " ++
151        show offset ++
152        " : " ++
153        {-T.unpack-} msg
154
155pokeException :: {-T.Text-} String -> Poke a
156pokeException msg = Poke $ \_ off -> throwIO (PokeException off msg)
157
158newtype Poke a = Poke
159    { runPoke :: PokeState -> Offset -> IO (Offset, a)
160    } deriving Functor
161
162instance Applicative Poke where
163    pure x = Poke $ \_ptr offset -> pure (offset, x)
164    {-# INLINE pure #-}
165    Poke f <*> Poke g = Poke $ \ptr offset1 -> do
166        (offset2, f') <- f ptr offset1
167        (offset3, g') <- g ptr offset2
168        return (offset3, f' g')
169    {-# INLINE (<*>) #-}
170    Poke f *> Poke g = Poke $ \ptr offset1 -> do
171        (offset2, _) <- f ptr offset1
172        g ptr offset2
173    {-# INLINE (*>) #-}
174
175instance Monad Poke where
176    return = pure
177    {-# INLINE return #-}
178    (>>) = (*>)
179    {-# INLINE (>>) #-}
180    Poke x >>= f = Poke $ \ptr offset1 -> do
181        (offset2, x') <- x ptr offset1
182        runPoke (f x') ptr offset2
183    {-# INLINE (>>=) #-}
184    fail = pokeException {- . T.pack -}
185    {-# INLINE fail #-}
186
187newtype PokeState = PokeState
188    { pokeStatePtr :: Ptr Word8
189    }
190
191newtype Peek a = Peek
192    { runPeek :: PeekState -> Ptr Word8 -> IO (Ptr Word8, a)
193    } deriving Functor
194
195instance Applicative Peek where
196    pure x = Peek (\_ ptr -> return (ptr, x))
197    {-# INLINE pure #-}
198    Peek f <*> Peek g = Peek $ \end ptr1 -> do
199        (ptr2, f') <- f end ptr1
200        (ptr3, g') <- g end ptr2
201        return (ptr3, f' g')
202    {-# INLINE (<*>) #-}
203    Peek f *> Peek g = Peek $ \end ptr1 -> do
204        (ptr2, _) <- f end ptr1
205        g end ptr2
206    {-# INLINE (*>) #-}
207
208instance Monad Peek where
209    return = pure
210    {-# INLINE return #-}
211    (>>) = (*>)
212    {-# INLINE (>>) #-}
213    Peek x >>= f = Peek $ \end ptr1 -> do
214        (ptr2, x') <- x end ptr1
215        runPeek (f x') end ptr2
216    {-# INLINE (>>=) #-}
217    fail = peekException {- . T.pack -}
218    {-# INLINE fail #-}
219
220newtype PeekState = PeekState
221    { peekStateEndPtr :: Ptr Word8 }
222
223genericSize :: (Generic a, GStoreSize (Rep a)) => Size a
224genericSize = contramapSize from gsize
225{-# INLINE genericSize #-}
226
227genericPoke :: (Generic a, GStorePoke (Rep a)) => a -> Poke ()
228genericPoke = gpoke . from
229{-# INLINE genericPoke #-}
230
231genericPeek :: (Generic a , GStorePeek (Rep a)) => Peek a
232genericPeek = to <$> gpeek
233{-# INLINE genericPeek #-}
234
235pokeFromByteArray :: ByteArray# -> Int -> Int -> Poke ()
236pokeFromByteArray sourceArr sourceOffset len =
237    Poke $ \targetState targetOffset -> do
238        let target = (pokeStatePtr targetState) `plusPtr` targetOffset
239        copyByteArrayToAddr sourceArr sourceOffset target len
240        let !newOffset = targetOffset + len
241        return (newOffset, ())
242{-# INLINE pokeFromByteArray #-}
243
244peekToByteArray :: String -> Int -> Peek ByteArray
245peekToByteArray ty len =
246    Peek $ \ps sourcePtr -> do
247        let ptr2 = sourcePtr `plusPtr` len
248            remaining = peekStateEndPtr ps `minusPtr` sourcePtr
249        when (len > remaining) $ -- Do not perform the check on the new pointer, since it could have overflowed
250            tooManyBytes len remaining ty
251        when (len < 0) $
252            negativeBytes len remaining ty
253        marr <- newByteArray len
254        copyAddrToByteArray sourcePtr marr 0 len
255        x <- unsafeFreezeByteArray marr
256        return (ptr2, x)
257{-# INLINE peekToByteArray #-}
258
259copyByteArrayToAddr :: ByteArray# -> Int -> Ptr a -> Int -> IO ()
260copyByteArrayToAddr arr (I# offset) (Ptr addr) (I# len) =
261    IO (\s -> (# copyByteArrayToAddr# arr offset addr len s, () #))
262{-# INLINE copyByteArrayToAddr  #-}
263
264copyAddrToByteArray :: Ptr a -> MutableByteArray RealWorld -> Int -> Int -> IO ()
265copyAddrToByteArray (Ptr addr) (MutableByteArray arr) (I# offset) (I# len) =
266    IO (\s -> (# copyAddrToByteArray# addr arr offset len s, () #))
267{-# INLINE copyAddrToByteArray  #-}
268
269type family SumArity (a :: * -> *) :: Nat where
270    SumArity (C1 c a) = 1
271    SumArity (x :+: y) = SumArity x + SumArity y
272
273class GStoreSize f where gsize :: Size (f a)
274class GStorePoke f where gpoke :: f a -> Poke ()
275class GStorePeek f where gpeek :: Peek (f a)
276
277instance GStoreSize f => GStoreSize (M1 i c f) where
278    gsize = contramapSize unM1 gsize
279    {-# INLINE gsize #-}
280instance GStorePoke f => GStorePoke (M1 i c f) where
281    gpoke = gpoke . unM1
282    {-# INLINE gpoke #-}
283instance GStorePeek f => GStorePeek (M1 i c f) where
284    gpeek = fmap M1 gpeek
285    {-# INLINE gpeek #-}
286
287instance Store a => GStoreSize (K1 i a) where
288    gsize = contramapSize unK1 size
289    {-# INLINE gsize #-}
290instance Store a => GStorePoke (K1 i a) where
291    gpoke = poke . unK1
292    {-# INLINE gpoke #-}
293instance Store a => GStorePeek (K1 i a) where
294    gpeek = fmap K1 peek
295    {-# INLINE gpeek #-}
296
297instance GStoreSize U1 where
298    gsize = ConstSize 0
299    {-# INLINE gsize #-}
300instance GStorePoke U1 where
301    gpoke _ = return ()
302    {-# INLINE gpoke #-}
303instance GStorePeek U1 where
304    gpeek = return U1
305    {-# INLINE gpeek #-}
306
307instance GStoreSize V1 where
308    gsize = ConstSize 0
309    {-# INLINE gsize #-}
310instance GStorePoke V1 where
311    gpoke x = case x of {}
312    {-# INLINE gpoke #-}
313instance GStorePeek V1 where
314    gpeek = undefined
315    {-# INLINE gpeek #-}
316
317instance (GStoreSize a, GStoreSize b) => GStoreSize (a :*: b) where
318    gsize = combineSizeWith (\(x :*: _) -> x) (\(_ :*: y) -> y) gsize gsize
319    {-# INLINE gsize #-}
320instance (GStorePoke a, GStorePoke b) => GStorePoke (a :*: b) where
321    gpoke (a :*: b) = gpoke a >> gpoke b
322    {-# INLINE gpoke #-}
323instance (GStorePeek a, GStorePeek b) => GStorePeek (a :*: b) where
324    gpeek = (:*:) <$> gpeek <*> gpeek
325    {-# INLINE gpeek #-}
326
327instance (SumArity (a :+: b) <= 255, GStoreSizeSum 0 (a :+: b))
328         => GStoreSize (a :+: b) where
329    gsize = VarSize $ \x -> sizeOf (undefined :: Word8) + gsizeSum x (Proxy :: Proxy 0)
330    {-# INLINE gsize #-}
331instance (SumArity (a :+: b) <= 255, GStorePokeSum 0 (a :+: b))
332         => GStorePoke (a :+: b) where
333    gpoke x = gpokeSum x (Proxy :: Proxy 0)
334    {-# INLINE gpoke #-}
335instance (SumArity (a :+: b) <= 255, GStorePeekSum 0 (a :+: b))
336         => GStorePeek (a :+: b) where
337    gpeek = do
338        tag <- peekStorable
339        gpeekSum tag (Proxy :: Proxy 0)
340    {-# INLINE gpeek #-}
341
342class KnownNat n => GStoreSizeSum (n :: Nat) (f :: * -> *) where gsizeSum :: f a -> Proxy n -> Int
343class KnownNat n => GStorePokeSum (n :: Nat) (f :: * -> *) where gpokeSum :: f p -> Proxy n -> Poke ()
344class KnownNat n => GStorePeekSum (n :: Nat) (f :: * -> *) where gpeekSum :: Word8 -> Proxy n -> Peek (f p)
345
346instance (GStoreSizeSum n a, GStoreSizeSum (n + SumArity a) b, KnownNat n)
347         => GStoreSizeSum n (a :+: b) where
348    gsizeSum (L1 l) _ = gsizeSum l (Proxy :: Proxy n)
349    gsizeSum (R1 r) _ = gsizeSum r (Proxy :: Proxy (n + SumArity a))
350    {-# INLINE gsizeSum #-}
351instance (GStorePokeSum n a, GStorePokeSum (n + SumArity a) b, KnownNat n)
352         => GStorePokeSum n (a :+: b) where
353    gpokeSum (L1 l) _ = gpokeSum l (Proxy :: Proxy n)
354    gpokeSum (R1 r) _ = gpokeSum r (Proxy :: Proxy (n + SumArity a))
355    {-# INLINE gpokeSum #-}
356instance (GStorePeekSum n a, GStorePeekSum (n + SumArity a) b, KnownNat n)
357         => GStorePeekSum n (a :+: b) where
358    gpeekSum tag proxyL
359        | tag < sizeL = L1 <$> gpeekSum tag proxyL
360        | otherwise = R1 <$> gpeekSum tag (Proxy :: Proxy (n + SumArity a))
361      where
362        sizeL = fromInteger (natVal (Proxy :: Proxy (n + SumArity a)))
363    {-# INLINE gpeekSum #-}
364
365instance (GStoreSize a, KnownNat n) => GStoreSizeSum n (C1 c a) where
366    gsizeSum x _ = getSizeWith gsize x
367    {-# INLINE gsizeSum #-}
368instance (GStorePoke a, KnownNat n) => GStorePokeSum n (C1 c a) where
369    gpokeSum x _ = do
370        pokeStorable (fromInteger (natVal (Proxy :: Proxy n)) :: Word8)
371        gpoke x
372    {-# INLINE gpokeSum #-}
373instance (GStorePeek a, KnownNat n) => GStorePeekSum n (C1 c a) where
374    gpeekSum tag _
375        | tag == cur = gpeek
376        | tag > cur = peekException "Sum tag invalid"
377        | otherwise = peekException "Error in implementation of Store Generics"
378      where
379        cur = fromInteger (natVal (Proxy :: Proxy n))
380    {-# INLINE gpeekSum #-}
381
382data ByteArray = ByteArray ByteArray#
383data MutableByteArray s = MutableByteArray (MutableByteArray# s)
384
385newByteArray :: Int -> IO (MutableByteArray RealWorld)
386{-# INLINE newByteArray #-}
387newByteArray (I# n#)
388  = IO (\s# -> case newByteArray# n# s# of
389                        (# s'#, arr# #) -> (# s'#, MutableByteArray arr# #))
390
391unsafeFreezeByteArray
392  :: MutableByteArray RealWorld -> IO ByteArray
393{-# INLINE unsafeFreezeByteArray #-}
394unsafeFreezeByteArray (MutableByteArray arr#)
395  = IO (\s# -> case unsafeFreezeByteArray# arr# s# of
396                        (# s'#, arr'# #) -> (# s'#, ByteArray arr'# #))