module Network.DNS.Message ( Message(..) , MessageID , MessageType(..) , Header(..) , Opcode(..) , ResponseCode(..) , Question(..) , ResourceRecord(..) , DomainName , DomainLabel , TTL , RecordType , RecordClass(..) , SomeQT , SomeRR , SomeRT , A(..) , NS(..) , CNAME(..) , HINFO(..) , mkDomainName , wrapQueryType , wrapRecordType , wrapRecord ) where import Control.Exception import Control.Monad import Data.Binary import Data.Binary.BitPut as BP import Data.Binary.Get as G import Data.Binary.Put as P import Data.Binary.Strict.BitGet as BG import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as C8 hiding (ByteString) import qualified Data.ByteString.Lazy as LBS import Data.Typeable import qualified Data.IntMap as IM import Data.IntMap (IntMap) import Data.Word import Network.Socket replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a) replicateM' = worker [] where worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a) worker soFar 0 _ a = return (reverse soFar, a) worker soFar n f a = do (b, a') <- f a worker (b : soFar) (n - 1) f a' data Message = Message { msgHeader :: !Header , msgQuestions :: ![Question] , msgAnswers :: ![SomeRR] , msgAuthorities :: ![SomeRR] , msgAdditionals :: ![SomeRR] } deriving (Show, Eq) data Header = Header { hdMessageID :: !MessageID , hdMessageType :: !MessageType , hdOpcode :: !Opcode , hdIsAuthoritativeAnswer :: !Bool , hdIsTruncated :: !Bool , hdIsRecursionDesired :: !Bool , hdIsRecursionAvailable :: !Bool , hdResponseCode :: !ResponseCode -- These fields are supressed in this data structure: -- + QDCOUNT -- + ANCOUNT -- + NSCOUNT -- + ARCOUNT } deriving (Show, Eq) type MessageID = Word16 data MessageType = Query | Response deriving (Show, Eq) data Opcode = StandardQuery | InverseQuery | ServerStatusRequest deriving (Show, Eq) data ResponseCode = NoError | FormatError | ServerFailure | NameError | NotImplemented | Refused deriving (Show, Eq) data Question = Question { qName :: !DomainName , qType :: !SomeQT , qClass :: !RecordClass } deriving (Show, Eq) type SomeQT = SomeRT putQ :: Question -> Put putQ q = do putDomainName $ qName q putSomeRT $ qType q put $ qClass q getQ :: DecompTable -> Get (Question, DecompTable) getQ dt = do (nm, dt') <- getDomainName dt ty <- getSomeRT cl <- get let q = Question { qName = nm , qType = ty , qClass = cl } return (q, dt') newtype DomainName = DN [DomainLabel] deriving (Eq, Show, Typeable) type DomainLabel = BS.ByteString nameToLabels :: DomainName -> [DomainLabel] nameToLabels (DN ls) = ls labelsToName :: [DomainLabel] -> DomainName labelsToName = DN rootName :: DomainName rootName = DN [BS.empty] consLabel :: DomainLabel -> DomainName -> DomainName consLabel x (DN ys) = DN (x:ys) mkDomainName :: String -> DomainName mkDomainName = labelsToName . mkLabels [] . notEmpty where notEmpty :: String -> String notEmpty xs = assert (not $ null xs) xs mkLabels :: [DomainLabel] -> String -> [DomainLabel] mkLabels soFar [] = reverse (C8.empty : soFar) mkLabels soFar xs = case break (== '.') xs of (l, ('.':rest)) -> mkLabels (C8.pack l : soFar) rest _ -> error ("Illegal domain name: " ++ xs) data RecordClass = IN | CS -- Obsolete | CH | HS | AnyClass -- Only for queries deriving (Show, Eq) data RecordType rt dt => ResourceRecord rt dt = ResourceRecord { rrName :: !DomainName , rrType :: !rt , rrClass :: !RecordClass , rrTTL :: !TTL , rrData :: !dt } deriving (Show, Eq, Typeable) putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put putRR rr = do putDomainName $ rrName rr putRecordType $ rrType rr put $ rrClass rr putWord32be $ rrTTL rr let dat = runPut $ putRecordData (undefined :: rt) (rrData rr) putWord16be $ fromIntegral $ LBS.length dat putLazyByteString dat getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable) getRR dt rt = do (nm, dt1) <- getDomainName dt G.skip 2 -- record type cl <- get ttl <- G.getWord32be G.skip 2 -- data length (dat, dt2) <- getRecordData (undefined :: rt) dt1 let rr = ResourceRecord { rrName = nm , rrType = rt , rrClass = cl , rrTTL = ttl , rrData = dat } return (rr, dt2) data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt) instance Show SomeRR where show (SomeRR rr) = show rr instance Eq SomeRR where (SomeRR a) == (SomeRR b) = Just a == cast b putSomeRR :: SomeRR -> Put putSomeRR (SomeRR rr) = putRR rr getSomeRR :: DecompTable -> Get (SomeRR, DecompTable) getSomeRR dt = do srt <- lookAhead $ do getDomainName dt -- skip getSomeRT case srt of SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt') type DecompTable = IntMap DomainName type TTL = Word32 getDomainName :: DecompTable -> Get (DomainName, DecompTable) getDomainName = worker where worker :: DecompTable -> Get (DomainName, DecompTable) worker dt = do offset <- liftM fromIntegral bytesRead hdr <- getLabelHeader case hdr of Offset n -> case IM.lookup n dt of Just name -> return (name, dt) Nothing -> fail ("Illegal offset of label pointer: " ++ show (n, dt)) Length 0 -> return (rootName, dt) Length n -> do label <- getByteString n (rest, dt') <- worker dt let name = consLabel label rest dt'' = IM.insert offset name dt' return (name, dt'') getLabelHeader :: Get LabelHeader getLabelHeader = do header <- lookAhead $ getByteString 1 let Right h = runBitGet header $ do a <- getBit b <- getBit n <- liftM fromIntegral (getAsWord8 6) case (a, b) of ( True, True) -> return $ Offset n (False, False) -> return $ Length n _ -> fail "Illegal label header" case h of Offset _ -> do header' <- getByteString 2 -- Pointers have 2 octets. let Right h' = runBitGet header' $ do BG.skip 2 n <- liftM fromIntegral (getAsWord16 14) return $ Offset n return h' len@(Length _) -> do G.skip 1 return len getCharString :: Get BS.ByteString getCharString = do len <- G.getWord8 getByteString (fromIntegral len) putCharString :: BS.ByteString -> Put putCharString = putDomainLabel data LabelHeader = Offset !Int | Length !Int putDomainName :: DomainName -> Put putDomainName = mapM_ putDomainLabel . nameToLabels putDomainLabel :: DomainLabel -> Put putDomainLabel l = do putWord8 $ fromIntegral $ BS.length l P.putByteString l class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where rtToInt :: rt -> Int putRecordType :: rt -> Put putRecordData :: rt -> dt -> Put getRecordData :: rt -> DecompTable -> Get (dt, DecompTable) putRecordType = putWord16be . fromIntegral . rtToInt data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt instance Show SomeRT where show (SomeRT rt) = show rt instance Eq SomeRT where (SomeRT a) == (SomeRT b) = Just a == cast b putSomeRT :: SomeRT -> Put putSomeRT (SomeRT rt) = putRecordType rt getSomeRT :: Get SomeRT getSomeRT = do n <- liftM fromIntegral G.getWord16be case IM.lookup n defaultRTTable of Nothing -> fail ("Unknown resource record type: " ++ show n) Just srt -> return srt data A = A deriving (Show, Eq, Typeable) instance RecordType A HostAddress where rtToInt _ = 1 putRecordData _ = putWord32be getRecordData _ = \ dt -> do addr <- G.getWord32be return (addr, dt) data NS = NS deriving (Show, Eq, Typeable) instance RecordType NS DomainName where rtToInt _ = 2 putRecordData _ = putDomainName getRecordData _ = getDomainName data CNAME = CNAME deriving (Show, Eq, Typeable) instance RecordType CNAME DomainName where rtToInt _ = 5 putRecordData _ = putDomainName getRecordData _ = getDomainName data HINFO = HINFO deriving (Show, Eq, Typeable) instance RecordType HINFO (BS.ByteString, BS.ByteString) where rtToInt _ = 13 putRecordData _ (cpu, os) = do putCharString cpu putCharString os getRecordData _ dt = do cpu <- getCharString os <- getCharString return ((cpu, os), dt) {- data RecordType = A | NS | MD | MF | CNAME | SOA | MB | MG | MR | NULL | WKS | PTR | HINFO | MINFO | MX | TXT -- Only for queries: | AXFR | MAILB -- Obsolete | MAILA -- Obsolete | AnyType deriving (Show, Eq) -} instance Binary Message where put m = do put $ msgHeader m putWord16be $ fromIntegral $ length $ msgQuestions m putWord16be $ fromIntegral $ length $ msgAnswers m putWord16be $ fromIntegral $ length $ msgAuthorities m putWord16be $ fromIntegral $ length $ msgAdditionals m mapM_ putQ $ msgQuestions m mapM_ putSomeRR $ msgAnswers m mapM_ putSomeRR $ msgAuthorities m mapM_ putSomeRR $ msgAdditionals m get = do hdr <- get nQ <- liftM fromIntegral G.getWord16be nAns <- liftM fromIntegral G.getWord16be nAth <- liftM fromIntegral G.getWord16be nAdd <- liftM fromIntegral G.getWord16be (qs , dt1) <- replicateM' nQ getQ IM.empty (anss, dt2) <- replicateM' nAns getSomeRR dt1 (aths, dt3) <- replicateM' nAth getSomeRR dt2 (adds, _ ) <- replicateM' nAdd getSomeRR dt3 return Message { msgHeader = hdr , msgQuestions = qs , msgAnswers = anss , msgAuthorities = aths , msgAdditionals = adds } instance Binary Header where put h = do putWord16be $ hdMessageID h putLazyByteString flags where flags = runBitPut $ do putNBits 1 $ fromEnum $ hdMessageType h putNBits 4 $ fromEnum $ hdOpcode h putBit $ hdIsAuthoritativeAnswer h putBit $ hdIsTruncated h putBit $ hdIsRecursionDesired h putBit $ hdIsRecursionAvailable h putNBits 3 (0 :: Int) putNBits 4 $ fromEnum $ hdResponseCode h get = do mID <- G.getWord16be flags <- getByteString 2 let Right hd = runBitGet flags $ do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1 op <- liftM (toEnum . fromIntegral) $ getAsWord8 4 aa <- getBit tc <- getBit rd <- getBit ra <- getBit BG.skip 3 rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4 return Header { hdMessageID = mID , hdMessageType = qr , hdOpcode = op , hdIsAuthoritativeAnswer = aa , hdIsTruncated = tc , hdIsRecursionDesired = rd , hdIsRecursionAvailable = ra , hdResponseCode = rc } return hd instance Enum MessageType where fromEnum Query = 0 fromEnum Response = 1 toEnum 0 = Query toEnum 1 = Response toEnum _ = undefined instance Enum Opcode where fromEnum StandardQuery = 0 fromEnum InverseQuery = 1 fromEnum ServerStatusRequest = 2 toEnum 0 = StandardQuery toEnum 1 = InverseQuery toEnum 2 = ServerStatusRequest toEnum _ = undefined instance Enum ResponseCode where fromEnum NoError = 0 fromEnum FormatError = 1 fromEnum ServerFailure = 2 fromEnum NameError = 3 fromEnum NotImplemented = 4 fromEnum Refused = 5 toEnum 0 = NoError toEnum 1 = FormatError toEnum 2 = ServerFailure toEnum 3 = NameError toEnum 4 = NotImplemented toEnum 5 = Refused toEnum _ = undefined {- instance Enum RecordType where fromEnum A = 1 fromEnum NS = 2 fromEnum MD = 3 fromEnum MF = 4 fromEnum CNAME = 5 fromEnum SOA = 6 fromEnum MB = 7 fromEnum MG = 8 fromEnum MR = 9 fromEnum NULL = 10 fromEnum WKS = 11 fromEnum PTR = 12 fromEnum HINFO = 13 fromEnum MINFO = 14 fromEnum MX = 15 fromEnum TXT = 16 fromEnum AXFR = 252 fromEnum MAILB = 253 fromEnum MAILA = 254 fromEnum AnyType = 255 toEnum 1 = A toEnum 2 = NS toEnum 3 = MD toEnum 4 = MF toEnum 5 = CNAME toEnum 6 = SOA toEnum 7 = MB toEnum 8 = MG toEnum 9 = MR toEnum 10 = NULL toEnum 11 = WKS toEnum 12 = PTR toEnum 13 = HINFO toEnum 14 = MINFO toEnum 15 = MX toEnum 16 = TXT toEnum 252 = AXFR toEnum 253 = MAILB toEnum 254 = MAILA toEnum 255 = AnyType toEnum _ = undefined -} instance Enum RecordClass where fromEnum IN = 1 fromEnum CS = 2 fromEnum CH = 3 fromEnum HS = 4 fromEnum AnyClass = 255 toEnum 1 = IN toEnum 2 = CS toEnum 3 = CH toEnum 4 = HS toEnum 255 = AnyClass toEnum _ = undefined instance Binary RecordClass where get = liftM (toEnum . fromIntegral) G.getWord16be put = putWord16be . fromIntegral . fromEnum defaultRTTable :: IntMap SomeRT defaultRTTable = IM.fromList $ map toPair $ [ wrapRecordType A , wrapRecordType NS , wrapRecordType CNAME , wrapRecordType HINFO ] where toPair :: SomeRT -> (Int, SomeRT) toPair srt@(SomeRT rt) = (rtToInt rt, srt) wrapQueryType :: RecordType rt dt => rt -> SomeQT wrapQueryType = SomeRT wrapRecordType :: RecordType rt dt => rt -> SomeRT wrapRecordType = SomeRT wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR wrapRecord = SomeRR