--- /dev/null
+module Network.DNS.Unpacker
+ ( Unpacker
+ , UnpackingState(..)
+
+ , unpack
+ , unpack'
+
+ , getState
+ , setState
+ , modifyState
+
+ , skip
+ , lookAhead
+ , bytesRead
+
+ , getByteString
+ , getLazyByteString
+ , getWord8
+ , getWord16be
+ , getWord32be
+
+ , getBinary
+ , liftToBinary
+ )
+ where
+
+import qualified Data.Binary as Binary
+import qualified Data.Binary.Get as Bin
+import qualified Data.ByteString as Strict
+import qualified Data.ByteString.Lazy as Lazy
+import Data.Bits
+import Data.Int
+import Data.Word
+
+
+data UnpackingState s
+ = UnpackingState {
+ stSource :: !Lazy.ByteString
+ , stBytesRead :: !Int64
+ , stUserState :: s
+ }
+
+newtype Unpacker s a = U { unU :: UnpackingState s -> (a, UnpackingState s) }
+
+instance Monad (Unpacker s) where
+ return a = U (\ s -> (a, s))
+ m >>= k = U (\ s -> let (a, s') = unU m s
+ in
+ unU (k a) s')
+ fail err = do bytes <- get stBytesRead
+ U (error (err
+ ++ ". Failed unpacking at byte position "
+ ++ show bytes))
+
+get :: (UnpackingState s -> a) -> Unpacker s a
+get f = U (\ s -> (f s, s))
+
+set :: (UnpackingState s -> UnpackingState s) -> Unpacker s ()
+set f = U (\ s -> ((), f s))
+
+mkState :: Lazy.ByteString -> Int64 -> s -> UnpackingState s
+mkState xs n s
+ = UnpackingState {
+ stSource = xs
+ , stBytesRead = n
+ , stUserState = s
+ }
+
+unpack' :: Unpacker s a -> s -> Lazy.ByteString -> (a, s)
+unpack' m s xs
+ = let (a, s') = unU m (mkState xs 0 s)
+ in
+ (a, stUserState s')
+
+unpack :: Unpacker s a -> s -> Lazy.ByteString -> a
+unpack = ((fst .) .) . unpack'
+
+getState :: Unpacker s s
+getState = get stUserState
+
+setState :: s -> Unpacker s ()
+setState = modifyState . const
+
+modifyState :: (s -> s) -> Unpacker s ()
+modifyState f
+ = set $ \ st -> st { stUserState = f (stUserState st) }
+
+skip :: Int64 -> Unpacker s ()
+skip n = getLazyByteString n >> return ()
+
+lookAhead :: Unpacker s a -> Unpacker s a
+lookAhead m = U (\ s -> let (a, _) = unU m s
+ in
+ (a, s))
+
+bytesRead :: Integral i => Unpacker s i
+bytesRead = get stBytesRead >>= return . fromIntegral
+
+getByteString :: Int -> Unpacker s Strict.ByteString
+getByteString n = getLazyByteString (fromIntegral n) >>= return . Strict.concat . Lazy.toChunks
+
+getLazyByteString :: Int64 -> Unpacker s Lazy.ByteString
+getLazyByteString n
+ = do src <- get stSource
+ let (xs, ys) = Lazy.splitAt n src
+ if Lazy.length xs /= n then
+ fail "Too few bytes"
+ else
+ do set $ \ st -> st {
+ stSource = ys
+ , stBytesRead = stBytesRead st + n
+ }
+ return xs
+
+getWord8 :: Unpacker s Word8
+getWord8 = getLazyByteString 1 >>= return . (`Lazy.index` 0)
+
+getWord16be :: Unpacker s Word16
+getWord16be = do xs <- getLazyByteString 2
+ return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 8) .|.
+ (fromIntegral (xs `Lazy.index` 1))
+
+getWord32be :: Unpacker s Word32
+getWord32be = do xs <- getLazyByteString 4
+ return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 24) .|.
+ (fromIntegral (xs `Lazy.index` 1) `shiftL` 16) .|.
+ (fromIntegral (xs `Lazy.index` 2) `shiftL` 8) .|.
+ (fromIntegral (xs `Lazy.index` 3))
+
+getBinary :: Binary.Binary a => Unpacker s a
+getBinary = do s <- get id
+ let (a, rest, bytes) = Bin.runGetState Binary.get (stSource s) (stBytesRead s)
+ set $ \ st -> st {
+ stSource = rest
+ , stBytesRead = bytes
+ }
+ return a
+
+
+liftToBinary :: s -> Unpacker s a -> Bin.Get a
+liftToBinary s m
+ = do bytes <- Bin.bytesRead
+ src <- Bin.getRemainingLazyByteString
+
+ let (a, s') = unU m (mkState src bytes s)
+
+ -- These bytes was consumed by the unpacker.
+ Bin.skip (fromIntegral (stBytesRead s' - bytes))
+
+ return a