module Network.DNS.Unpacker ( Unpacker , UnpackingState(..) , unpack , unpack' , getState , setState , modifyState , skip , lookAhead , bytesRead , getByteString , getLazyByteString , getWord8 , getWord16be , getWord32be , getBinary , liftToBinary ) where import qualified Data.Binary as Binary import qualified Data.Binary.Get as Bin import qualified Data.ByteString as Strict import qualified Data.ByteString.Lazy as Lazy import Data.Bits import Data.Int import Data.Word data UnpackingState s = UnpackingState { stSource :: !Lazy.ByteString , stBytesRead :: !Int64 , stUserState :: s } newtype Unpacker s a = U { unU :: UnpackingState s -> (a, UnpackingState s) } instance Monad (Unpacker s) where return a = U (\ s -> (a, s)) m >>= k = U (\ s -> let (a, s') = unU m s in unU (k a) s') fail err = do bytes <- get stBytesRead U (error (err ++ ". Failed unpacking at byte position " ++ show bytes)) get :: (UnpackingState s -> a) -> Unpacker s a get f = U (\ s -> (f s, s)) set :: (UnpackingState s -> UnpackingState s) -> Unpacker s () set f = U (\ s -> ((), f s)) mkState :: Lazy.ByteString -> Int64 -> s -> UnpackingState s mkState xs n s = UnpackingState { stSource = xs , stBytesRead = n , stUserState = s } unpack' :: Unpacker s a -> s -> Lazy.ByteString -> (a, s) unpack' m s xs = let (a, s') = unU m (mkState xs 0 s) in (a, stUserState s') unpack :: Unpacker s a -> s -> Lazy.ByteString -> a unpack = ((fst .) .) . unpack' getState :: Unpacker s s getState = get stUserState setState :: s -> Unpacker s () setState = modifyState . const modifyState :: (s -> s) -> Unpacker s () modifyState f = set $ \ st -> st { stUserState = f (stUserState st) } skip :: Int64 -> Unpacker s () skip n = getLazyByteString n >> return () lookAhead :: Unpacker s a -> Unpacker s a lookAhead m = U (\ s -> let (a, _) = unU m s in (a, s)) bytesRead :: Integral i => Unpacker s i bytesRead = get stBytesRead >>= return . fromIntegral getByteString :: Int -> Unpacker s Strict.ByteString getByteString n = getLazyByteString (fromIntegral n) >>= return . Strict.concat . Lazy.toChunks getLazyByteString :: Int64 -> Unpacker s Lazy.ByteString getLazyByteString n = do src <- get stSource let (xs, ys) = Lazy.splitAt n src if Lazy.length xs /= n then fail "Too few bytes" else do set $ \ st -> st { stSource = ys , stBytesRead = stBytesRead st + n } return xs getWord8 :: Unpacker s Word8 getWord8 = getLazyByteString 1 >>= return . (`Lazy.index` 0) getWord16be :: Unpacker s Word16 getWord16be = do xs <- getLazyByteString 2 return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 8) .|. (fromIntegral (xs `Lazy.index` 1)) getWord32be :: Unpacker s Word32 getWord32be = do xs <- getLazyByteString 4 return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 24) .|. (fromIntegral (xs `Lazy.index` 1) `shiftL` 16) .|. (fromIntegral (xs `Lazy.index` 2) `shiftL` 8) .|. (fromIntegral (xs `Lazy.index` 3)) getBinary :: Binary.Binary a => Unpacker s a getBinary = do s <- get id let (a, rest, bytes) = Bin.runGetState Binary.get (stSource s) (stBytesRead s) set $ \ st -> st { stSource = rest , stBytesRead = bytes } return a liftToBinary :: s -> Unpacker s a -> Bin.Get a liftToBinary s m = do bytes <- Bin.bytesRead src <- Bin.getRemainingLazyByteString let (a, s') = unU m (mkState src bytes s) -- These bytes was consumed by the unpacker. Bin.skip (fromIntegral (stBytesRead s' - bytes)) return a