]> gitweb @ CieloNegro.org - haskell-dns.git/commitdiff
Introduce Unpacker monad to clean up things.
authorPHO <pho@cielonegro.org>
Fri, 22 May 2009 05:13:11 +0000 (14:13 +0900)
committerPHO <pho@cielonegro.org>
Fri, 22 May 2009 05:13:11 +0000 (14:13 +0900)
DNSUnitTest.hs
Network/DNS/Message.hs
Network/DNS/Unpacker.hs [new file with mode: 0644]
dns.cabal

index 07d3adff52fdbe6a28ecf3ab62482a35c495b602..76a677d9e4deb963e2c3dd062f1f07ea55c81310 100644 (file)
@@ -7,101 +7,111 @@ import           System.IO.Unsafe
 import           Test.HUnit
 
 
-parseMsg :: [Word8] -> Message
-parseMsg = decode . LBS.pack
-
-
-testData :: [Test]
-testData = [ (parseMsg [ 0x22, 0x79, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00
-                       , 0x00, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x69
-                       , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E
-                       , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67
-                       , 0x00, 0x00, 0x05, 0x00, 0x01
-                       ]
-              ~?=
-              Message {
-                msgHeader = Header {
-                              hdMessageID             = 8825
-                            , hdMessageType           = Query
-                            , hdOpcode                = StandardQuery
-                            , hdIsAuthoritativeAnswer = False
-                            , hdIsTruncated           = False
-                            , hdIsRecursionDesired    = True
-                            , hdIsRecursionAvailable  = False
-                            , hdResponseCode          = NoError
-                            }
-              , msgQuestions   = [ Question {
-                                     qName  = mkDomainName "mail.cielonegro.org."
-                                   , qType  = wrapQueryType CNAME
-                                   , qClass = IN
-                                   }
-                                 ]
-              , msgAnswers     = []
-              , msgAuthorities = []
-              , msgAdditionals = []
-              }
+messages :: [([Word8], Message)]
+messages = [ ( [ 0x22, 0x79, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00
+               , 0x00, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x69
+               , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E
+               , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67
+               , 0x00, 0x00, 0x05, 0x00, 0x01
+               ]
+             , Message {
+                 msgHeader = Header {
+                               hdMessageID             = 8825
+                             , hdMessageType           = Query
+                             , hdOpcode                = StandardQuery
+                             , hdIsAuthoritativeAnswer = False
+                             , hdIsTruncated           = False
+                             , hdIsRecursionDesired    = True
+                             , hdIsRecursionAvailable  = False
+                             , hdResponseCode          = NoError
+                             }
+               , msgQuestions   = [ Question {
+                                      qName  = mkDomainName "mail.cielonegro.org."
+                                    , qType  = wrapQueryType CNAME
+                                    , qClass = IN
+                                    }
+                                  ]
+               , msgAnswers     = []
+               , msgAuthorities = []
+               , msgAdditionals = []
+               }
              )
-           , (parseMsg [ 0x22, 0x79, 0x85, 0x00, 0x00, 0x01, 0x00, 0x01
-                       , 0x00, 0x01, 0x00, 0x01, 0x04, 0x6D, 0x61, 0x69
-                       , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E
-                       , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67
-                       , 0x00, 0x00, 0x05, 0x00, 0x01, 0xC0, 0x0C, 0x00
-                       , 0x05, 0x00, 0x01, 0x00, 0x01, 0x51, 0x80, 0x00
-                       , 0x06, 0x03, 0x6E, 0x65, 0x6D, 0xC0, 0x11, 0xC0
-                       , 0x11, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x0E
-                       , 0x10, 0x00, 0x02, 0xC0, 0x31, 0xC0, 0x31, 0x00
-                       , 0x01, 0x00, 0x01, 0x00, 0x00, 0x0E, 0x10, 0x00
-                       , 0x04, 0xDB, 0x5E, 0x82, 0x8B
-                       ]
-              ~?=
-              Message {
-                msgHeader = Header {
-                              hdMessageID             = 8825
-                            , hdMessageType           = Response
-                            , hdOpcode                = StandardQuery
-                            , hdIsAuthoritativeAnswer = True
-                            , hdIsTruncated           = False
-                            , hdIsRecursionDesired    = True
-                            , hdIsRecursionAvailable  = False
-                            , hdResponseCode          = NoError
-                            }
-              , msgQuestions   = [ Question {
-                                     qName  = mkDomainName "mail.cielonegro.org."
-                                   , qType  = wrapQueryType CNAME
-                                   , qClass = IN
-                                   }
-                                 ]
-              , msgAnswers     = [ wrapRecord $
-                                   ResourceRecord {
-                                     rrName  = mkDomainName "mail.cielonegro.org."
-                                   , rrType  = CNAME
-                                   , rrClass = IN
-                                   , rrTTL   = 86400
-                                   , rrData  = mkDomainName "nem.cielonegro.org."
-                                   }
-                                 ]
-              , msgAuthorities = [ wrapRecord $
-                                   ResourceRecord {
-                                     rrName  = mkDomainName "cielonegro.org."
-                                   , rrType  = NS
-                                   , rrClass = IN
-                                   , rrTTL   = 3600
-                                   , rrData  = mkDomainName "nem.cielonegro.org."
-                                   }
-                                 ]
-              , msgAdditionals = [ wrapRecord $
-                                   ResourceRecord {
-                                     rrName  = mkDomainName "nem.cielonegro.org."
-                                   , rrType  = A
-                                   , rrClass = IN
-                                   , rrTTL   = 3600
-                                   , rrData  = unsafePerformIO (inet_addr "219.94.130.139")
-                                   }
-                                 ]
-              }
+           , ( [ 0x22, 0x79, 0x85, 0x00, 0x00, 0x01, 0x00, 0x01
+               , 0x00, 0x01, 0x00, 0x01, 0x04, 0x6D, 0x61, 0x69
+               , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E
+               , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67
+               , 0x00, 0x00, 0x05, 0x00, 0x01, 0xC0, 0x0C, 0x00
+               , 0x05, 0x00, 0x01, 0x00, 0x01, 0x51, 0x80, 0x00
+               , 0x06, 0x03, 0x6E, 0x65, 0x6D, 0xC0, 0x11, 0xC0
+               , 0x11, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x0E
+               , 0x10, 0x00, 0x02, 0xC0, 0x31, 0xC0, 0x31, 0x00
+               , 0x01, 0x00, 0x01, 0x00, 0x00, 0x0E, 0x10, 0x00
+               , 0x04, 0xDB, 0x5E, 0x82, 0x8B
+               ]
+             , Message {
+                 msgHeader = Header {
+                               hdMessageID             = 8825
+                             , hdMessageType           = Response
+                             , hdOpcode                = StandardQuery
+                             , hdIsAuthoritativeAnswer = True
+                             , hdIsTruncated           = False
+                             , hdIsRecursionDesired    = True
+                             , hdIsRecursionAvailable  = False
+                             , hdResponseCode          = NoError
+                             }
+               , msgQuestions   = [ Question {
+                                      qName  = mkDomainName "mail.cielonegro.org."
+                                    , qType  = wrapQueryType CNAME
+                                    , qClass = IN
+                                    }
+                                  ]
+               , msgAnswers     = [ wrapRecord $
+                                    ResourceRecord {
+                                      rrName  = mkDomainName "mail.cielonegro.org."
+                                    , rrType  = CNAME
+                                    , rrClass = IN
+                                    , rrTTL   = 86400
+                                    , rrData  = mkDomainName "nem.cielonegro.org."
+                                    }
+                                  ]
+               , msgAuthorities = [ wrapRecord $
+                                    ResourceRecord {
+                                      rrName  = mkDomainName "cielonegro.org."
+                                    , rrType  = NS
+                                    , rrClass = IN
+                                    , rrTTL   = 3600
+                                    , rrData  = mkDomainName "nem.cielonegro.org."
+                                    }
+                                  ]
+               , msgAdditionals = [ wrapRecord $
+                                    ResourceRecord {
+                                      rrName  = mkDomainName "nem.cielonegro.org."
+                                    , rrType  = A
+                                    , rrClass = IN
+                                    , rrTTL   = 3600
+                                    , rrData  = unsafePerformIO (inet_addr "219.94.130.139")
+                                    }
+                                  ]
+               }
              )
            ]
 
+packMsg :: Message -> [Word8]
+packMsg = LBS.unpack . encode
+
+unpackMsg :: [Word8] -> Message
+unpackMsg = decode . LBS.pack
+
+testData :: [Test]
+testData = map mkPackTest messages
+           ++
+           map mkUnpackTest messages
+    where
+      mkPackTest :: ([Word8], Message) -> Test
+      mkPackTest (bin, msg) = packMsg msg ~?= bin
+
+      mkUnpackTest :: ([Word8], Message) -> Test
+      mkUnpackTest (bin, msg) = unpackMsg bin ~?= msg
 
 main :: IO ()
 main = runTestTT (test testData) >> return ()
\ No newline at end of file
index 7bedacf5a0922b1816280a8ae4162f9aaf3ff698..be0b79a33800a32dcc769c7ff68129f322261bf4 100644 (file)
@@ -43,18 +43,10 @@ import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
 import           Data.Word
+import           Network.DNS.Unpacker as U
 import           Network.Socket
 
 
-replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a)
-replicateM' = worker []
-    where
-      worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a)
-      worker soFar 0 _ a = return (reverse soFar, a)
-      worker soFar n f a = do (b, a') <- f a
-                              worker (b : soFar) (n - 1) f a'
-
-
 data Message
     = Message {
         msgHeader      :: !Header
@@ -122,17 +114,16 @@ putQ q
          putSomeRT $ qType q
          put $ qClass q
 
-getQ :: DecompTable -> Get (Question, DecompTable)
-getQ dt
-    = do (nm, dt') <- getDomainName dt
-         ty        <- getSomeRT
-         cl        <- get
-         let q = Question {
-                   qName  = nm
-                 , qType  = ty
-                 , qClass = cl
-                 }
-         return (q, dt')
+getQ :: Unpacker DecompTable Question
+getQ = do nm <- getDomainName
+          ty <- getSomeRT
+          cl <- getBinary
+          return Question {
+                       qName  = nm
+                     , qType  = ty
+                     , qClass = cl
+                     }
+
 
 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
 type DomainLabel    = BS.ByteString
@@ -194,25 +185,6 @@ putRR rr = do putDomainName $ rrName rr
               putLazyByteString dat
 
 
-getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable)
-getRR dt rt
-    = do (nm, dt1)  <- getDomainName dt
-         G.skip 2   -- record type
-         cl         <- get
-         ttl        <- G.getWord32be
-         G.skip 2   -- data length
-         (dat, dt2) <- getRecordData (undefined :: rt) dt1
-
-         let rr = ResourceRecord {
-                    rrName  = nm
-                  , rrType  = rt
-                  , rrClass = cl
-                  , rrTTL   = ttl
-                  , rrData  = dat
-                  }
-         return (rr, dt2)
-
-
 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
 
 instance Show SomeRR where
@@ -225,44 +197,44 @@ instance Eq SomeRR where
 putSomeRR :: SomeRR -> Put
 putSomeRR (SomeRR rr) = putRR rr
 
-getSomeRR :: DecompTable -> Get (SomeRR, DecompTable)
-getSomeRR dt
-    = do srt <- lookAhead $
-                do getDomainName dt -- skip
-                   getSomeRT
-         case srt of
-           SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt')
-
+getSomeRR :: Unpacker DecompTable SomeRR
+getSomeRR = do srt <- U.lookAhead $
+                      do getDomainName -- skip
+                         getSomeRT
+               case srt of
+                 SomeRT rt
+                     -> getResourceRecord rt >>= return . SomeRR
 
 type DecompTable = IntMap DomainName
 type TTL = Word32
 
-getDomainName :: DecompTable -> Get (DomainName, DecompTable)
+getDomainName :: Unpacker DecompTable DomainName
 getDomainName = worker
     where
-      worker :: DecompTable -> Get (DomainName, DecompTable)
-      worker dt
-          = do offset <- liftM fromIntegral bytesRead
+      worker :: Unpacker DecompTable DomainName
+      worker
+          = do offset <- U.bytesRead
                hdr    <- getLabelHeader
                case hdr of
                  Offset n
-                     -> case IM.lookup n dt of
-                          Just name
-                              -> return (name, dt)
-                          Nothing
-                              -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
+                     -> do dt <- getState
+                           case IM.lookup n dt of
+                             Just name
+                                 -> return name
+                             Nothing
+                                 -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
                  Length 0
-                     -> return (rootName, dt)
+                     -> return rootName
                  Length n
-                     -> do label       <- getByteString n
-                           (rest, dt') <- worker dt
+                     -> do label <- U.getByteString n
+                           rest  <- worker
                            let name = consLabel label rest
-                               dt'' = IM.insert offset name dt'
-                           return (name, dt'')
+                           modifyState $ IM.insert offset name
+                           return name
 
-      getLabelHeader :: Get LabelHeader
+      getLabelHeader :: Unpacker s LabelHeader
       getLabelHeader
-          = do header <- lookAhead $ getByteString 1
+          = do header <- U.lookAhead $ U.getByteString 1
                let Right h
                        = runBitGet header $
                          do a <- getBit
@@ -274,7 +246,7 @@ getDomainName = worker
                               _              -> fail "Illegal label header"
                case h of
                  Offset _
-                     -> do header' <- getByteString 2 -- Pointers have 2 octets.
+                     -> do header' <- U.getByteString 2 -- Pointers have 2 octets.
                            let Right h'
                                    = runBitGet header' $
                                      do BG.skip 2
@@ -282,13 +254,13 @@ getDomainName = worker
                                         return $ Offset n
                            return h'
                  len@(Length _)
-                     -> do G.skip 1
+                     -> do U.skip 1
                            return len
 
 
-getCharString :: Get BS.ByteString
-getCharString = do len <- G.getWord8
-                   getByteString (fromIntegral len)
+getCharString :: Unpacker s BS.ByteString
+getCharString = do len <- U.getWord8
+                   U.getByteString (fromIntegral len)
 
 putCharString :: BS.ByteString -> Put
 putCharString = putDomainLabel
@@ -307,12 +279,28 @@ putDomainLabel l
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
-    putRecordType :: rt -> Put
     putRecordData :: rt -> dt -> Put
-    getRecordData :: rt -> DecompTable -> Get (dt, DecompTable)
+    getRecordData :: rt -> Unpacker DecompTable dt
 
+    putRecordType :: rt -> Put
     putRecordType = putWord16be . fromIntegral . rtToInt
 
+    getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
+    getResourceRecord rt
+        = do name     <- getDomainName
+             U.skip 2 -- record type
+             cl       <- getBinary
+             ttl      <- U.getWord32be
+             U.skip 2 -- data length
+             dat      <- getRecordData rt
+             return $ ResourceRecord {
+                          rrName  = name
+                        , rrType  = rt
+                        , rrClass = cl
+                        , rrTTL   = ttl
+                        , rrData  = dat
+                        }
+
 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
 
 instance Show SomeRT where
@@ -324,8 +312,8 @@ instance Eq SomeRT where
 putSomeRT :: SomeRT -> Put
 putSomeRT (SomeRT rt) = putRecordType rt
 
-getSomeRT :: Get SomeRT
-getSomeRT = do n <- liftM fromIntegral G.getWord16be
+getSomeRT :: Unpacker s SomeRT
+getSomeRT = do n <- liftM fromIntegral U.getWord16be
                case IM.lookup n defaultRTTable of
                  Nothing
                      -> fail ("Unknown resource record type: " ++ show n)
@@ -336,9 +324,7 @@ data A = A deriving (Show, Eq, Typeable)
 instance RecordType A HostAddress where
     rtToInt       _ = 1
     putRecordData _ = putWord32be
-    getRecordData _ = \ dt ->
-                      do addr <- G.getWord32be
-                         return (addr, dt)
+    getRecordData _ = U.getWord32be
 
 data NS = NS deriving (Show, Eq, Typeable)
 instance RecordType NS DomainName where
@@ -357,9 +343,9 @@ instance RecordType HINFO (BS.ByteString, BS.ByteString) where
     rtToInt       _           = 13
     putRecordData _ (cpu, os) = do putCharString cpu
                                    putCharString os
-    getRecordData _ dt        = do cpu <- getCharString
+    getRecordData _           = do cpu <- getCharString
                                    os  <- getCharString
-                                   return ((cpu, os), dt)
+                                   return (cpu, os)
 
 
 {-
@@ -400,15 +386,16 @@ instance Binary Message where
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
 
-    get = do hdr  <- get
-             nQ   <- liftM fromIntegral G.getWord16be
-             nAns <- liftM fromIntegral G.getWord16be
-             nAth <- liftM fromIntegral G.getWord16be
-             nAdd <- liftM fromIntegral G.getWord16be
-             (qs  , dt1) <- replicateM' nQ   getQ IM.empty
-             (anss, dt2) <- replicateM' nAns getSomeRR dt1
-             (aths, dt3) <- replicateM' nAth getSomeRR dt2
-             (adds, _  ) <- replicateM' nAdd getSomeRR dt3
+    get = liftToBinary IM.empty $
+          do hdr  <- getBinary
+             nQ   <- liftM fromIntegral U.getWord16be
+             nAns <- liftM fromIntegral U.getWord16be
+             nAth <- liftM fromIntegral U.getWord16be
+             nAdd <- liftM fromIntegral U.getWord16be
+             qs   <- replicateM nQ   getQ
+             anss <- replicateM nAns getSomeRR
+             aths <- replicateM nAth getSomeRR
+             adds <- replicateM nAdd getSomeRR
              return Message {
                           msgHeader      = hdr
                         , msgQuestions   = qs
@@ -432,7 +419,7 @@ instance Binary Header where
                      putNBits 4 $ fromEnum $ hdResponseCode h
 
     get = do mID   <- G.getWord16be
-             flags <- getByteString 2
+             flags <- G.getByteString 2
              let Right hd
                      = runBitGet flags $
                        do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
diff --git a/Network/DNS/Unpacker.hs b/Network/DNS/Unpacker.hs
new file mode 100644 (file)
index 0000000..db34946
--- /dev/null
@@ -0,0 +1,150 @@
+module Network.DNS.Unpacker
+    ( Unpacker
+    , UnpackingState(..)
+
+    , unpack
+    , unpack'
+
+    , getState
+    , setState
+    , modifyState
+
+    , skip
+    , lookAhead
+    , bytesRead
+
+    , getByteString
+    , getLazyByteString
+    , getWord8
+    , getWord16be
+    , getWord32be
+
+    , getBinary
+    , liftToBinary
+    )
+    where
+
+import qualified Data.Binary as Binary
+import qualified Data.Binary.Get as Bin
+import qualified Data.ByteString as Strict
+import qualified Data.ByteString.Lazy as Lazy
+import           Data.Bits
+import           Data.Int
+import           Data.Word
+
+
+data UnpackingState s
+    = UnpackingState {
+        stSource    :: !Lazy.ByteString
+      , stBytesRead :: !Int64
+      , stUserState :: s
+      }
+
+newtype Unpacker s a = U { unU :: UnpackingState s -> (a, UnpackingState s) }
+
+instance Monad (Unpacker s) where
+    return a = U (\ s -> (a, s))
+    m >>= k  = U (\ s -> let (a, s') = unU m s
+                         in
+                           unU (k a) s')
+    fail err = do bytes <- get stBytesRead
+                  U (error (err
+                            ++ ". Failed unpacking at byte position "
+                            ++ show bytes))
+
+get :: (UnpackingState s -> a) -> Unpacker s a
+get f = U (\ s -> (f s, s))
+
+set :: (UnpackingState s -> UnpackingState s) -> Unpacker s ()
+set f = U (\ s -> ((), f s))
+
+mkState :: Lazy.ByteString -> Int64 -> s -> UnpackingState s
+mkState xs n s
+    = UnpackingState {
+        stSource    = xs
+      , stBytesRead = n
+      , stUserState = s
+      }
+
+unpack' :: Unpacker s a -> s -> Lazy.ByteString -> (a, s)
+unpack' m s xs
+    = let (a, s') = unU m (mkState xs 0 s)
+      in
+        (a, stUserState s')
+
+unpack :: Unpacker s a -> s -> Lazy.ByteString -> a
+unpack = ((fst .) .) . unpack'
+
+getState :: Unpacker s s
+getState = get stUserState
+
+setState :: s -> Unpacker s ()
+setState = modifyState . const
+
+modifyState :: (s -> s) -> Unpacker s ()
+modifyState f
+    = set $ \ st -> st { stUserState = f (stUserState st) }
+
+skip :: Int64 -> Unpacker s ()
+skip n = getLazyByteString n >> return ()
+
+lookAhead :: Unpacker s a -> Unpacker s a
+lookAhead m = U (\ s -> let (a, _) = unU m s
+                        in
+                          (a, s))
+
+bytesRead :: Integral i => Unpacker s i
+bytesRead = get stBytesRead >>= return . fromIntegral
+
+getByteString :: Int -> Unpacker s Strict.ByteString
+getByteString n = getLazyByteString (fromIntegral n) >>= return . Strict.concat . Lazy.toChunks
+
+getLazyByteString :: Int64 -> Unpacker s Lazy.ByteString
+getLazyByteString n
+    = do src <- get stSource
+         let (xs, ys) = Lazy.splitAt n src
+         if Lazy.length xs /= n then
+             fail "Too few bytes"
+           else
+             do set $ \ st -> st {
+                                stSource    = ys
+                              , stBytesRead = stBytesRead st + n
+                              }
+                return xs
+
+getWord8 :: Unpacker s Word8
+getWord8 = getLazyByteString 1 >>= return . (`Lazy.index` 0)
+
+getWord16be :: Unpacker s Word16
+getWord16be = do xs <- getLazyByteString 2
+                 return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 8) .|.
+                          (fromIntegral (xs `Lazy.index` 1))
+
+getWord32be :: Unpacker s Word32
+getWord32be = do xs <- getLazyByteString 4
+                 return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 24) .|.
+                          (fromIntegral (xs `Lazy.index` 1) `shiftL` 16) .|.
+                          (fromIntegral (xs `Lazy.index` 2) `shiftL`  8) .|.
+                          (fromIntegral (xs `Lazy.index` 3))
+
+getBinary :: Binary.Binary a => Unpacker s a
+getBinary = do s <- get id
+               let (a, rest, bytes) = Bin.runGetState Binary.get (stSource s) (stBytesRead s)
+               set $ \ st -> st {
+                               stSource    = rest
+                             , stBytesRead = bytes
+                             }
+               return a
+
+
+liftToBinary :: s -> Unpacker s a -> Bin.Get a
+liftToBinary s m
+    = do bytes <- Bin.bytesRead
+         src   <- Bin.getRemainingLazyByteString
+
+         let (a, s') = unU m (mkState src bytes s)
+
+         -- These bytes was consumed by the unpacker.
+         Bin.skip (fromIntegral (stBytesRead s' - bytes))
+
+         return a
index e257dc7ea75791ac2864b9a8879c0214f31fc9e7..3f609129c303318b7b38eab2481378c51a3087a0 100644 (file)
--- a/dns.cabal
+++ b/dns.cabal
@@ -23,6 +23,9 @@ Library
     Exposed-Modules:
         Network.DNS.Message
 
+    Other-Modules:
+        Network.DNS.Unpacker
+
     Extensions:
         DeriveDataTypeable, ExistentialQuantification,
         FlexibleInstances, FunctionalDependencies, MultiParamTypeClasses,