]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Unpacker.hs
Introduce Unpacker monad to clean up things.
[haskell-dns.git] / Network / DNS / Unpacker.hs
diff --git a/Network/DNS/Unpacker.hs b/Network/DNS/Unpacker.hs
new file mode 100644 (file)
index 0000000..db34946
--- /dev/null
@@ -0,0 +1,150 @@
+module Network.DNS.Unpacker
+    ( Unpacker
+    , UnpackingState(..)
+
+    , unpack
+    , unpack'
+
+    , getState
+    , setState
+    , modifyState
+
+    , skip
+    , lookAhead
+    , bytesRead
+
+    , getByteString
+    , getLazyByteString
+    , getWord8
+    , getWord16be
+    , getWord32be
+
+    , getBinary
+    , liftToBinary
+    )
+    where
+
+import qualified Data.Binary as Binary
+import qualified Data.Binary.Get as Bin
+import qualified Data.ByteString as Strict
+import qualified Data.ByteString.Lazy as Lazy
+import           Data.Bits
+import           Data.Int
+import           Data.Word
+
+
+data UnpackingState s
+    = UnpackingState {
+        stSource    :: !Lazy.ByteString
+      , stBytesRead :: !Int64
+      , stUserState :: s
+      }
+
+newtype Unpacker s a = U { unU :: UnpackingState s -> (a, UnpackingState s) }
+
+instance Monad (Unpacker s) where
+    return a = U (\ s -> (a, s))
+    m >>= k  = U (\ s -> let (a, s') = unU m s
+                         in
+                           unU (k a) s')
+    fail err = do bytes <- get stBytesRead
+                  U (error (err
+                            ++ ". Failed unpacking at byte position "
+                            ++ show bytes))
+
+get :: (UnpackingState s -> a) -> Unpacker s a
+get f = U (\ s -> (f s, s))
+
+set :: (UnpackingState s -> UnpackingState s) -> Unpacker s ()
+set f = U (\ s -> ((), f s))
+
+mkState :: Lazy.ByteString -> Int64 -> s -> UnpackingState s
+mkState xs n s
+    = UnpackingState {
+        stSource    = xs
+      , stBytesRead = n
+      , stUserState = s
+      }
+
+unpack' :: Unpacker s a -> s -> Lazy.ByteString -> (a, s)
+unpack' m s xs
+    = let (a, s') = unU m (mkState xs 0 s)
+      in
+        (a, stUserState s')
+
+unpack :: Unpacker s a -> s -> Lazy.ByteString -> a
+unpack = ((fst .) .) . unpack'
+
+getState :: Unpacker s s
+getState = get stUserState
+
+setState :: s -> Unpacker s ()
+setState = modifyState . const
+
+modifyState :: (s -> s) -> Unpacker s ()
+modifyState f
+    = set $ \ st -> st { stUserState = f (stUserState st) }
+
+skip :: Int64 -> Unpacker s ()
+skip n = getLazyByteString n >> return ()
+
+lookAhead :: Unpacker s a -> Unpacker s a
+lookAhead m = U (\ s -> let (a, _) = unU m s
+                        in
+                          (a, s))
+
+bytesRead :: Integral i => Unpacker s i
+bytesRead = get stBytesRead >>= return . fromIntegral
+
+getByteString :: Int -> Unpacker s Strict.ByteString
+getByteString n = getLazyByteString (fromIntegral n) >>= return . Strict.concat . Lazy.toChunks
+
+getLazyByteString :: Int64 -> Unpacker s Lazy.ByteString
+getLazyByteString n
+    = do src <- get stSource
+         let (xs, ys) = Lazy.splitAt n src
+         if Lazy.length xs /= n then
+             fail "Too few bytes"
+           else
+             do set $ \ st -> st {
+                                stSource    = ys
+                              , stBytesRead = stBytesRead st + n
+                              }
+                return xs
+
+getWord8 :: Unpacker s Word8
+getWord8 = getLazyByteString 1 >>= return . (`Lazy.index` 0)
+
+getWord16be :: Unpacker s Word16
+getWord16be = do xs <- getLazyByteString 2
+                 return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 8) .|.
+                          (fromIntegral (xs `Lazy.index` 1))
+
+getWord32be :: Unpacker s Word32
+getWord32be = do xs <- getLazyByteString 4
+                 return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 24) .|.
+                          (fromIntegral (xs `Lazy.index` 1) `shiftL` 16) .|.
+                          (fromIntegral (xs `Lazy.index` 2) `shiftL`  8) .|.
+                          (fromIntegral (xs `Lazy.index` 3))
+
+getBinary :: Binary.Binary a => Unpacker s a
+getBinary = do s <- get id
+               let (a, rest, bytes) = Bin.runGetState Binary.get (stSource s) (stBytesRead s)
+               set $ \ st -> st {
+                               stSource    = rest
+                             , stBytesRead = bytes
+                             }
+               return a
+
+
+liftToBinary :: s -> Unpacker s a -> Bin.Get a
+liftToBinary s m
+    = do bytes <- Bin.bytesRead
+         src   <- Bin.getRemainingLazyByteString
+
+         let (a, s') = unU m (mkState src bytes s)
+
+         -- These bytes was consumed by the unpacker.
+         Bin.skip (fromIntegral (stBytesRead s' - bytes))
+
+         return a