]> gitweb @ CieloNegro.org - haskell-dns.git/commitdiff
Introduce Packer monad so that we can compress binary packets.
authorPHO <pho@cielonegro.org>
Fri, 22 May 2009 07:35:49 +0000 (16:35 +0900)
committerPHO <pho@cielonegro.org>
Fri, 22 May 2009 07:35:49 +0000 (16:35 +0900)
Network/DNS/Message.hs
Network/DNS/Packer.hs [new file with mode: 0644]
Network/DNS/Unpacker.hs
dns.cabal

index be0b79a33800a32dcc769c7ff68129f322261bf4..5c537956bd1657284a8c3b15db3dd9e9941cce62 100644 (file)
@@ -34,15 +34,17 @@ import           Control.Monad
 import           Data.Binary
 import           Data.Binary.BitPut as BP
 import           Data.Binary.Get as G
-import           Data.Binary.Put as P
+import           Data.Binary.Put as P'
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
-import qualified Data.ByteString.Lazy as LBS
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
+import qualified Data.Map as M
+import           Data.Map (Map)
 import           Data.Word
+import           Network.DNS.Packer as P
 import           Network.DNS.Unpacker as U
 import           Network.Socket
 
@@ -108,11 +110,11 @@ data Question
 
 type SomeQT = SomeRT
 
-putQ :: Question -> Put
+putQ :: Question -> Packer CompTable ()
 putQ q
     = do putDomainName $ qName q
          putSomeRT $ qType q
-         put $ qClass q
+         putBinary $ qClass q
 
 getQ :: Unpacker DecompTable Question
 getQ = do nm <- getDomainName
@@ -125,23 +127,25 @@ getQ = do nm <- getDomainName
                      }
 
 
-newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
 type DomainLabel    = BS.ByteString
 
-nameToLabels :: DomainName -> [DomainLabel]
-nameToLabels (DN ls) = ls
-
-labelsToName :: [DomainLabel] -> DomainName
-labelsToName = DN
-
 rootName :: DomainName
 rootName = DN [BS.empty]
 
+isRootName :: DomainName -> Bool
+isRootName (DN [_]) = True
+isRootName _        = False
+
 consLabel :: DomainLabel -> DomainName -> DomainName
 consLabel x (DN ys) = DN (x:ys)
 
+unconsLabel :: DomainName -> (DomainLabel, DomainName)
+unconsLabel (DN (x:xs)) = (x, DN xs)
+unconsLabel x           = error ("Illegal use of unconsLabel: " ++ show x)
+
 mkDomainName :: String -> DomainName
-mkDomainName = labelsToName . mkLabels [] . notEmpty
+mkDomainName = DN . mkLabels [] . notEmpty
     where
       notEmpty :: String -> String
       notEmpty xs = assert (not $ null xs) xs
@@ -173,18 +177,6 @@ data RecordType rt dt => ResourceRecord rt dt
     deriving (Show, Eq, Typeable)
 
 
-putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
-putRR rr = do putDomainName $ rrName rr
-              putRecordType $ rrType  rr
-              put $ rrClass rr
-              putWord32be $ rrTTL rr
-
-              let dat = runPut $
-                        putRecordData (undefined :: rt) (rrData rr)
-              putWord16be $ fromIntegral $ LBS.length dat
-              putLazyByteString dat
-
-
 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
 
 instance Show SomeRR where
@@ -194,8 +186,8 @@ instance Eq SomeRR where
     (SomeRR a) == (SomeRR b) = Just a == cast b
 
 
-putSomeRR :: SomeRR -> Put
-putSomeRR (SomeRR rr) = putRR rr
+putSomeRR :: SomeRR -> Packer CompTable ()
+putSomeRR (SomeRR rr) = putResourceRecord rr
 
 getSomeRR :: Unpacker DecompTable SomeRR
 getSomeRR = do srt <- U.lookAhead $
@@ -205,8 +197,9 @@ getSomeRR = do srt <- U.lookAhead $
                  SomeRT rt
                      -> getResourceRecord rt >>= return . SomeRR
 
+type CompTable   = Map DomainName Int
 type DecompTable = IntMap DomainName
