]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Message.hs
Many changes...
[haskell-dns.git] / Network / DNS / Message.hs
index 5c537956bd1657284a8c3b15db3dd9e9941cce62..570548ced67f03e55c1cbe42466bfb505b641e0c 100644 (file)
@@ -13,18 +13,38 @@ module Network.DNS.Message
     , RecordType
     , RecordClass(..)
 
+    , SOAFields(..)
+    , WKSFields(..)
+
+    , SomeQ
     , SomeQT
     , SomeRR
     , SomeRT
 
     , A(..)
     , NS(..)
+    , MD(..)
+    , MF(..)
     , CNAME(..)
+    , SOA(..)
+    , MB(..)
+    , MG(..)
+    , MR(..)
+    , NULL(..)
+    , WKS(..)
+    , PTR(..)
     , HINFO(..)
+    , MINFO(..)
+    , MX(..)
+    , TXT(..)
+
+    , AXFR(..)
+    , MAILB(..)
+    , MAILA(..)
+    , ANY(..)
 
     , mkDomainName
-    , wrapQueryType
-    , wrapRecordType
+    , wrapQuestion
     , wrapRecord
     )
     where
@@ -38,9 +58,12 @@ import           Data.Binary.Put as P'
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
+import qualified Data.ByteString.Lazy as LBS
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
+import qualified Data.IntSet as IS
+import           Data.IntSet (IntSet)
 import qualified Data.Map as M
 import           Data.Map (Map)
 import           Data.Word
