]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Message.hs
Introduce Packer monad so that we can compress binary packets.
[haskell-dns.git] / Network / DNS / Message.hs
index a3c04821f0f90b60a838ad7fe7fa8e10093dd20c..5c537956bd1657284a8c3b15db3dd9e9941cce62 100644 (file)
@@ -1,38 +1,95 @@
 module Network.DNS.Message
-    ( Header(..)
+    ( Message(..)
+    , MessageID
+    , MessageType(..)
+    , Header(..)
     , Opcode(..)
     , ResponseCode(..)
+    , Question(..)
+    , ResourceRecord(..)
+    , DomainName
+    , DomainLabel
+    , TTL
+    , RecordType
+    , RecordClass(..)
+
+    , SomeQT
+    , SomeRR
+    , SomeRT
+
+    , A(..)
+    , NS(..)
+    , CNAME(..)
+    , HINFO(..)
+
+    , mkDomainName
+    , wrapQueryType
+    , wrapRecordType
+    , wrapRecord
     )
     where
 
+import           Control.Exception
+import           Control.Monad
 import           Data.Binary
-import           Data.Binary.Get
-import           Data.Binary.Put
-import           Data.Bits
+import           Data.Binary.BitPut as BP
+import           Data.Binary.Get as G
+import           Data.Binary.Put as P'
+import           Data.Binary.Strict.BitGet as BG
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
+import           Data.Typeable
+import qualified Data.IntMap as IM
+import           Data.IntMap (IntMap)
+import qualified Data.Map as M
+import           Data.Map (Map)
 import           Data.Word
+import           Network.DNS.Packer as P
+import           Network.DNS.Unpacker as U
+import           Network.Socket
 
 
-data Header
-    = QueryHeader {
-        hdMessageID             :: !Word16
-      , hdOpcode                :: !Opcode
-      , hdIsTruncated           :: !Bool
-      , hdIsRecursionDesired    :: !Bool
+data Message
+    = Message {
+        msgHeader      :: !Header
+      , msgQuestions   :: ![Question]
+      , msgAnswers     :: ![SomeRR]
+      , msgAuthorities :: ![SomeRR]
+      , msgAdditionals :: ![SomeRR]
       }
-    | ResponseHeader {
-        hdMessageID             :: !Word16
+    deriving (Show, Eq)
+
+data Header
+    = Header {
+        hdMessageID             :: !MessageID
+      , hdMessageType           :: !MessageType
       , hdOpcode                :: !Opcode
       , hdIsAuthoritativeAnswer :: !Bool
       , hdIsTruncated           :: !Bool
       , hdIsRecursionDesired    :: !Bool
       , hdIsRecursionAvailable  :: !Bool
       , hdResponseCode          :: !ResponseCode
+
+      -- These fields are supressed in this data structure:
+      -- + QDCOUNT
+      -- + ANCOUNT
+      -- + NSCOUNT
+      -- + ARCOUNT
       }
+    deriving (Show, Eq)
+
+type MessageID = Word16
+
+data MessageType
+    = Query
+    | Response
+    deriving (Show, Eq)
 
 data Opcode
     = StandardQuery
     | InverseQuery
     | ServerStatusRequest
+    deriving (Show, Eq)
 
 data ResponseCode
     = NoError
@@ -43,69 +100,386 @@ data ResponseCode
     | Refused
     deriving (Show, Eq)
 
-hdIsResponse :: Header -> Bool
-hdIsResponse (QueryHeader    _ _ _ _      ) = False
-hdIsResponse (ResponseHeader _ _ _ _ _ _ _) = True
+data Question
+    = Question {
+        qName  :: !DomainName
+      , qType  :: !SomeQT
+      , qClass :: !RecordClass
+      }
+    deriving (Show, Eq)
+
+type SomeQT = SomeRT
+
+putQ :: Question -> Packer CompTable ()
+putQ q
+    = do putDomainName $ qName q
+         putSomeRT $ qType q
+         putBinary $ qClass q
+
+getQ :: Unpacker DecompTable Question
+getQ = do nm <- getDomainName
+          ty <- getSomeRT
+          cl <- getBinary
+          return Question {
+                       qName  = nm
+                     , qType  = ty
+                     , qClass = cl
+                     }
+
+
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
+type DomainLabel    = BS.ByteString
+
+rootName :: DomainName
+rootName = DN [BS.empty]
+
+isRootName :: DomainName -> Bool
+isRootName (DN [_]) = True
+isRootName _        = False
+
+consLabel :: DomainLabel -> DomainName -> DomainName
+consLabel x (DN ys) = DN (x:ys)
+
+unconsLabel :: DomainName -> (DomainLabel, DomainName)
+unconsLabel (DN (x:xs)) = (x, DN xs)
+unconsLabel x           = error ("Illegal use of unconsLabel: " ++ show x)
+
+mkDomainName :: String -> DomainName
+mkDomainName = DN . mkLabels [] . notEmpty
+    where
+      notEmpty :: String -> String
+      notEmpty xs = assert (not $ null xs) xs
+
+      mkLabels :: [DomainLabel] -> String -> [DomainLabel]
+      mkLabels soFar [] = reverse (C8.empty : soFar)
+      mkLabels soFar xs = case break (== '.') xs of
+                            (l, ('.':rest))
+                                -> mkLabels (C8.pack l : soFar) rest
+                            _   -> error ("Illegal domain name: " ++ xs)
+
+data RecordClass
+    = IN
+    | CS -- Obsolete
+    | CH
+    | HS
+    | AnyClass -- Only for queries
+    deriving (Show, Eq)
+
+
+data RecordType rt dt => ResourceRecord rt dt
+    = ResourceRecord {
+        rrName  :: !DomainName
+      , rrType  :: !rt
+      , rrClass :: !RecordClass
+      , rrTTL   :: !TTL
+      , rrData  :: !dt
+      }
+    deriving (Show, Eq, Typeable)
+
+
+data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
+
+instance Show SomeRR where
+    show (SomeRR rr) = show rr
+
+instance Eq SomeRR where
+    (SomeRR a) == (SomeRR b) = Just a == cast b
+
+
+putSomeRR :: SomeRR -> Packer CompTable ()
+putSomeRR (SomeRR rr) = putResourceRecord rr
+
+getSomeRR :: Unpacker DecompTable SomeRR
+getSomeRR = do srt <- U.lookAhead $
+                      do getDomainName -- skip
+                         getSomeRT
+               case srt of
+                 SomeRT rt
+                     -> getResourceRecord rt >>= return . SomeRR
+
+type CompTable   = Map DomainName Int
+type DecompTable = IntMap DomainName
+type TTL         = Word32
+
+getDomainName :: Unpacker DecompTable DomainName
+getDomainName = worker
+    where
+      worker :: Unpacker DecompTable DomainName
+      worker
+          = do offset <- U.bytesRead
+               hdr    <- getLabelHeader
+               case hdr of
+                 Offset n
+                     -> do dt <- U.getState
+                           case IM.lookup n dt of
+                             Just name
+                                 -> return name
+                             Nothing
+                                 -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
+                 Length 0
+                     -> return rootName
+                 Length n
+                     -> do label <- U.getByteString n
+                           rest  <- worker
+                           let name = consLabel label rest
+                           U.modifyState $ IM.insert offset name
+                           return name
+
+      getLabelHeader :: Unpacker s LabelHeader
+      getLabelHeader
+          = do header <- U.lookAhead $ U.getByteString 1
+               let Right h
+                       = runBitGet header $
+                         do a <- getBit
+                            b <- getBit
+                            n <- liftM fromIntegral (getAsWord8 6)
+                            case (a, b) of
+                              ( True,  True) -> return $ Offset n
+                              (False, False) -> return $ Length n
+                              _              -> fail "Illegal label header"
+               case h of
+                 Offset _
+                     -> do header' <- U.getByteString 2 -- Pointers have 2 octets.
+                           let Right h'
+                                   = runBitGet header' $
+                                     do BG.skip 2
+                                        n <- liftM fromIntegral (getAsWord16 14)
+                                        return $ Offset n
+                           return h'
+                 len@(Length _)
+                     -> do U.skip 1
+                           return len
+
+
+getCharString :: Unpacker s BS.ByteString
+getCharString = do len <- U.getWord8
+                   U.getByteString (fromIntegral len)
+
+putCharString :: BS.ByteString -> Packer s ()
+putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs
+                      P.putByteString xs
+
+data LabelHeader
+    = Offset !Int
+    | Length !Int
+
+putDomainName :: DomainName -> Packer CompTable ()
+putDomainName name
+    = do ct <- P.getState
+         case M.lookup name ct of
+           Just n
+               -> do let ptr = runBitPut $
+                               do putBit True
+                                  putBit True
+                                  putNBits 14 n
+                     P.putLazyByteString ptr
+           Nothing
+               -> do offset <- bytesWrote
+                     P.modifyState $ M.insert name offset
+
+                     let (label, rest) = unconsLabel name
+
+                     putCharString label
+
+                     if isRootName rest then
+                         P.putWord8 0
+                       else
+                         putDomainName rest
+
+
+class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
+    rtToInt       :: rt -> Int
+    putRecordData :: rt -> dt -> Packer CompTable ()
+    getRecordData :: rt -> Unpacker DecompTable dt
+
+    putRecordType :: rt -> Packer s ()
+    putRecordType = P.putWord16be . fromIntegral . rtToInt
+
+    putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
+    putResourceRecord rr
+        = do putDomainName $ rrName  rr
+             putRecordType $ rrType  rr
+             putBinary     $ rrClass rr
+             P.putWord32be $ rrTTL   rr
+
+             -- First, write a dummy data length.
+             offset <- bytesWrote
+             P.putWord16be 0
+
+             -- Second, write data.
+             putRecordData (rrType rr) (rrData rr)
+
+             -- Third, rewrite the dummy length to an actual value.
+             offset' <- bytesWrote
+             withOffset offset
+                 $ P.putWord16be (fromIntegral (offset' - offset - 2))
+
+    getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
+    getResourceRecord rt
+        = do name     <- getDomainName
+             U.skip 2 -- record type
+             cl       <- getBinary
+             ttl      <- U.getWord32be
+             U.skip 2 -- data length
+             dat      <- getRecordData rt
+             return $ ResourceRecord {
+                          rrName  = name
+                        , rrType  = rt
+                        , rrClass = cl
+                        , rrTTL   = ttl
+                        , rrData  = dat
+                        }
+
+data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
+
+instance Show SomeRT where
+    show (SomeRT rt) = show rt
+
+instance Eq SomeRT where
+    (SomeRT a) == (SomeRT b) = Just a == cast b
+
+putSomeRT :: SomeRT -> Packer s ()
+putSomeRT (SomeRT rt) = putRecordType rt
+
+getSomeRT :: Unpacker s SomeRT
+getSomeRT = do n <- liftM fromIntegral U.getWord16be
+               case IM.lookup n defaultRTTable of
+                 Nothing
+                     -> fail ("Unknown resource record type: " ++ show n)
+                 Just srt
+                     -> return srt
+
+data A = A deriving (Show, Eq, Typeable)
+instance RecordType A HostAddress where
+    rtToInt       _ = 1
+    putRecordData _ = P.putWord32be
+    getRecordData _ = U.getWord32be
+
+data NS = NS deriving (Show, Eq, Typeable)
+instance RecordType NS DomainName where
+    rtToInt       _ = 2
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data CNAME = CNAME deriving (Show, Eq, Typeable)
+instance RecordType CNAME DomainName where
+    rtToInt       _ = 5
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data HINFO = HINFO deriving (Show, Eq, Typeable)
+instance RecordType HINFO (BS.ByteString, BS.ByteString) where
+    rtToInt       _           = 13
+    putRecordData _ (cpu, os) = do putCharString cpu
+                                   putCharString os
+    getRecordData _           = do cpu <- getCharString
+                                   os  <- getCharString
+                                   return (cpu, os)
+
+
+{-
+data RecordType
+    = A
+    | NS
+    | MD
+    | MF
+    | CNAME
+    | SOA
+    | MB
+    | MG
+    | MR
+    | NULL
+    | WKS
+    | PTR
+    | HINFO
+    | MINFO
+    | MX
+    | TXT
+
+    -- Only for queries:
+    | AXFR
+    | MAILB -- Obsolete
+    | MAILA -- Obsolete
+    | AnyType
+    deriving (Show, Eq)
+-}
+
+instance Binary Message where
+    put m = P.liftToBinary M.empty $
+            do putBinary $ msgHeader m
+               P.putWord16be $ fromIntegral $ length $ msgQuestions m
+               P.putWord16be $ fromIntegral $ length $ msgAnswers m
+               P.putWord16be $ fromIntegral $ length $ msgAuthorities m
+               P.putWord16be $ fromIntegral $ length $ msgAdditionals m
+               mapM_ putQ      $ msgQuestions m
+               mapM_ putSomeRR $ msgAnswers m
+               mapM_ putSomeRR $ msgAuthorities m
+               mapM_ putSomeRR $ msgAdditionals m
+
+    get = U.liftToBinary IM.empty $
+          do hdr  <- getBinary
+             nQ   <- liftM fromIntegral U.getWord16be
+             nAns <- liftM fromIntegral U.getWord16be
+             nAth <- liftM fromIntegral U.getWord16be
+             nAdd <- liftM fromIntegral U.getWord16be
+             qs   <- replicateM nQ   getQ
+             anss <- replicateM nAns getSomeRR
+             aths <- replicateM nAth getSomeRR
+             adds <- replicateM nAdd getSomeRR
+             return Message {
+                          msgHeader      = hdr
+                        , msgQuestions   = qs
+                        , msgAnswers     = anss
+                        , msgAuthorities = aths
+                        , msgAdditionals = adds
+                        }
 
 instance Binary Header where
-    put h = do putWord16be $ hdMessageID h
-               let qr    = boolToNum $ hdIsResponse h
-                   op    = fromIntegral $ fromEnum $ hdOpcode h
-                   aa    = if hdIsResponse h then
-                               boolToNum $ hdIsAuthoritativeAnswer h
-                           else
-                               0
-                   tc    = boolToNum $ hdIsTruncated h
-                   rd    = boolToNum $ hdIsRecursionDesired h
-                   ra    = if hdIsResponse h then
-                               boolToNum $ hdIsRecursionAvailable h
-                           else
-                               0
-                   rc    = if hdIsResponse h then
-                               fromIntegral $ fromEnum $ hdResponseCode h
-                           else
-                               0
-                   flags = ((qr `shiftL` 15) .&. 0x01) .|.
-                           ((op `shiftL` 11) .&. 0x0F) .|.
-                           ((aa `shiftL` 10) .&. 0x01) .|.
-                           ((tc `shiftL`  9) .&. 0x01) .|.
-                           ((rd `shiftL`  8) .&. 0x01) .|.
-                           ((ra `shiftL`  7) .&. 0x01) .|.
-                           ((rc `shiftL`  0) .&. 0x0F)
-               putWord16be flags
+    put h = do P'.putWord16be $ hdMessageID h
+               P'.putLazyByteString flags
         where
-          boolToNum :: Num a => Bool -> a
-          boolToNum True  = 1
-          boolToNum False = 0
-
-    get = do mID   <- getWord16be
-             flags <- getWord16be
-             let qr = testBit flags 15
-                 op = toEnum $ fromIntegral ((flags `shiftR` 11) .&. 0x0F)
-                 aa = testBit flags 10
-                 tc = testBit flags 9
-                 rd = testBit flags 8
-                 ra = testBit flags 7
-                 rc = toEnum $ fromIntegral (flags .&. 0x0F)
-                 hd = if qr then
-                          ResponseHeader {
-                            hdMessageID             = mID
-                          , hdOpcode                = op
-                          , hdIsAuthoritativeAnswer = aa
-                          , hdIsTruncated           = tc
-                          , hdIsRecursionDesired    = rd
-                          , hdIsRecursionAvailable  = ra
-                          , hdResponseCode          = rc
-                          }
-                      else
-                          QueryHeader {
-                            hdMessageID          = mID
-                          , hdOpcode             = op
-                          , hdIsTruncated        = tc
-                          , hdIsRecursionDesired = rd
-                          }
+          flags = runBitPut $
+                  do putNBits 1 $ fromEnum $ hdMessageType h
+                     putNBits 4 $ fromEnum $ hdOpcode h
+                     putBit $ hdIsAuthoritativeAnswer h
+                     putBit $ hdIsTruncated h
+                     putBit $ hdIsRecursionDesired h
+                     putBit $ hdIsRecursionAvailable h
+                     putNBits 3 (0 :: Int)
+                     putNBits 4 $ fromEnum $ hdResponseCode h
+
+    get = do mID   <- G.getWord16be
+             flags <- G.getByteString 2
+             let Right hd
+                     = runBitGet flags $
+                       do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
+                          op <- liftM (toEnum . fromIntegral) $ getAsWord8 4
+                          aa <- getBit
+                          tc <- getBit
+                          rd <- getBit
+                          ra <- getBit
+                          BG.skip 3
+                          rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4
+                          return Header {
+                                       hdMessageID             = mID
+                                     , hdMessageType           = qr
+                                     , hdOpcode                = op
+                                     , hdIsAuthoritativeAnswer = aa
+                                     , hdIsTruncated           = tc
+                                     , hdIsRecursionDesired    = rd
+                                     , hdIsRecursionAvailable  = ra
+                                     , hdResponseCode          = rc
+                                     }
              return hd
 
+instance Enum MessageType where
+    fromEnum Query    = 0
+    fromEnum Response = 1
+
+    toEnum 0 = Query
+    toEnum 1 = Response
+    toEnum _ = undefined
+
 instance Enum Opcode where
     fromEnum StandardQuery       = 0
     fromEnum InverseQuery        = 1
@@ -116,10 +490,6 @@ instance Enum Opcode where
     toEnum 2 = ServerStatusRequest
     toEnum _ = undefined
 
-instance Bounded Opcode where
-    minBound = StandardQuery
-    maxBound = ServerStatusRequest
-
 instance Enum ResponseCode where
     fromEnum NoError        = 0
     fromEnum FormatError    = 1
@@ -136,6 +506,88 @@ instance Enum ResponseCode where
     toEnum 5 = Refused
     toEnum _ = undefined
 
-instance Bounded ResponseCode where
-    minBound = NoError
-    maxBound = Refused
+{-
+instance Enum RecordType where
+    fromEnum A       = 1
+    fromEnum NS      = 2
+    fromEnum MD      = 3
+    fromEnum MF      = 4
+    fromEnum CNAME   = 5
+    fromEnum SOA     = 6
+    fromEnum MB      = 7
+    fromEnum MG      = 8
+    fromEnum MR      = 9
+    fromEnum NULL    = 10
+    fromEnum WKS     = 11
+    fromEnum PTR     = 12
+    fromEnum HINFO   = 13
+    fromEnum MINFO   = 14
+    fromEnum MX      = 15
+    fromEnum TXT     = 16
+    fromEnum AXFR    = 252
+    fromEnum MAILB   = 253
+    fromEnum MAILA   = 254
+    fromEnum AnyType = 255
+
+    toEnum 1  = A
+    toEnum 2  = NS
+    toEnum 3  = MD
+    toEnum 4  = MF
+    toEnum 5  = CNAME
+    toEnum 6  = SOA
+    toEnum 7  = MB
+    toEnum 8  = MG
+    toEnum 9  = MR
+    toEnum 10 = NULL
+    toEnum 11 = WKS
+    toEnum 12 = PTR
+    toEnum 13 = HINFO
+    toEnum 14 = MINFO
+    toEnum 15 = MX
+    toEnum 16 = TXT
+    toEnum 252 = AXFR
+    toEnum 253 = MAILB
+    toEnum 254 = MAILA
+    toEnum 255 = AnyType
+    toEnum _  = undefined
+-}
+
+instance Enum RecordClass where
+    fromEnum IN       = 1
+    fromEnum CS       = 2
+    fromEnum CH       = 3
+    fromEnum HS       = 4
+    fromEnum AnyClass = 255
+
+    toEnum 1   = IN
+    toEnum 2   = CS
+    toEnum 3   = CH
+    toEnum 4   = HS
+    toEnum 255 = AnyClass
+    toEnum _   = undefined
+
+instance Binary RecordClass where
+    get = liftM (toEnum . fromIntegral) G.getWord16be
+    put = P'.putWord16be . fromIntegral . fromEnum
+
+
+defaultRTTable :: IntMap SomeRT
+defaultRTTable = IM.fromList $ map toPair $
+                 [ wrapRecordType A
+                 , wrapRecordType NS
+                 , wrapRecordType CNAME
+                 , wrapRecordType HINFO
+                 ]
+    where
+      toPair :: SomeRT -> (Int, SomeRT)
+      toPair srt@(SomeRT rt) = (rtToInt rt, srt)
+
+
+wrapQueryType :: RecordType rt dt => rt -> SomeQT
+wrapQueryType = SomeRT
+
+wrapRecordType :: RecordType rt dt => rt -> SomeRT
+wrapRecordType = SomeRT
+
+wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
+wrapRecord = SomeRR
\ No newline at end of file