-type TTL = Word32
+type TTL         = Word32
 
 getDomainName :: Unpacker DecompTable DomainName
 getDomainName = worker
@@ -217,7 +210,7 @@ getDomainName = worker
                hdr    <- getLabelHeader
                case hdr of
                  Offset n
-                     -> do dt <- getState
+                     -> do dt <- U.getState
                            case IM.lookup n dt of
                              Just name
                                  -> return name
@@ -229,7 +222,7 @@ getDomainName = worker
                      -> do label <- U.getByteString n
                            rest  <- worker
                            let name = consLabel label rest
-                           modifyState $ IM.insert offset name
+                           U.modifyState $ IM.insert offset name
                            return name
 
       getLabelHeader :: Unpacker s LabelHeader
@@ -262,28 +255,64 @@ getCharString :: Unpacker s BS.ByteString
 getCharString = do len <- U.getWord8
                    U.getByteString (fromIntegral len)
 
-putCharString :: BS.ByteString -> Put
-putCharString = putDomainLabel
+putCharString :: BS.ByteString -> Packer s ()
+putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs
+                      P.putByteString xs
 
 data LabelHeader
     = Offset !Int
     | Length !Int
 
-putDomainName :: DomainName -> Put
-putDomainName = mapM_ putDomainLabel . nameToLabels
+putDomainName :: DomainName -> Packer CompTable ()
+putDomainName name
+    = do ct <- P.getState
+         case M.lookup name ct of
+           Just n
+               -> do let ptr = runBitPut $
+                               do putBit True
+                                  putBit True
+                                  putNBits 14 n
+                     P.putLazyByteString ptr
+           Nothing
+               -> do offset <- bytesWrote
+                     P.modifyState $ M.insert name offset
+
+                     let (label, rest) = unconsLabel name
+
+                     putCharString label
+
+                     if isRootName rest then
+                         P.putWord8 0
+                       else
+                         putDomainName rest
 
-putDomainLabel :: DomainLabel -> Put
-putDomainLabel l
-    = do putWord8 $ fromIntegral $ BS.length l
-         P.putByteString l
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
-    putRecordData :: rt -> dt -> Put
+    putRecordData :: rt -> dt -> Packer CompTable ()
     getRecordData :: rt -> Unpacker DecompTable dt
 
-    putRecordType :: rt -> Put
-    putRecordType = putWord16be . fromIntegral . rtToInt
+    putRecordType :: rt -> Packer s ()
+    putRecordType = P.putWord16be . fromIntegral . rtToInt
+
+    putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
+    putResourceRecord rr
+        = do putDomainName $ rrName  rr
+             putRecordType $ rrType  rr
+             putBinary     $ rrClass rr
+             P.putWord32be $ rrTTL   rr
+
+             -- First, write a dummy data length.
+             offset <- bytesWrote
+             P.putWord16be 0
+
+             -- Second, write data.
+             putRecordData (rrType rr) (rrData rr)
+
+             -- Third, rewrite the dummy length to an actual value.
+             offset' <- bytesWrote
+             withOffset offset
+                 $ P.putWord16be (fromIntegral (offset' - offset - 2))
 
     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
     getResourceRecord rt
@@ -309,7 +338,7 @@ instance Show SomeRT where
 instance Eq SomeRT where
     (SomeRT a) == (SomeRT b) = Just a == cast b
 
-putSomeRT :: SomeRT -> Put
+putSomeRT :: SomeRT -> Packer s ()
 putSomeRT (SomeRT rt) = putRecordType rt
 
 getSomeRT :: Unpacker s SomeRT
@@ -323,7 +352,7 @@ getSomeRT = do n <- liftM fromIntegral U.getWord16be
 data A = A deriving (Show, Eq, Typeable)
 instance RecordType A HostAddress where
     rtToInt       _ = 1
-    putRecordData _ = putWord32be
+    putRecordData _ = P.putWord32be
     getRecordData _ = U.getWord32be
 
 data NS = NS deriving (Show, Eq, Typeable)
@@ -376,17 +405,18 @@ data RecordType
 -}
 
 instance Binary Message where
-    put m = do put $ msgHeader m
-               putWord16be $ fromIntegral $ length $ msgQuestions m
-               putWord16be $ fromIntegral $ length $ msgAnswers m
-               putWord16be $ fromIntegral $ length $ msgAuthorities m
-               putWord16be $ fromIntegral $ length $ msgAdditionals m
-               mapM_ putQ  $ msgQuestions m
+    put m = P.liftToBinary M.empty $
+            do putBinary $ msgHeader m
+               P.putWord16be $ fromIntegral $ length $ msgQuestions m
+               P.putWord16be $ fromIntegral $ length $ msgAnswers m
+               P.putWord16be $ fromIntegral $ length $ msgAuthorities m
+               P.putWord16be $ fromIntegral $ length $ msgAdditionals m
+               mapM_ putQ      $ msgQuestions m
                mapM_ putSomeRR $ msgAnswers m
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
 
-    get = liftToBinary IM.empty $
+    get = U.liftToBinary IM.empty $
           do hdr  <- getBinary
              nQ   <- liftM fromIntegral U.getWord16be
              nAns <- liftM fromIntegral U.getWord16be
@@ -405,8 +435,8 @@ instance Binary Message where
                         }
 
 instance Binary Header where