@@ -52,7 +75,7 @@ import           Network.Socket
 data Message
     = Message {
         msgHeader      :: !Header
-      , msgQuestions   :: ![Question]
+      , msgQuestions   :: ![SomeQ]
       , msgAnswers     :: ![SomeRR]
       , msgAuthorities :: ![SomeRR]
       , msgAdditionals :: ![SomeRR]
@@ -100,31 +123,66 @@ data ResponseCode
     | Refused
     deriving (Show, Eq)
 
-data Question
+data QueryType qt => Question qt
     = Question {
         qName  :: !DomainName
-      , qType  :: !SomeQT
+      , qType  :: !qt
       , qClass :: !RecordClass
       }
-    deriving (Show, Eq)
+    deriving (Typeable)
+
+instance QueryType qt => Show (Question qt) where
+    show q = "Question { qName = " ++ show (qName q) ++
+             ", qType = " ++ show (qType q) ++
+             ", qClass = " ++ show (qClass q) ++ " }"
+
+instance QueryType qt => Eq (Question qt) where
+    a == b = qName  a == qName  b &&
+             qType  a == qType  b &&
+             qClass a == qClass b
+
+data SomeQ = forall qt. QueryType qt => SomeQ (Question qt)
+
+instance Show SomeQ where
+    show (SomeQ q) = show q
+
+instance Eq SomeQ where
+    (SomeQ a) == (SomeQ b) = Just a == cast b
+
+data SomeQT = forall qt. QueryType qt => SomeQT qt
+
+instance Show SomeQT where
+    show (SomeQT qt) = show qt
 
-type SomeQT = SomeRT
+instance Eq SomeQT where
+    (SomeQT a) == (SomeQT b) = Just a == cast b
 
-putQ :: Question -> Packer CompTable ()
-putQ q
+putSomeQ :: SomeQ -> Packer CompTable ()
+putSomeQ (SomeQ q)
     = do putDomainName $ qName q
-         putSomeRT $ qType q
+         putQueryType $ qType q
          putBinary $ qClass q
 
-getQ :: Unpacker DecompTable Question
-getQ = do nm <- getDomainName
-          ty <- getSomeRT
-          cl <- getBinary
-          return Question {
-                       qName  = nm
-                     , qType  = ty
-                     , qClass = cl
-                     }
+getSomeQ :: Unpacker DecompTable SomeQ
+getSomeQ
+    = do nm <- getDomainName
+         ty <- getSomeQT
+         cl <- getBinary
+         case ty of
+           SomeQT qt -> return $ SomeQ $
+                        Question {
+                          qName  = nm
+                        , qType  = qt
+                        , qClass = cl
+                        }
+
+getSomeQT :: Unpacker s SomeQT
+getSomeQT = do n <- liftM fromIntegral U.getWord16be
+               case IM.lookup n defaultQTTable of
+                 Just sqt
+                     -> return sqt
+                 Nothing
+                     -> fail ("Unknown query type: " ++ show n)
 
 
 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
@@ -286,6 +344,14 @@ putDomainName name
                        else
                          putDomainName rest
 
+class (Show qt, Eq qt, Typeable qt) => QueryType qt where
+    qtToInt :: qt -> Int
+
+    putQueryType :: qt -> Packer s ()
+    putQueryType = P.putWord16be . fromIntegral . qtToInt
+
+instance RecordType rt dt => QueryType rt where
+    qtToInt = rtToInt
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
@@ -295,24 +361,46 @@ class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType r
     putRecordType :: rt -> Packer s ()
     putRecordType = P.putWord16be . fromIntegral . rtToInt
 
+    putRecordDataWithLength :: rt -> dt -> Packer CompTable ()
+    putRecordDataWithLength rt dt
+        = do -- First, write a dummy data length.
+             offset <- bytesWrote
+             P.putWord16be 0
+
+             -- Second, write data.
+             putRecordData rt dt
+
+             -- Third, rewrite the dummy length to an actual value.
+             offset' <- bytesWrote
+             let len = offset' - offset - 2
+             if len <= 0xFFFF then
+                 withOffset offset
+                    $ P.putWord16be $ fromIntegral len
+               else
+                 fail ("putRecordData " ++ show rt ++ " wrote " ++ show len
+                       ++ " bytes, which is way too long")
+
     putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
     putResourceRecord rr
         = do putDomainName $ rrName  rr
              putRecordType $ rrType  rr
              putBinary     $ rrClass rr
              P.putWord32be $ rrTTL   rr
+             putRecordDataWithLength (rrType rr) (rrData rr)
 
-             -- First, write a dummy data length.
-             offset <- bytesWrote
-             P.putWord16be 0
+    getRecordDataWithLength :: rt -> Unpacker DecompTable dt
+    getRecordDataWithLength rt
+        = do len     <- U.getWord16be
+             offset  <- U.bytesRead
+             dat     <- getRecordData rt
+             offset' <- U.bytesRead
 
-             -- Second, write data.
-             putRecordData (rrType rr) (rrData rr)
+             let consumed = offset' - offset
+             when (consumed /= len)
+                      $ fail ("getRecordData " ++ show rt ++ " consumed " ++ show consumed ++
+                              " bytes but it had to consume " ++ show len ++ " bytes")
 
-             -- Third, rewrite the dummy length to an actual value.
-             offset' <- bytesWrote
-             withOffset offset
-                 $ P.putWord16be (fromIntegral (offset' - offset - 2))
+             return dat
 
     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
     getResourceRecord rt
@@ -320,8 +408,7 @@ class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType r
              U.skip 2 -- record type
              cl       <- getBinary
              ttl      <- U.getWord32be
-             U.skip 2 -- data length
-             dat      <- getRecordData rt
+             dat      <- getRecordDataWithLength rt
              return $ ResourceRecord {
                           rrName  = name
                         , rrType  = rt
@@ -338,9 +425,6 @@ instance Show SomeRT where
 instance Eq SomeRT where
     (SomeRT a) == (SomeRT b) = Just a == cast b
 
-putSomeRT :: SomeRT -> Packer s ()
-putSomeRT (SomeRT rt) = putRecordType rt
-
 getSomeRT :: Unpacker s SomeRT
 getSomeRT = do n <- liftM fromIntegral U.getWord16be
                case IM.lookup n defaultRTTable of
@@ -349,6 +433,28 @@ getSomeRT = do n <- liftM fromIntegral U.getWord16be
                  Just srt
                      -> return srt
 
+
+data SOAFields
+    = SOAFields {
+        soaMasterNameServer   :: !DomainName
+      , soaResponsibleMailbox :: !DomainName
+      , soaSerialNumber       :: !Word32
+      , soaRefreshInterval    :: !Word32
+      , soaRetryInterval      :: !Word32
+      , soaExpirationLimit    :: !Word32
+      , soaMinimumTTL         :: !Word32
+      }
+    deriving (Show, Eq, Typeable)
+
+data WKSFields
+    = WKSFields {
+        wksAddress  :: !HostAddress
+      , wksProtocol :: !ProtocolNumber
+      , wksServices :: !IntSet
+      }
+    deriving (Show, Eq, Typeable)
+
+
 data A = A deriving (Show, Eq, Typeable)
 instance RecordType A HostAddress where
     rtToInt       _ = 1
@@ -361,48 +467,193 @@ instance RecordType NS DomainName where
     putRecordData _ = putDomainName
     getRecordData _ = getDomainName
 
+data MD = MD deriving (Show, Eq, Typeable)
+instance RecordType MD DomainName where
+    rtToInt       _ = 3
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data MF = MF deriving (Show, Eq, Typeable)
+instance RecordType MF DomainName where
+    rtToInt       _ = 4
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
 data CNAME = CNAME deriving (Show, Eq, Typeable)
 instance RecordType CNAME DomainName where
     rtToInt       _ = 5
     putRecordData _ = putDomainName
     getRecordData _ = getDomainName
 
+data SOA = SOA deriving (Show, Eq, Typeable)
+instance RecordType SOA SOAFields where
+    rtToInt       _ = 6
+    putRecordData _ = \ soa ->
+                      do putDomainName $ soaMasterNameServer soa
+                         putDomainName $ soaResponsibleMailbox soa
+                         P.putWord32be $ soaSerialNumber soa
+                         P.putWord32be $ soaRefreshInterval soa
+                         P.putWord32be $ soaRetryInterval soa
+                         P.putWord32be $ soaExpirationLimit soa
+                         P.putWord32be $ soaMinimumTTL soa
+    getRecordData _ = do master  <- getDomainName
+                         mail    <- getDomainName
+                         serial  <- U.getWord32be
+                         refresh <- U.getWord32be
+                         retry   <- U.getWord32be
+                         expire  <- U.getWord32be
+                         ttl     <- U.getWord32be
+                         return SOAFields {
+                                      soaMasterNameServer   = master
+                                    , soaResponsibleMailbox = mail
+                                    , soaSerialNumber       = serial
+                                    , soaRefreshInterval    = refresh
+                                    , soaRetryInterval      = retry
+                                    , soaExpirationLimit    = expire
+                                    , soaMinimumTTL         = ttl
+                                    }
+
+data MB = MB deriving (Show, Eq, Typeable)
+instance RecordType MB DomainName where
+    rtToInt       _ = 7
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data MG = MG deriving (Show, Eq, Typeable)
+instance RecordType MG DomainName where
+    rtToInt       _ = 8
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data MR = MR deriving (Show, Eq, Typeable)
+instance RecordType MR DomainName where
+    rtToInt       _ = 9
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data NULL = NULL deriving (Show, Eq, Typeable)
+instance RecordType NULL BS.ByteString where
+    rtToInt                 _ = 10
+    putRecordData         _ _ = fail "putRecordData NULL can't be defined"
+    getRecordData           _ = fail "getRecordData NULL can't be defined"
+    putRecordDataWithLength _ = \ dat ->
+                                do P.putWord16be $ fromIntegral $ BS.length dat
+                                   P.putByteString dat
+    getRecordDataWithLength _ = do len <- U.getWord16be
+                                   U.getByteString $ fromIntegral len
+
+data WKS = WKS deriving (Show, Eq, Typeable)
+instance RecordType WKS WKSFields where
+    rtToInt       _ = 11
+    putRecordData _ = \ wks ->
+                      do P.putWord32be $ wksAddress wks
+                         P.putWord8 $ fromIntegral $ wksProtocol wks
+                         P.putLazyByteString $ toBitmap $ wksServices wks
+        where
+          toBitmap :: IntSet -> LBS.ByteString
+          toBitmap is
+              = let maxPort   = IS.findMax is
+                    range     = [0 .. maxPort]
+                    isAvail p = p `IS.member` is
+                in
+                  runBitPut $ mapM_ putBit $ map isAvail range
+    getRecordData _ = fail "getRecordData WKS can't be defined"
+
+    getRecordDataWithLength _
+        = do len   <- U.getWord16be
+             addr  <- U.getWord32be
+             proto <- liftM fromIntegral U.getWord8
+             bits  <- U.getByteString $ fromIntegral $ len - 4 - 1
+             return WKSFields {
+                          wksAddress  = addr
+                        , wksProtocol = proto
+                        , wksServices = fromBitmap bits
+                        }
+        where
+          fromBitmap :: BS.ByteString -> IntSet
+          fromBitmap bs
+              = let Right is = runBitGet bs $ worker 0 IS.empty
+                in
+                  is
+
+          worker :: Int -> IntSet -> BitGet IntSet
+          worker pos is
+              = do remain <- BG.remaining
+                   if remain == 0 then
+                       return is
+                     else
+                       do bit <- getBit
+                          if bit then
+                              worker (pos + 1) (IS.insert pos is)
+                            else
+                              worker (pos + 1) is
+
+
+data PTR = PTR deriving (Show, Eq, Typeable)
+instance RecordType PTR DomainName where
+    rtToInt       _ = 12
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
 data HINFO = HINFO deriving (Show, Eq, Typeable)
 instance RecordType HINFO (BS.ByteString, BS.ByteString) where
-    rtToInt       _           = 13
-    putRecordData _ (cpu, os) = do putCharString cpu
-                                   putCharString os
-    getRecordData _           = do cpu <- getCharString
-                                   os  <- getCharString
-                                   return (cpu, os)
-
-
-{-
-data RecordType
-    = A
-    | NS
-    | MD
-    | MF
-    | CNAME
-    | SOA
-    | MB
-    | MG
-    | MR
-    | NULL
-    | WKS
-    | PTR
-    | HINFO
-    | MINFO
-    | MX
-    | TXT
-
-    -- Only for queries:
-    | AXFR
-    | MAILB -- Obsolete
-    | MAILA -- Obsolete
-    | AnyType
-    deriving (Show, Eq)
--}
+    rtToInt       _ = 13
+    putRecordData _ = \ (cpu, os) ->
+                      do putCharString cpu
+                         putCharString os
+    getRecordData _ = do cpu <- getCharString
+                         os  <- getCharString
+                         return (cpu, os)
+
+data MINFO = MINFO deriving (Show, Eq, Typeable)
+instance RecordType MINFO (DomainName, DomainName) where
+    rtToInt       _ = 14
+    putRecordData _ = \ (r, e) ->
+                      do putDomainName r
+                         putDomainName e
+    getRecordData _ = do r <- getDomainName
+                         e <- getDomainName
+                         return (r, e)
+
+data MX = MX deriving (Show, Eq, Typeable)
+instance RecordType MX (Word16, DomainName) where
+    rtToInt       _ = 15
+    putRecordData _ = \ (pref, exch) ->
+                      do P.putWord16be pref
+                         putDomainName exch
+    getRecordData _ = do pref <- U.getWord16be
+                         exch <- getDomainName
+                         return (pref, exch)
+
+data TXT = TXT deriving (Show, Eq, Typeable)
+instance RecordType TXT [BS.ByteString] where
+    rtToInt       _ = 16
+    putRecordData _ = mapM_ putCharString
+    getRecordData _ = fail "getRecordData TXT can't be defined"
+
+    getRecordDataWithLength _ = U.getWord16be >>= worker [] . fromIntegral
+        where
+          worker :: [BS.ByteString] -> Int -> Unpacker s [BS.ByteString]
+          worker soFar 0 = return (reverse soFar)
+          worker soFar n = do str <- getCharString
+                              worker (str : soFar) (0 `max` n - 1 - BS.length str)
+
+data AXFR = AXFR deriving (Show, Eq, Typeable)
+instance QueryType AXFR where
+    qtToInt _ = 252
+
+data MAILB = MAILB deriving (Show, Eq, Typeable)
+instance QueryType MAILB where
+    qtToInt _ = 253
+
+data MAILA = MAILA deriving (Show, Eq, Typeable)
+instance QueryType MAILA where
+    qtToInt _ = 254
+
+data ANY = ANY deriving (Show, Eq, Typeable)
+instance QueryType ANY where
+    qtToInt _ = 255
+
 
 instance Binary Message where
     put m = P.liftToBinary M.empty $
@@ -411,7 +662,7 @@ instance Binary Message where
                P.putWord16be $ fromIntegral $ length $ msgAnswers m
                P.putWord16be $ fromIntegral $ length $ msgAuthorities m
                P.putWord16be $ fromIntegral $ length $ msgAdditionals m
-               mapM_ putQ      $ msgQuestions m
+               mapM_ putSomeQ  $ msgQuestions m
                mapM_ putSomeRR $ msgAnswers m
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
@@ -422,7 +673,7 @@ instance Binary Message where
              nAns <- liftM fromIntegral U.getWord16be
              nAth <- liftM fromIntegral U.getWord16be
              nAdd <- liftM fromIntegral U.getWord16be
-             qs   <- replicateM nQ   getQ
+             qs   <- replicateM nQ   getSomeQ
              anss <- replicateM nAns getSomeRR
              aths <- replicateM nAth getSomeRR
              adds <- replicateM nAdd getSomeRR
@@ -506,52 +757,6 @@ instance Enum ResponseCode where
     toEnum 5 = Refused
     toEnum _ = undefined
 
-{-
-instance Enum RecordType where
-    fromEnum A       = 1
-    fromEnum NS      = 2
-    fromEnum MD      = 3
-    fromEnum MF      = 4
-    fromEnum CNAME   = 5
-    fromEnum SOA     = 6
-    fromEnum MB      = 7
-    fromEnum MG      = 8
-    fromEnum MR      = 9
-    fromEnum NULL    = 10
-    fromEnum WKS     = 11
-    fromEnum PTR     = 12
-    fromEnum HINFO   = 13
-    fromEnum MINFO   = 14
-    fromEnum MX      = 15
-    fromEnum TXT     = 16
-    fromEnum AXFR    = 252
-    fromEnum MAILB   = 253
-    fromEnum MAILA   = 254
-    fromEnum AnyType = 255
-
-    toEnum 1  = A
-    toEnum 2  = NS
-    toEnum 3  = MD
-    toEnum 4  = MF
-    toEnum 5  = CNAME
-    toEnum 6  = SOA
-    toEnum 7  = MB
-    toEnum 8  = MG
-    toEnum 9  = MR
-    toEnum 10 = NULL
-    toEnum 11 = WKS
-    toEnum 12 = PTR
-    toEnum 13 = HINFO
-    toEnum 14 = MINFO
-    toEnum 15 = MX
-    toEnum 16 = TXT
-    toEnum 252 = AXFR
-    toEnum 253 = MAILB
-    toEnum 254 = MAILA
-    toEnum 255 = AnyType
-    toEnum _  = undefined
--}
-
 instance Enum RecordClass where
     fromEnum IN       = 1
     fromEnum CS       = 2
@@ -575,19 +780,56 @@ defaultRTTable :: IntMap SomeRT
 defaultRTTable = IM.fromList $ map toPair $
                  [ wrapRecordType A
                  , wrapRecordType NS
+                 , wrapRecordType MD
+                 , wrapRecordType MF
                  , wrapRecordType CNAME
+                 , wrapRecordType SOA
+                 , wrapRecordType MB
+                 , wrapRecordType MG
+                 , wrapRecordType MR
+                 , wrapRecordType NULL
+                 , wrapRecordType WKS
+                 , wrapRecordType PTR
                  , wrapRecordType HINFO
+                 , wrapRecordType MINFO
+                 , wrapRecordType MX
+                 , wrapRecordType TXT
                  ]
     where
       toPair :: SomeRT -> (Int, SomeRT)
       toPair srt@(SomeRT rt) = (rtToInt rt, srt)
 
 
-wrapQueryType :: RecordType rt dt => rt -> SomeQT
-wrapQueryType = SomeRT
+defaultQTTable :: IntMap SomeQT
+defaultQTTable = mergeWithRTTable defaultRTTable $ IM.fromList $ map toPair $
+                 [ wrapQueryType AXFR
+                 , wrapQueryType MAILB
+                 , wrapQueryType MAILA
+                 , wrapQueryType ANY
+                 ]
+    where
+      toPair :: SomeQT -> (Int, SomeQT)
+      toPair sqt@(SomeQT qt) = (qtToInt qt, sqt)
+
+      mergeWithRTTable :: IntMap SomeRT -> IntMap SomeQT -> IntMap SomeQT
+      mergeWithRTTable rts qts
+          = IM.union (toQTTable rts) qts
+
+      toQTTable :: IntMap SomeRT -> IntMap SomeQT
+      toQTTable = IM.map toSomeQT
+
+      toSomeQT :: SomeRT -> SomeQT
+      toSomeQT (SomeRT rt) = SomeQT rt
+
+
+wrapQueryType :: QueryType qt => qt -> SomeQT
+wrapQueryType = SomeQT
 
 wrapRecordType :: RecordType rt dt => rt -> SomeRT
 wrapRecordType = SomeRT
 
+wrapQuestion :: QueryType qt => Question qt -> SomeQ
+wrapQuestion = SomeQ
+
 wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
 wrapRecord = SomeRR
\ No newline at end of file