From: PHO Date: Fri, 22 May 2009 05:13:11 +0000 (+0900) Subject: Introduce Unpacker monad to clean up things. X-Git-Url: http://git.cielonegro.org/gitweb.cgi?p=haskell-dns.git;a=commitdiff_plain;h=298473c933e7ad1e101f4db7a7ee115745098235 Introduce Unpacker monad to clean up things. --- diff --git a/DNSUnitTest.hs b/DNSUnitTest.hs index 07d3adf..76a677d 100644 --- a/DNSUnitTest.hs +++ b/DNSUnitTest.hs @@ -7,101 +7,111 @@ import System.IO.Unsafe import Test.HUnit -parseMsg :: [Word8] -> Message -parseMsg = decode . LBS.pack - - -testData :: [Test] -testData = [ (parseMsg [ 0x22, 0x79, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00 - , 0x00, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x69 - , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E - , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67 - , 0x00, 0x00, 0x05, 0x00, 0x01 - ] - ~?= - Message { - msgHeader = Header { - hdMessageID = 8825 - , hdMessageType = Query - , hdOpcode = StandardQuery - , hdIsAuthoritativeAnswer = False - , hdIsTruncated = False - , hdIsRecursionDesired = True - , hdIsRecursionAvailable = False - , hdResponseCode = NoError - } - , msgQuestions = [ Question { - qName = mkDomainName "mail.cielonegro.org." - , qType = wrapQueryType CNAME - , qClass = IN - } - ] - , msgAnswers = [] - , msgAuthorities = [] - , msgAdditionals = [] - } +messages :: [([Word8], Message)] +messages = [ ( [ 0x22, 0x79, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00 + , 0x00, 0x00, 0x00, 0x00, 0x04, 0x6D, 0x61, 0x69 + , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E + , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67 + , 0x00, 0x00, 0x05, 0x00, 0x01 + ] + , Message { + msgHeader = Header { + hdMessageID = 8825 + , hdMessageType = Query + , hdOpcode = StandardQuery + , hdIsAuthoritativeAnswer = False + , hdIsTruncated = False + , hdIsRecursionDesired = True + , hdIsRecursionAvailable = False + , hdResponseCode = NoError + } + , msgQuestions = [ Question { + qName = mkDomainName "mail.cielonegro.org." + , qType = wrapQueryType CNAME + , qClass = IN + } + ] + , msgAnswers = [] + , msgAuthorities = [] + , msgAdditionals = [] + } ) - , (parseMsg [ 0x22, 0x79, 0x85, 0x00, 0x00, 0x01, 0x00, 0x01 - , 0x00, 0x01, 0x00, 0x01, 0x04, 0x6D, 0x61, 0x69 - , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E - , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67 - , 0x00, 0x00, 0x05, 0x00, 0x01, 0xC0, 0x0C, 0x00 - , 0x05, 0x00, 0x01, 0x00, 0x01, 0x51, 0x80, 0x00 - , 0x06, 0x03, 0x6E, 0x65, 0x6D, 0xC0, 0x11, 0xC0 - , 0x11, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x0E - , 0x10, 0x00, 0x02, 0xC0, 0x31, 0xC0, 0x31, 0x00 - , 0x01, 0x00, 0x01, 0x00, 0x00, 0x0E, 0x10, 0x00 - , 0x04, 0xDB, 0x5E, 0x82, 0x8B - ] - ~?= - Message { - msgHeader = Header { - hdMessageID = 8825 - , hdMessageType = Response - , hdOpcode = StandardQuery - , hdIsAuthoritativeAnswer = True - , hdIsTruncated = False - , hdIsRecursionDesired = True - , hdIsRecursionAvailable = False - , hdResponseCode = NoError - } - , msgQuestions = [ Question { - qName = mkDomainName "mail.cielonegro.org." - , qType = wrapQueryType CNAME - , qClass = IN - } - ] - , msgAnswers = [ wrapRecord $ - ResourceRecord { - rrName = mkDomainName "mail.cielonegro.org." - , rrType = CNAME - , rrClass = IN - , rrTTL = 86400 - , rrData = mkDomainName "nem.cielonegro.org." - } - ] - , msgAuthorities = [ wrapRecord $ - ResourceRecord { - rrName = mkDomainName "cielonegro.org." - , rrType = NS - , rrClass = IN - , rrTTL = 3600 - , rrData = mkDomainName "nem.cielonegro.org." - } - ] - , msgAdditionals = [ wrapRecord $ - ResourceRecord { - rrName = mkDomainName "nem.cielonegro.org." - , rrType = A - , rrClass = IN - , rrTTL = 3600 - , rrData = unsafePerformIO (inet_addr "219.94.130.139") - } - ] - } + , ( [ 0x22, 0x79, 0x85, 0x00, 0x00, 0x01, 0x00, 0x01 + , 0x00, 0x01, 0x00, 0x01, 0x04, 0x6D, 0x61, 0x69 + , 0x6C, 0x0A, 0x63, 0x69, 0x65, 0x6C, 0x6F, 0x6E + , 0x65, 0x67, 0x72, 0x6F, 0x03, 0x6F, 0x72, 0x67 + , 0x00, 0x00, 0x05, 0x00, 0x01, 0xC0, 0x0C, 0x00 + , 0x05, 0x00, 0x01, 0x00, 0x01, 0x51, 0x80, 0x00 + , 0x06, 0x03, 0x6E, 0x65, 0x6D, 0xC0, 0x11, 0xC0 + , 0x11, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x0E + , 0x10, 0x00, 0x02, 0xC0, 0x31, 0xC0, 0x31, 0x00 + , 0x01, 0x00, 0x01, 0x00, 0x00, 0x0E, 0x10, 0x00 + , 0x04, 0xDB, 0x5E, 0x82, 0x8B + ] + , Message { + msgHeader = Header { + hdMessageID = 8825 + , hdMessageType = Response + , hdOpcode = StandardQuery + , hdIsAuthoritativeAnswer = True + , hdIsTruncated = False + , hdIsRecursionDesired = True + , hdIsRecursionAvailable = False + , hdResponseCode = NoError + } + , msgQuestions = [ Question { + qName = mkDomainName "mail.cielonegro.org." + , qType = wrapQueryType CNAME + , qClass = IN + } + ] + , msgAnswers = [ wrapRecord $ + ResourceRecord { + rrName = mkDomainName "mail.cielonegro.org." + , rrType = CNAME + , rrClass = IN + , rrTTL = 86400 + , rrData = mkDomainName "nem.cielonegro.org." + } + ] + , msgAuthorities = [ wrapRecord $ + ResourceRecord { + rrName = mkDomainName "cielonegro.org." + , rrType = NS + , rrClass = IN + , rrTTL = 3600 + , rrData = mkDomainName "nem.cielonegro.org." + } + ] + , msgAdditionals = [ wrapRecord $ + ResourceRecord { + rrName = mkDomainName "nem.cielonegro.org." + , rrType = A + , rrClass = IN + , rrTTL = 3600 + , rrData = unsafePerformIO (inet_addr "219.94.130.139") + } + ] + } ) ] +packMsg :: Message -> [Word8] +packMsg = LBS.unpack . encode + +unpackMsg :: [Word8] -> Message +unpackMsg = decode . LBS.pack + +testData :: [Test] +testData = map mkPackTest messages + ++ + map mkUnpackTest messages + where + mkPackTest :: ([Word8], Message) -> Test + mkPackTest (bin, msg) = packMsg msg ~?= bin + + mkUnpackTest :: ([Word8], Message) -> Test + mkUnpackTest (bin, msg) = unpackMsg bin ~?= msg main :: IO () main = runTestTT (test testData) >> return () \ No newline at end of file diff --git a/Network/DNS/Message.hs b/Network/DNS/Message.hs index 7bedacf..be0b79a 100644 --- a/Network/DNS/Message.hs +++ b/Network/DNS/Message.hs @@ -43,18 +43,10 @@ import Data.Typeable import qualified Data.IntMap as IM import Data.IntMap (IntMap) import Data.Word +import Network.DNS.Unpacker as U import Network.Socket -replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a) -replicateM' = worker [] - where - worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a) - worker soFar 0 _ a = return (reverse soFar, a) - worker soFar n f a = do (b, a') <- f a - worker (b : soFar) (n - 1) f a' - - data Message = Message { msgHeader :: !Header @@ -122,17 +114,16 @@ putQ q putSomeRT $ qType q put $ qClass q -getQ :: DecompTable -> Get (Question, DecompTable) -getQ dt - = do (nm, dt') <- getDomainName dt - ty <- getSomeRT - cl <- get - let q = Question { - qName = nm - , qType = ty - , qClass = cl - } - return (q, dt') +getQ :: Unpacker DecompTable Question +getQ = do nm <- getDomainName + ty <- getSomeRT + cl <- getBinary + return Question { + qName = nm + , qType = ty + , qClass = cl + } + newtype DomainName = DN [DomainLabel] deriving (Eq, Show, Typeable) type DomainLabel = BS.ByteString @@ -194,25 +185,6 @@ putRR rr = do putDomainName $ rrName rr putLazyByteString dat -getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable) -getRR dt rt - = do (nm, dt1) <- getDomainName dt - G.skip 2 -- record type - cl <- get - ttl <- G.getWord32be - G.skip 2 -- data length - (dat, dt2) <- getRecordData (undefined :: rt) dt1 - - let rr = ResourceRecord { - rrName = nm - , rrType = rt - , rrClass = cl - , rrTTL = ttl - , rrData = dat - } - return (rr, dt2) - - data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt) instance Show SomeRR where @@ -225,44 +197,44 @@ instance Eq SomeRR where putSomeRR :: SomeRR -> Put putSomeRR (SomeRR rr) = putRR rr -getSomeRR :: DecompTable -> Get (SomeRR, DecompTable) -getSomeRR dt - = do srt <- lookAhead $ - do getDomainName dt -- skip - getSomeRT - case srt of - SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt') - +getSomeRR :: Unpacker DecompTable SomeRR +getSomeRR = do srt <- U.lookAhead $ + do getDomainName -- skip + getSomeRT + case srt of + SomeRT rt + -> getResourceRecord rt >>= return . SomeRR type DecompTable = IntMap DomainName type TTL = Word32 -getDomainName :: DecompTable -> Get (DomainName, DecompTable) +getDomainName :: Unpacker DecompTable DomainName getDomainName = worker where - worker :: DecompTable -> Get (DomainName, DecompTable) - worker dt - = do offset <- liftM fromIntegral bytesRead + worker :: Unpacker DecompTable DomainName + worker + = do offset <- U.bytesRead hdr <- getLabelHeader case hdr of Offset n - -> case IM.lookup n dt of - Just name - -> return (name, dt) - Nothing - -> fail ("Illegal offset of label pointer: " ++ show (n, dt)) + -> do dt <- getState + case IM.lookup n dt of + Just name + -> return name + Nothing + -> fail ("Illegal offset of label pointer: " ++ show (n, dt)) Length 0 - -> return (rootName, dt) + -> return rootName Length n - -> do label <- getByteString n - (rest, dt') <- worker dt + -> do label <- U.getByteString n + rest <- worker let name = consLabel label rest - dt'' = IM.insert offset name dt' - return (name, dt'') + modifyState $ IM.insert offset name + return name - getLabelHeader :: Get LabelHeader + getLabelHeader :: Unpacker s LabelHeader getLabelHeader - = do header <- lookAhead $ getByteString 1 + = do header <- U.lookAhead $ U.getByteString 1 let Right h = runBitGet header $ do a <- getBit @@ -274,7 +246,7 @@ getDomainName = worker _ -> fail "Illegal label header" case h of Offset _ - -> do header' <- getByteString 2 -- Pointers have 2 octets. + -> do header' <- U.getByteString 2 -- Pointers have 2 octets. let Right h' = runBitGet header' $ do BG.skip 2 @@ -282,13 +254,13 @@ getDomainName = worker return $ Offset n return h' len@(Length _) - -> do G.skip 1 + -> do U.skip 1 return len -getCharString :: Get BS.ByteString -getCharString = do len <- G.getWord8 - getByteString (fromIntegral len) +getCharString :: Unpacker s BS.ByteString +getCharString = do len <- U.getWord8 + U.getByteString (fromIntegral len) putCharString :: BS.ByteString -> Put putCharString = putDomainLabel @@ -307,12 +279,28 @@ putDomainLabel l class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where rtToInt :: rt -> Int - putRecordType :: rt -> Put putRecordData :: rt -> dt -> Put - getRecordData :: rt -> DecompTable -> Get (dt, DecompTable) + getRecordData :: rt -> Unpacker DecompTable dt + putRecordType :: rt -> Put putRecordType = putWord16be . fromIntegral . rtToInt + getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt) + getResourceRecord rt + = do name <- getDomainName + U.skip 2 -- record type + cl <- getBinary + ttl <- U.getWord32be + U.skip 2 -- data length + dat <- getRecordData rt + return $ ResourceRecord { + rrName = name + , rrType = rt + , rrClass = cl + , rrTTL = ttl + , rrData = dat + } + data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt instance Show SomeRT where @@ -324,8 +312,8 @@ instance Eq SomeRT where putSomeRT :: SomeRT -> Put putSomeRT (SomeRT rt) = putRecordType rt -getSomeRT :: Get SomeRT -getSomeRT = do n <- liftM fromIntegral G.getWord16be +getSomeRT :: Unpacker s SomeRT +getSomeRT = do n <- liftM fromIntegral U.getWord16be case IM.lookup n defaultRTTable of Nothing -> fail ("Unknown resource record type: " ++ show n) @@ -336,9 +324,7 @@ data A = A deriving (Show, Eq, Typeable) instance RecordType A HostAddress where rtToInt _ = 1 putRecordData _ = putWord32be - getRecordData _ = \ dt -> - do addr <- G.getWord32be - return (addr, dt) + getRecordData _ = U.getWord32be data NS = NS deriving (Show, Eq, Typeable) instance RecordType NS DomainName where @@ -357,9 +343,9 @@ instance RecordType HINFO (BS.ByteString, BS.ByteString) where rtToInt _ = 13 putRecordData _ (cpu, os) = do putCharString cpu putCharString os - getRecordData _ dt = do cpu <- getCharString + getRecordData _ = do cpu <- getCharString os <- getCharString - return ((cpu, os), dt) + return (cpu, os) {- @@ -400,15 +386,16 @@ instance Binary Message where mapM_ putSomeRR $ msgAuthorities m mapM_ putSomeRR $ msgAdditionals m - get = do hdr <- get - nQ <- liftM fromIntegral G.getWord16be - nAns <- liftM fromIntegral G.getWord16be - nAth <- liftM fromIntegral G.getWord16be - nAdd <- liftM fromIntegral G.getWord16be - (qs , dt1) <- replicateM' nQ getQ IM.empty - (anss, dt2) <- replicateM' nAns getSomeRR dt1 - (aths, dt3) <- replicateM' nAth getSomeRR dt2 - (adds, _ ) <- replicateM' nAdd getSomeRR dt3 + get = liftToBinary IM.empty $ + do hdr <- getBinary + nQ <- liftM fromIntegral U.getWord16be + nAns <- liftM fromIntegral U.getWord16be + nAth <- liftM fromIntegral U.getWord16be + nAdd <- liftM fromIntegral U.getWord16be + qs <- replicateM nQ getQ + anss <- replicateM nAns getSomeRR + aths <- replicateM nAth getSomeRR + adds <- replicateM nAdd getSomeRR return Message { msgHeader = hdr , msgQuestions = qs @@ -432,7 +419,7 @@ instance Binary Header where putNBits 4 $ fromEnum $ hdResponseCode h get = do mID <- G.getWord16be - flags <- getByteString 2 + flags <- G.getByteString 2 let Right hd = runBitGet flags $ do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1 diff --git a/Network/DNS/Unpacker.hs b/Network/DNS/Unpacker.hs new file mode 100644 index 0000000..db34946 --- /dev/null +++ b/Network/DNS/Unpacker.hs @@ -0,0 +1,150 @@ +module Network.DNS.Unpacker + ( Unpacker + , UnpackingState(..) + + , unpack + , unpack' + + , getState + , setState + , modifyState + + , skip + , lookAhead + , bytesRead + + , getByteString + , getLazyByteString + , getWord8 + , getWord16be + , getWord32be + + , getBinary + , liftToBinary + ) + where + +import qualified Data.Binary as Binary +import qualified Data.Binary.Get as Bin +import qualified Data.ByteString as Strict +import qualified Data.ByteString.Lazy as Lazy +import Data.Bits +import Data.Int +import Data.Word + + +data UnpackingState s + = UnpackingState { + stSource :: !Lazy.ByteString + , stBytesRead :: !Int64 + , stUserState :: s + } + +newtype Unpacker s a = U { unU :: UnpackingState s -> (a, UnpackingState s) } + +instance Monad (Unpacker s) where + return a = U (\ s -> (a, s)) + m >>= k = U (\ s -> let (a, s') = unU m s + in + unU (k a) s') + fail err = do bytes <- get stBytesRead + U (error (err + ++ ". Failed unpacking at byte position " + ++ show bytes)) + +get :: (UnpackingState s -> a) -> Unpacker s a +get f = U (\ s -> (f s, s)) + +set :: (UnpackingState s -> UnpackingState s) -> Unpacker s () +set f = U (\ s -> ((), f s)) + +mkState :: Lazy.ByteString -> Int64 -> s -> UnpackingState s +mkState xs n s + = UnpackingState { + stSource = xs + , stBytesRead = n + , stUserState = s + } + +unpack' :: Unpacker s a -> s -> Lazy.ByteString -> (a, s) +unpack' m s xs + = let (a, s') = unU m (mkState xs 0 s) + in + (a, stUserState s') + +unpack :: Unpacker s a -> s -> Lazy.ByteString -> a +unpack = ((fst .) .) . unpack' + +getState :: Unpacker s s +getState = get stUserState + +setState :: s -> Unpacker s () +setState = modifyState . const + +modifyState :: (s -> s) -> Unpacker s () +modifyState f + = set $ \ st -> st { stUserState = f (stUserState st) } + +skip :: Int64 -> Unpacker s () +skip n = getLazyByteString n >> return () + +lookAhead :: Unpacker s a -> Unpacker s a +lookAhead m = U (\ s -> let (a, _) = unU m s + in + (a, s)) + +bytesRead :: Integral i => Unpacker s i +bytesRead = get stBytesRead >>= return . fromIntegral + +getByteString :: Int -> Unpacker s Strict.ByteString +getByteString n = getLazyByteString (fromIntegral n) >>= return . Strict.concat . Lazy.toChunks + +getLazyByteString :: Int64 -> Unpacker s Lazy.ByteString +getLazyByteString n + = do src <- get stSource + let (xs, ys) = Lazy.splitAt n src + if Lazy.length xs /= n then + fail "Too few bytes" + else + do set $ \ st -> st { + stSource = ys + , stBytesRead = stBytesRead st + n + } + return xs + +getWord8 :: Unpacker s Word8 +getWord8 = getLazyByteString 1 >>= return . (`Lazy.index` 0) + +getWord16be :: Unpacker s Word16 +getWord16be = do xs <- getLazyByteString 2 + return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 8) .|. + (fromIntegral (xs `Lazy.index` 1)) + +getWord32be :: Unpacker s Word32 +getWord32be = do xs <- getLazyByteString 4 + return $ (fromIntegral (xs `Lazy.index` 0) `shiftL` 24) .|. + (fromIntegral (xs `Lazy.index` 1) `shiftL` 16) .|. + (fromIntegral (xs `Lazy.index` 2) `shiftL` 8) .|. + (fromIntegral (xs `Lazy.index` 3)) + +getBinary :: Binary.Binary a => Unpacker s a +getBinary = do s <- get id + let (a, rest, bytes) = Bin.runGetState Binary.get (stSource s) (stBytesRead s) + set $ \ st -> st { + stSource = rest + , stBytesRead = bytes + } + return a + + +liftToBinary :: s -> Unpacker s a -> Bin.Get a +liftToBinary s m + = do bytes <- Bin.bytesRead + src <- Bin.getRemainingLazyByteString + + let (a, s') = unU m (mkState src bytes s) + + -- These bytes was consumed by the unpacker. + Bin.skip (fromIntegral (stBytesRead s' - bytes)) + + return a diff --git a/dns.cabal b/dns.cabal index e257dc7..3f60912 100644 --- a/dns.cabal +++ b/dns.cabal @@ -23,6 +23,9 @@ Library Exposed-Modules: Network.DNS.Message + Other-Modules: + Network.DNS.Unpacker + Extensions: DeriveDataTypeable, ExistentialQuantification, FlexibleInstances, FunctionalDependencies, MultiParamTypeClasses,