-    put h = do putWord16be $ hdMessageID h
-               putLazyByteString flags
+    put h = do P'.putWord16be $ hdMessageID h
+               P'.putLazyByteString flags
         where
           flags = runBitPut $
                   do putNBits 1 $ fromEnum $ hdMessageType h
@@ -538,7 +568,7 @@ instance Enum RecordClass where
 
 instance Binary RecordClass where
     get = liftM (toEnum . fromIntegral) G.getWord16be
-    put = putWord16be . fromIntegral . fromEnum
+    put = P'.putWord16be . fromIntegral . fromEnum
 
 
 defaultRTTable :: IntMap SomeRT
diff --git a/Network/DNS/Packer.hs b/Network/DNS/Packer.hs
new file mode 100644 (file)
index 0000000..7f8f895
--- /dev/null
@@ -0,0 +1,154 @@
+module Network.DNS.Packer
+    ( Packer
+
+    , pack
+    , pack'
+
+    , getState
+    , setState
+    , modifyState
+
+    , bytesWrote
+    , withOffset
+
+    , putByteString
+    , putLazyByteString
+
+    , putWord8
+    , putWord16be
+    , putWord32be
+
+    , putBinary
+    , liftToBinary
+    )
+    where
+
+import qualified Data.Binary as Binary
+import qualified Data.Binary.Put as Bin
+import           Data.Bits
+import qualified Data.ByteString as Strict
+import qualified Data.ByteString.Lazy as Lazy
+import           Data.Int
+import           Data.Word
+
+
+data PackingState s
+    = PackingState {
+        stResult     :: !Lazy.ByteString
+      , stBytesWrote :: !Int64
+      , stUserState  :: s
+      }
+
+newtype Packer s a = P { unP :: PackingState s -> (a, PackingState s) }
+
+instance Monad (Packer s) where
+    return a = P (\ s -> (a, s))
+    m >>= k  = P (\ s -> let (a, s') = unP m s
+                         in
+                           unP (k a) s')
+    fail err = do bytes <- get stBytesWrote
+                  P (error (err
+                            ++ ". Failed packing at byte position "
+                            ++ show bytes))
+
+get :: (PackingState s -> a) -> Packer s a
+get f = P (\ s -> (f s, s))
+
+set :: (PackingState s -> PackingState s) -> Packer s ()
+set f = P (\ s -> ((), f s))
+
+mkState :: Lazy.ByteString -> Int64 -> s -> PackingState s
+mkState xs n s
+    = PackingState {
+        stResult     = xs
+      , stBytesWrote = n
+      , stUserState  = s
+      }
+
+pack' :: Packer s a -> s -> (Lazy.ByteString, s, a)
+pack' m s
+    = let (a, s') = unP m (mkState Lazy.empty 0 s)
+      in
+        (stResult s', stUserState s', a)
+
+pack :: Packer s a -> s -> Lazy.ByteString
+pack = (fst' .) . pack'
+    where
+      fst' (xs, _, _) = xs
+
+getState :: Packer s s
+getState = get stUserState
+
+setState :: s -> Packer s ()
+setState = modifyState . const
+
+modifyState :: (s -> s) -> Packer s ()
+modifyState f
+    = set $ \ st -> st { stUserState = f (stUserState st) }
+
+bytesWrote :: Integral i => Packer s i
+bytesWrote = get stBytesWrote >>= return . fromIntegral
+
+withOffset :: Int64 -> Packer s a -> Packer s a
+withOffset n m
+    = P $ \ s -> let (taken, dropped) = Lazy.splitAt n (stResult s)
+                     padded           = Lazy.take n (taken `Lazy.append` Lazy.repeat 0)
+                     tempState        = s {
+                                          stResult     = padded
+                                        , stBytesWrote = stBytesWrote s - Lazy.length dropped
+                                        }
+                     (a, tempState')  = unP m tempState
+                     newState         = tempState {
+                                          stResult     = replaceHead (stResult s) (stResult tempState')
+                                        , stBytesWrote = max (stBytesWrote s) (stBytesWrote tempState')
+                                        }
+                 in
+                   (a, newState)
+      where
+        replaceHead :: Lazy.ByteString -> Lazy.ByteString -> Lazy.ByteString
+        replaceHead world newHead
+            = let rest = Lazy.drop (Lazy.length newHead) world
+              in
+                newHead `Lazy.append` rest
+
+
+putByteString :: Strict.ByteString -> Packer s ()
+putByteString = putLazyByteString . Lazy.fromChunks . (:[])
+
+putLazyByteString :: Lazy.ByteString -> Packer s ()
+putLazyByteString xs
+    = set $ \ st -> st {
+                      stResult     = stResult st `Lazy.append` xs
+                    , stBytesWrote = stBytesWrote st + Lazy.length xs
+                    }
+
+putWord8 :: Word8 -> Packer s ()
+putWord8 w
+    = set $ \ st -> st {
+                      stResult     = stResult st `Lazy.snoc` w
+                    , stBytesWrote = stBytesWrote st + 1
+                    }
+
+putWord16be :: Word16 -> Packer s ()
+putWord16be w
+    = do putWord8 $ fromIntegral $ (w `shiftR`  8) .&. 0xFF
+         putWord8 $ fromIntegral $  w              .&. 0xFF
+
+putWord32be :: Word32 -> Packer s ()
+putWord32be w
+    = do putWord8 $ fromIntegral $ (w `shiftR` 24) .&. 0xFF
+         putWord8 $ fromIntegral $ (w `shiftR` 16) .&. 0xFF
+         putWord8 $ fromIntegral $ (w `shiftR`  8) .&. 0xFF
+         putWord8 $ fromIntegral $  w              .&. 0xFF
+
+
+putBinary :: Binary.Binary a => a -> Packer s ()
+putBinary = putLazyByteString . Binary.encode
+
+
+liftToBinary :: s -> Packer s a -> Bin.PutM a
+liftToBinary s m
+    = do let (a, s') = unP m (mkState Lazy.empty 0 s)
+
+         Bin.putLazyByteString (stResult s')
+         return a
index db349468262015599ccaf996bd98ceda693ff0c1..36106501e2f432a02f427b4830b8ff76a983099f 100644 (file)
@@ -26,9 +26,9 @@ module Network.DNS.Unpacker
 
 import qualified Data.Binary as Binary
 import qualified Data.Binary.Get as Bin
+import           Data.Bits
 import qualified Data.ByteString as Strict
 import qualified Data.ByteString.Lazy as Lazy
-import           Data.Bits
 import           Data.Int
 import           Data.Word
 
index 3f609129c303318b7b38eab2481378c51a3087a0..b62c8714a7d20dc838cd9eeaff9a50bbfc79fadd 100644 (file)
--- a/dns.cabal
+++ b/dns.cabal
@@ -22,8 +22,7 @@ Library
 
     Exposed-Modules:
         Network.DNS.Message
-
-    Other-Modules:
+        Network.DNS.Packer
         Network.DNS.Unpacker
 
     Extensions: