From 6423ccc375d8b7d61707de4c6e7b2ace5971be0f Mon Sep 17 00:00:00 2001 From: PHO Date: Fri, 22 May 2009 16:35:49 +0900 Subject: [PATCH] Introduce Packer monad so that we can compress binary packets. --- Network/DNS/Message.hs | 134 ++++++++++++++++++++-------------- Network/DNS/Packer.hs | 154 ++++++++++++++++++++++++++++++++++++++++ Network/DNS/Unpacker.hs | 2 +- dns.cabal | 3 +- 4 files changed, 238 insertions(+), 55 deletions(-) create mode 100644 Network/DNS/Packer.hs diff --git a/Network/DNS/Message.hs b/Network/DNS/Message.hs index be0b79a..5c53795 100644 --- a/Network/DNS/Message.hs +++ b/Network/DNS/Message.hs @@ -34,15 +34,17 @@ import Control.Monad import Data.Binary import Data.Binary.BitPut as BP import Data.Binary.Get as G -import Data.Binary.Put as P +import Data.Binary.Put as P' import Data.Binary.Strict.BitGet as BG import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as C8 hiding (ByteString) -import qualified Data.ByteString.Lazy as LBS import Data.Typeable import qualified Data.IntMap as IM import Data.IntMap (IntMap) +import qualified Data.Map as M +import Data.Map (Map) import Data.Word +import Network.DNS.Packer as P import Network.DNS.Unpacker as U import Network.Socket @@ -108,11 +110,11 @@ data Question type SomeQT = SomeRT -putQ :: Question -> Put +putQ :: Question -> Packer CompTable () putQ q = do putDomainName $ qName q putSomeRT $ qType q - put $ qClass q + putBinary $ qClass q getQ :: Unpacker DecompTable Question getQ = do nm <- getDomainName @@ -125,23 +127,25 @@ getQ = do nm <- getDomainName } -newtype DomainName = DN [DomainLabel] deriving (Eq, Show, Typeable) +newtype DomainName = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable) type DomainLabel = BS.ByteString -nameToLabels :: DomainName -> [DomainLabel] -nameToLabels (DN ls) = ls - -labelsToName :: [DomainLabel] -> DomainName -labelsToName = DN - rootName :: DomainName rootName = DN [BS.empty] +isRootName :: DomainName -> Bool +isRootName (DN [_]) = True +isRootName _ = False + consLabel :: DomainLabel -> DomainName -> DomainName consLabel x (DN ys) = DN (x:ys) +unconsLabel :: DomainName -> (DomainLabel, DomainName) +unconsLabel (DN (x:xs)) = (x, DN xs) +unconsLabel x = error ("Illegal use of unconsLabel: " ++ show x) + mkDomainName :: String -> DomainName -mkDomainName = labelsToName . mkLabels [] . notEmpty +mkDomainName = DN . mkLabels [] . notEmpty where notEmpty :: String -> String notEmpty xs = assert (not $ null xs) xs @@ -173,18 +177,6 @@ data RecordType rt dt => ResourceRecord rt dt deriving (Show, Eq, Typeable) -putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put -putRR rr = do putDomainName $ rrName rr - putRecordType $ rrType rr - put $ rrClass rr - putWord32be $ rrTTL rr - - let dat = runPut $ - putRecordData (undefined :: rt) (rrData rr) - putWord16be $ fromIntegral $ LBS.length dat - putLazyByteString dat - - data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt) instance Show SomeRR where @@ -194,8 +186,8 @@ instance Eq SomeRR where (SomeRR a) == (SomeRR b) = Just a == cast b -putSomeRR :: SomeRR -> Put -putSomeRR (SomeRR rr) = putRR rr +putSomeRR :: SomeRR -> Packer CompTable () +putSomeRR (SomeRR rr) = putResourceRecord rr getSomeRR :: Unpacker DecompTable SomeRR getSomeRR = do srt <- U.lookAhead $ @@ -205,8 +197,9 @@ getSomeRR = do srt <- U.lookAhead $ SomeRT rt -> getResourceRecord rt >>= return . SomeRR +type CompTable = Map DomainName Int type DecompTable = IntMap DomainName -type TTL = Word32 +type TTL = Word32 getDomainName :: Unpacker DecompTable DomainName getDomainName = worker @@ -217,7 +210,7 @@ getDomainName = worker hdr <- getLabelHeader case hdr of Offset n - -> do dt <- getState + -> do dt <- U.getState case IM.lookup n dt of Just name -> return name @@ -229,7 +222,7 @@ getDomainName = worker -> do label <- U.getByteString n rest <- worker let name = consLabel label rest - modifyState $ IM.insert offset name + U.modifyState $ IM.insert offset name return name getLabelHeader :: Unpacker s LabelHeader @@ -262,28 +255,64 @@ getCharString :: Unpacker s BS.ByteString getCharString = do len <- U.getWord8 U.getByteString (fromIntegral len) -putCharString :: BS.ByteString -> Put -putCharString = putDomainLabel +putCharString :: BS.ByteString -> Packer s () +putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs + P.putByteString xs data LabelHeader = Offset !Int | Length !Int -putDomainName :: DomainName -> Put -putDomainName = mapM_ putDomainLabel . nameToLabels +putDomainName :: DomainName -> Packer CompTable () +putDomainName name + = do ct <- P.getState + case M.lookup name ct of + Just n + -> do let ptr = runBitPut $ + do putBit True + putBit True + putNBits 14 n + P.putLazyByteString ptr + Nothing + -> do offset <- bytesWrote + P.modifyState $ M.insert name offset + + let (label, rest) = unconsLabel name + + putCharString label + + if isRootName rest then + P.putWord8 0 + else + putDomainName rest -putDomainLabel :: DomainLabel -> Put -putDomainLabel l - = do putWord8 $ fromIntegral $ BS.length l - P.putByteString l class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where rtToInt :: rt -> Int - putRecordData :: rt -> dt -> Put + putRecordData :: rt -> dt -> Packer CompTable () getRecordData :: rt -> Unpacker DecompTable dt - putRecordType :: rt -> Put - putRecordType = putWord16be . fromIntegral . rtToInt + putRecordType :: rt -> Packer s () + putRecordType = P.putWord16be . fromIntegral . rtToInt + + putResourceRecord :: ResourceRecord rt dt -> Packer CompTable () + putResourceRecord rr + = do putDomainName $ rrName rr + putRecordType $ rrType rr + putBinary $ rrClass rr + P.putWord32be $ rrTTL rr + + -- First, write a dummy data length. + offset <- bytesWrote + P.putWord16be 0 + + -- Second, write data. + putRecordData (rrType rr) (rrData rr) + + -- Third, rewrite the dummy length to an actual value. + offset' <- bytesWrote + withOffset offset + $ P.putWord16be (fromIntegral (offset' - offset - 2)) getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt) getResourceRecord rt @@ -309,7 +338,7 @@ instance Show SomeRT where instance Eq SomeRT where (SomeRT a) == (SomeRT b) = Just a == cast b -putSomeRT :: SomeRT -> Put +putSomeRT :: SomeRT -> Packer s () putSomeRT (SomeRT rt) = putRecordType rt getSomeRT :: Unpacker s SomeRT @@ -323,7 +352,7 @@ getSomeRT = do n <- liftM fromIntegral U.getWord16be data A = A deriving (Show, Eq, Typeable) instance RecordType A HostAddress where rtToInt _ = 1 - putRecordData _ = putWord32be + putRecordData _ = P.putWord32be getRecordData _ = U.getWord32be data NS = NS deriving (Show, Eq, Typeable) @@ -376,17 +405,18 @@ data RecordType -} instance Binary Message where - put m = do put $ msgHeader m - putWord16be $ fromIntegral $ length $ msgQuestions m - putWord16be $ fromIntegral $ length $ msgAnswers m - putWord16be $ fromIntegral $ length $ msgAuthorities m - putWord16be $ fromIntegral $ length $ msgAdditionals m - mapM_ putQ $ msgQuestions m + put m = P.liftToBinary M.empty $ + do putBinary $ msgHeader m + P.putWord16be $ fromIntegral $ length $ msgQuestions m + P.putWord16be $ fromIntegral $ length $ msgAnswers m + P.putWord16be $ fromIntegral $ length $ msgAuthorities m + P.putWord16be $ fromIntegral $ length $ msgAdditionals m + mapM_ putQ $ msgQuestions m mapM_ putSomeRR $ msgAnswers m mapM_ putSomeRR $ msgAuthorities m mapM_ putSomeRR $ msgAdditionals m - get = liftToBinary IM.empty $ + get = U.liftToBinary IM.empty $ do hdr <- getBinary nQ <- liftM fromIntegral U.getWord16be nAns <- liftM fromIntegral U.getWord16be @@ -405,8 +435,8 @@ instance Binary Message where } instance Binary Header where - put h = do putWord16be $ hdMessageID h - putLazyByteString flags + put h = do P'.putWord16be $ hdMessageID h + P'.putLazyByteString flags where flags = runBitPut $ do putNBits 1 $ fromEnum $ hdMessageType h @@ -538,7 +568,7 @@ instance Enum RecordClass where instance Binary RecordClass where get = liftM (toEnum . fromIntegral) G.getWord16be - put = putWord16be . fromIntegral . fromEnum + put = P'.putWord16be . fromIntegral . fromEnum defaultRTTable :: IntMap SomeRT diff --git a/Network/DNS/Packer.hs b/Network/DNS/Packer.hs new file mode 100644 index 0000000..7f8f895 --- /dev/null +++ b/Network/DNS/Packer.hs @@ -0,0 +1,154 @@ +module Network.DNS.Packer + ( Packer + + , pack + , pack' + + , getState + , setState + , modifyState + + , bytesWrote + , withOffset + + , putByteString + , putLazyByteString + + , putWord8 + , putWord16be + , putWord32be + + , putBinary + , liftToBinary + ) + where + +import qualified Data.Binary as Binary +import qualified Data.Binary.Put as Bin +import Data.Bits +import qualified Data.ByteString as Strict +import qualified Data.ByteString.Lazy as Lazy +import Data.Int +import Data.Word + + +data PackingState s + = PackingState { + stResult :: !Lazy.ByteString + , stBytesWrote :: !Int64 + , stUserState :: s + } + +newtype Packer s a = P { unP :: PackingState s -> (a, PackingState s) } + +instance Monad (Packer s) where + return a = P (\ s -> (a, s)) + m >>= k = P (\ s -> let (a, s') = unP m s + in + unP (k a) s') + fail err = do bytes <- get stBytesWrote + P (error (err + ++ ". Failed packing at byte position " + ++ show bytes)) + +get :: (PackingState s -> a) -> Packer s a +get f = P (\ s -> (f s, s)) + +set :: (PackingState s -> PackingState s) -> Packer s () +set f = P (\ s -> ((), f s)) + +mkState :: Lazy.ByteString -> Int64 -> s -> PackingState s +mkState xs n s + = PackingState { + stResult = xs + , stBytesWrote = n + , stUserState = s + } + +pack' :: Packer s a -> s -> (Lazy.ByteString, s, a) +pack' m s + = let (a, s') = unP m (mkState Lazy.empty 0 s) + in + (stResult s', stUserState s', a) + +pack :: Packer s a -> s -> Lazy.ByteString +pack = (fst' .) . pack' + where + fst' (xs, _, _) = xs + +getState :: Packer s s +getState = get stUserState + +setState :: s -> Packer s () +setState = modifyState . const + +modifyState :: (s -> s) -> Packer s () +modifyState f + = set $ \ st -> st { stUserState = f (stUserState st) } + +bytesWrote :: Integral i => Packer s i +bytesWrote = get stBytesWrote >>= return . fromIntegral + +withOffset :: Int64 -> Packer s a -> Packer s a +withOffset n m + = P $ \ s -> let (taken, dropped) = Lazy.splitAt n (stResult s) + padded = Lazy.take n (taken `Lazy.append` Lazy.repeat 0) + tempState = s { + stResult = padded + , stBytesWrote = stBytesWrote s - Lazy.length dropped + } + (a, tempState') = unP m tempState + newState = tempState { + stResult = replaceHead (stResult s) (stResult tempState') + , stBytesWrote = max (stBytesWrote s) (stBytesWrote tempState') + } + in + (a, newState) + where + replaceHead :: Lazy.ByteString -> Lazy.ByteString -> Lazy.ByteString + replaceHead world newHead + = let rest = Lazy.drop (Lazy.length newHead) world + in + newHead `Lazy.append` rest + + +putByteString :: Strict.ByteString -> Packer s () +putByteString = putLazyByteString . Lazy.fromChunks . (:[]) + +putLazyByteString :: Lazy.ByteString -> Packer s () +putLazyByteString xs + = set $ \ st -> st { + stResult = stResult st `Lazy.append` xs + , stBytesWrote = stBytesWrote st + Lazy.length xs + } + +putWord8 :: Word8 -> Packer s () +putWord8 w + = set $ \ st -> st { + stResult = stResult st `Lazy.snoc` w + , stBytesWrote = stBytesWrote st + 1 + } + +putWord16be :: Word16 -> Packer s () +putWord16be w + = do putWord8 $ fromIntegral $ (w `shiftR` 8) .&. 0xFF + putWord8 $ fromIntegral $ w .&. 0xFF + +putWord32be :: Word32 -> Packer s () +putWord32be w + = do putWord8 $ fromIntegral $ (w `shiftR` 24) .&. 0xFF + putWord8 $ fromIntegral $ (w `shiftR` 16) .&. 0xFF + putWord8 $ fromIntegral $ (w `shiftR` 8) .&. 0xFF + putWord8 $ fromIntegral $ w .&. 0xFF + + +putBinary :: Binary.Binary a => a -> Packer s () +putBinary = putLazyByteString . Binary.encode + + +liftToBinary :: s -> Packer s a -> Bin.PutM a +liftToBinary s m + = do let (a, s') = unP m (mkState Lazy.empty 0 s) + + Bin.putLazyByteString (stResult s') + return a diff --git a/Network/DNS/Unpacker.hs b/Network/DNS/Unpacker.hs index db34946..3610650 100644 --- a/Network/DNS/Unpacker.hs +++ b/Network/DNS/Unpacker.hs @@ -26,9 +26,9 @@ module Network.DNS.Unpacker import qualified Data.Binary as Binary import qualified Data.Binary.Get as Bin +import Data.Bits import qualified Data.ByteString as Strict import qualified Data.ByteString.Lazy as Lazy -import Data.Bits import Data.Int import Data.Word diff --git a/dns.cabal b/dns.cabal index 3f60912..b62c871 100644 --- a/dns.cabal +++ b/dns.cabal @@ -22,8 +22,7 @@ Library Exposed-Modules: Network.DNS.Message - - Other-Modules: + Network.DNS.Packer Network.DNS.Unpacker Extensions: -- 2.40.0