]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Message.hs
Add DNSUnitTest.hs
[haskell-dns.git] / Network / DNS / Message.hs
index 71fec55814cda85735cb528f6f5b262f50383616..6144d13766e037d551838526f18804bf0b896451 100644 (file)
@@ -1,35 +1,67 @@
 module Network.DNS.Message
     ( Message(..)
+    , MessageID
     , MessageType(..)
     , Header(..)
     , Opcode(..)
     , ResponseCode(..)
     , Question(..)
     , ResourceRecord(..)
-    , RecordType(..)
+    , DomainName
+    , DomainLabel
+    , TTL
+    , RecordType
     , RecordClass(..)
+
+    , SomeRR(..)
+    , SomeRT(..)
+
+    , CNAME(..)
+    , HINFO(..)
+
+    , mkQueryType
+    , mkDomainName
     )
     where
 
+import           Control.Exception
+import           Control.Monad
 import           Data.Binary
-import           Data.Binary.Get
-import           Data.Binary.Put
-import           Data.Bits
+import           Data.Binary.BitPut as BP
+import           Data.Binary.Get as G
+import           Data.Binary.Put as P
+import           Data.Binary.Strict.BitGet as BG
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
+import qualified Data.ByteString.Lazy as LBS
+import           Data.Typeable
+import qualified Data.IntMap as IM
+import           Data.IntMap (IntMap)
 import           Data.Word
 
 
+replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a)
+replicateM' = worker []
+    where
+      worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a)
+      worker soFar 0 _ a = return (reverse soFar, a)
+      worker soFar n f a = do (b, a') <- f a
+                              worker (b : soFar) (n - 1) f a'
+
+
 data Message
     = Message {
         msgHeader      :: !Header
       , msgQuestions   :: ![Question]
-      , msgAnswers     :: ![ResourceRecord]
-      , msgAuthorities :: ![ResourceRecord]
-      , msgAdditionals :: ![ResourceRecord]
+      , msgAnswers     :: ![SomeRR]
+      , msgAuthorities :: ![SomeRR]
+      , msgAdditionals :: ![SomeRR]
       }
+    deriving (Show, Eq)
 
 data Header
     = Header {
-        hdMessageID             :: !Word16
+        hdMessageID             :: !MessageID
       , hdMessageType           :: !MessageType
       , hdOpcode                :: !Opcode
       , hdIsAuthoritativeAnswer :: !Bool
@@ -39,11 +71,14 @@ data Header
       , hdResponseCode          :: !ResponseCode
 
       -- These fields are supressed in this data structure:
-      -- * QDCOUNT
-      -- * ANCOUNT
-      -- * NSCOUNT
-      -- * ARCOUNT
+      -- + QDCOUNT
+      -- + ANCOUNT
+      -- + NSCOUNT
+      -- + ARCOUNT
       }
+    deriving (Show, Eq)
+
+type MessageID = Word16
 
 data MessageType
     = Query
@@ -68,12 +103,55 @@ data ResponseCode
 data Question
     = Question {
         qName  :: !DomainName
-      , qType  :: !RecordType
+      , qType  :: !SomeQT
       , qClass :: !RecordClass
       }
     deriving (Show, Eq)
 
-type DomainName = [[Word8]]
+type SomeQT = SomeRT
+
+mkQueryType :: RecordType rt dt => rt -> SomeQT
+mkQueryType = SomeRT
+
+putQ :: Question -> Put
+putQ q
+    = do putDomainName $ qName q
+         putSomeRT $ qType q
+         put $ qClass q
+
+getQ :: DecompTable -> Get (Question, DecompTable)
+getQ dt
+    = do (nm, dt') <- getDomainName dt
+         ty        <- getSomeRT
+         cl        <- get
+         let q = Question {
+                   qName  = nm
+                 , qType  = ty
+                 , qClass = cl
+                 }
+         return (q, dt')
+
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
+type DomainLabel    = BS.ByteString
+
+nameToLabels :: DomainName -> [DomainLabel]
+nameToLabels (DN ls) = ls
+
+labelsToName :: [DomainLabel] -> DomainName
+labelsToName = DN
+
+mkDomainName :: String -> DomainName
+mkDomainName = labelsToName . mkLabels [] . notEmpty
+    where
+      notEmpty :: String -> String
+      notEmpty xs = assert (not $ null xs) xs
+
+      mkLabels :: [DomainLabel] -> String -> [DomainLabel]
+      mkLabels soFar [] = reverse (C8.empty : soFar)
+      mkLabels soFar xs = case break (== '.') xs of
+                            (l, ('.':rest))
+                                -> mkLabels (C8.pack l : soFar) rest
+                            _   -> error ("Illegal domain name: " ++ xs)
 
 data RecordClass
     = IN
@@ -83,16 +161,167 @@ data RecordClass
     | AnyClass -- Only for queries
     deriving (Show, Eq)
 
-data ResourceRecord
+
+data RecordType rt dt => ResourceRecord rt dt
     = ResourceRecord {
         rrName  :: !DomainName
-      , rrType  :: !RecordType
+      , rrType  :: !rt
       , rrClass :: !RecordClass
-      , rrTTL   :: !Word32
-      , rrData  :: ![Word8]
+      , rrTTL   :: !TTL
+      , rrData  :: !dt
       }
-    deriving (Show, Eq)
+    deriving (Show, Eq, Typeable)
+
+
+putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
+putRR rr = do putDomainName $ rrName rr
+              putRecordType $ rrType  rr
+              put $ rrClass rr
+              putWord32be $ rrTTL rr
+
+              let dat = runPut $
+                        putRecordData (undefined :: rt) (rrData rr)
+              putWord16be $ fromIntegral $ LBS.length dat
+              putLazyByteString dat
+
+
+getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable)
+getRR dt rt
+    = do (nm, dt1)  <- getDomainName dt
+         G.skip 2   -- record type
+         cl         <- get
+         ttl        <- G.getWord32be
+         G.skip 2   -- data length
+         (dat, dt2) <- getRecordData (undefined :: rt) dt1
+
+         let rr = ResourceRecord {
+                    rrName  = nm
+                  , rrType  = rt
+                  , rrClass = cl
+                  , rrTTL   = ttl
+                  , rrData  = dat
+                  }
+         return (rr, dt2)
+
+
+data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
+
+instance Show SomeRR where
+    show (SomeRR rr) = show rr
+
+instance Eq SomeRR where
+    (SomeRR a) == (SomeRR b) = Just a == cast b
 
+
+putSomeRR :: SomeRR -> Put
+putSomeRR (SomeRR rr) = putRR rr
+
+getSomeRR :: DecompTable -> Get (SomeRR, DecompTable)
+getSomeRR dt
+    = do srt <- lookAhead $
+                do getDomainName dt -- skip
+                   getSomeRT
+         case srt of
+           SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt')
+
+
+type DecompTable = IntMap BS.ByteString
+type TTL = Word32
+
+getDomainName :: DecompTable -> Get (DomainName, DecompTable)
+getDomainName = flip worker []
+    where
+      worker :: DecompTable -> [DomainLabel] -> Get (DomainName, DecompTable)
+      worker dt soFar
+          = do (l, dt') <- getDomainLabel dt
+               case BS.null l of
+                 True  -> return (labelsToName (reverse (l : soFar)), dt')
+                 False -> worker dt' (l : soFar)
+
+getDomainLabel :: DecompTable -> Get (DomainLabel, DecompTable)
+getDomainLabel dt
+    = do header <- getByteString 1
+         let Right h
+                 = runBitGet header $
+                   do a <- getBit
+                      b <- getBit
+                      n <- liftM fromIntegral (getAsWord8 6)
+                      case (a, b) of
+                        ( True,  True) -> return $ Offset n
+                        (False, False) -> return $ Length n
+                        _              -> fail "Illegal label header"
+         case h of
+           Offset n
+               -> do let Just l = IM.lookup n dt
+                     return (l, dt)
+           Length n
+               -> do offset <- liftM fromIntegral bytesRead
+                     label  <- getByteString n
+                     let dt' = IM.insert offset label dt
+                     return (label, dt')
+
+getCharString :: Get BS.ByteString
+getCharString = do len <- G.getWord8
+                   getByteString (fromIntegral len)
+
+putCharString :: BS.ByteString -> Put
+putCharString = putDomainLabel
+
+data LabelHeader
+    = Offset !Int
+    | Length !Int
+
+putDomainName :: DomainName -> Put
+putDomainName = mapM_ putDomainLabel . nameToLabels
+
+putDomainLabel :: DomainLabel -> Put
+putDomainLabel l
+    = do putWord8 $ fromIntegral $ BS.length l
+         P.putByteString l
+
+class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
+    rtToInt       :: rt -> Int
+    putRecordType :: rt -> Put
+    putRecordData :: rt -> dt -> Put
+    getRecordData :: rt -> DecompTable -> Get (dt, DecompTable)
+
+    putRecordType = putWord16be . fromIntegral . rtToInt
+
+data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
+
+instance Show SomeRT where
+    show (SomeRT rt) = show rt
+
+instance Eq SomeRT where
+    (SomeRT a) == (SomeRT b) = Just a == cast b
+
+putSomeRT :: SomeRT -> Put
+putSomeRT (SomeRT rt) = putRecordType rt
+
+getSomeRT :: Get SomeRT
+getSomeRT = do n <- liftM fromIntegral G.getWord16be
+               case IM.lookup n defaultRTTable of
+                 Nothing
+                     -> fail ("Unknown resource record type: " ++ show n)
+                 Just srt
+                     -> return srt
+
+data CNAME = CNAME deriving (Show, Eq, Typeable)
+instance RecordType CNAME DomainName where
+    rtToInt       _ = 5
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data HINFO = HINFO deriving (Show, Eq, Typeable)
+instance RecordType HINFO (BS.ByteString, BS.ByteString) where
+    rtToInt       _           = 13
+    putRecordData _ (cpu, os) = do putCharString cpu
+                                   putCharString os
+    getRecordData _ dt        = do cpu <- getCharString
+                                   os  <- getCharString
+                                   return ((cpu, os), dt)
+
+{-
 data RecordType
     = A
     | NS
@@ -117,48 +346,72 @@ data RecordType
     | MAILA -- Obsolete
     | AnyType
     deriving (Show, Eq)
+-}
+
+instance Binary Message where
+    put m = do put $ msgHeader m
+               putWord16be $ fromIntegral $ length $ msgQuestions m
+               putWord16be $ fromIntegral $ length $ msgAnswers m
+               putWord16be $ fromIntegral $ length $ msgAuthorities m
+               putWord16be $ fromIntegral $ length $ msgAdditionals m
+               mapM_ putQ  $ msgQuestions m
+               mapM_ putSomeRR $ msgAnswers m
+               mapM_ putSomeRR $ msgAuthorities m
+               mapM_ putSomeRR $ msgAdditionals m
+
+    get = do hdr  <- get
+             nQ   <- liftM fromIntegral G.getWord16be
+             nAns <- liftM fromIntegral G.getWord16be
+             nAth <- liftM fromIntegral G.getWord16be
+             nAdd <- liftM fromIntegral G.getWord16be
+             (qs  , dt1) <- replicateM' nQ   getQ IM.empty
+             (anss, dt2) <- replicateM' nAns getSomeRR dt1
+             (aths, dt3) <- replicateM' nAth getSomeRR dt2
+             (adds, _  ) <- replicateM' nAdd getSomeRR dt3
+             return Message {
+                          msgHeader      = hdr
+                        , msgQuestions   = qs
+                        , msgAnswers     = anss
+                        , msgAuthorities = aths
+                        , msgAdditionals = adds
+                        }
 
 instance Binary Header where
     put h = do putWord16be $ hdMessageID h
-               let qr    = fromIntegral $ fromEnum $ hdMessageType h
-                   op    = fromIntegral $ fromEnum $ hdOpcode h
-                   aa    = boolToNum $ hdIsAuthoritativeAnswer h
-                   tc    = boolToNum $ hdIsTruncated h
-                   rd    = boolToNum $ hdIsRecursionDesired h
-                   ra    = boolToNum $ hdIsRecursionAvailable h
-                   rc    = fromIntegral $ fromEnum $ hdResponseCode h
-                   flags = ((qr `shiftL` 15) .&. 0x01) .|.
-                           ((op `shiftL` 11) .&. 0x0F) .|.
-                           ((aa `shiftL` 10) .&. 0x01) .|.
-                           ((tc `shiftL`  9) .&. 0x01) .|.
-                           ((rd `shiftL`  8) .&. 0x01) .|.
-                           ((ra `shiftL`  7) .&. 0x01) .|.
-                           ((rc `shiftL`  0) .&. 0x0F)
-               putWord16be flags
+               putLazyByteString flags
         where
-          boolToNum :: Num a => Bool -> a
-          boolToNum True  = 1
-          boolToNum False = 0
-
-    get = do mID   <- getWord16be
-             flags <- getWord16be
-             let qr = toEnum $ fromIntegral ((flags `shiftR` 15) .&. 0x01)
-                 op = toEnum $ fromIntegral ((flags `shiftR` 11) .&. 0x0F)
-                 aa = testBit flags 10
-                 tc = testBit flags 9
-                 rd = testBit flags 8
-                 ra = testBit flags 7
-                 rc = toEnum $ fromIntegral (flags .&. 0x0F)
-                 hd = Header {
-                         hdMessageID             = mID
-                       , hdMessageType           = qr
-                       , hdOpcode                = op
-                       , hdIsAuthoritativeAnswer = aa
-                       , hdIsTruncated           = tc
-                       , hdIsRecursionDesired    = rd
-                       , hdIsRecursionAvailable  = ra
-                       , hdResponseCode          = rc
-                       }
+          flags = runBitPut $
+                  do putNBits 1 $ fromEnum $ hdMessageType h
+                     putNBits 4 $ fromEnum $ hdOpcode h
+                     putBit $ hdIsAuthoritativeAnswer h
+                     putBit $ hdIsTruncated h
+                     putBit $ hdIsRecursionDesired h
+                     putBit $ hdIsRecursionAvailable h
+                     putNBits 3 (0 :: Int)
+                     putNBits 4 $ fromEnum $ hdResponseCode h
+
+    get = do mID   <- G.getWord16be
+             flags <- getByteString 2
+             let Right hd
+                     = runBitGet flags $
+                       do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
+                          op <- liftM (toEnum . fromIntegral) $ getAsWord8 4
+                          aa <- getBit
+                          tc <- getBit
+                          rd <- getBit
+                          ra <- getBit
+                          BG.skip 3
+                          rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4
+                          return Header {
+                                       hdMessageID             = mID
+                                     , hdMessageType           = qr
+                                     , hdOpcode                = op
+                                     , hdIsAuthoritativeAnswer = aa
+                                     , hdIsTruncated           = tc
+                                     , hdIsRecursionDesired    = rd
+                                     , hdIsRecursionAvailable  = ra
+                                     , hdResponseCode          = rc
+                                     }
              return hd
 
 instance Enum MessageType where
@@ -195,6 +448,7 @@ instance Enum ResponseCode where
     toEnum 5 = Refused
     toEnum _ = undefined
 
+{-
 instance Enum RecordType where
     fromEnum A       = 1
     fromEnum NS      = 2
@@ -238,8 +492,9 @@ instance Enum RecordType where
     toEnum 254 = MAILA
     toEnum 255 = AnyType
     toEnum _  = undefined
+-}
 
-instance Enum RecordClass
+instance Enum RecordClass where
     fromEnum IN       = 1
     fromEnum CS       = 2
     fromEnum CH       = 3
@@ -252,3 +507,16 @@ instance Enum RecordClass
     toEnum 4   = HS
     toEnum 255 = AnyClass
     toEnum _   = undefined
+
+instance Binary RecordClass where
+    get = liftM (toEnum . fromIntegral) G.getWord16be
+    put = putWord16be . fromIntegral . fromEnum
+
+
+defaultRTTable :: IntMap SomeRT
+defaultRTTable = IM.fromList $ map toPair $
+                 [ SomeRT CNAME
+                 ]
+    where
+      toPair :: SomeRT -> (Int, SomeRT)
+      toPair srt@(SomeRT rt) = (rtToInt rt, srt)