X-Git-Url: http://git.cielonegro.org/gitweb.cgi?a=blobdiff_plain;f=Network%2FDNS%2FNamed.hs;h=57570cfffdc87a39de49cf065571198e08144f9e;hb=845dca95afa7e073e62520ef3c4840b3b078bdad;hp=13297e8ae2fbbb2bf31051255ebe23ecd8bb5d8b;hpb=5d250da422c01c7aab948ebdda5ef618f18e0f39;p=haskell-dns.git diff --git a/Network/DNS/Named.hs b/Network/DNS/Named.hs index 13297e8..57570cf 100644 --- a/Network/DNS/Named.hs +++ b/Network/DNS/Named.hs @@ -1,10 +1,5 @@ module Network.DNS.Named - ( ZoneFinder(..) - , Zone(..) - - , runNamed - - , defaultRootZone + ( runNamed ) where @@ -12,54 +7,44 @@ import Control.Concurrent import Control.Exception import Control.Monad import Data.Binary +import Data.Binary.Get +import Data.Binary.Put import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS +import Data.Dynamic import Data.Maybe import Network.Socket import qualified Network.Socket.ByteString as NB import Network.DNS.Message import Network.DNS.Named.Config +import Network.DNS.Named.ResponseBuilder +import Network.DNS.Named.Zone import System.Posix.Signals +import System.IO -class ZoneFinder a where - findZone :: a -> DomainName -> IO Zone - -instance ZoneFinder (DomainName -> Zone) where - findZone = (return .) - -instance ZoneFinder (DomainName -> IO Zone) where - findZone = id - -instance ZoneFinder (DomainName -> Maybe Zone) where - findZone = ((return . fromMaybe defaultRootZone) .) - -instance ZoneFinder (DomainName -> IO (Maybe Zone)) where - findZone = (fmap (fromMaybe defaultRootZone) .) - - -data Zone - = Zone { - zoneName :: !DomainName - } - -defaultRootZone :: Zone -defaultRootZone = error "FIXME: defaultRootZone is not implemented yet" - - -runNamed :: ZoneFinder zf => Config -> zf -> IO () -runNamed cnf zf +runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO () +runNamed cnf findZone = withSocketsDo $ do installHandler sigPIPE Ignore Nothing - _tcpListenerTID <- forkIO $ tcpListen - udpListen + + let hint = defaultHints { + addrFlags = [AI_PASSIVE, AI_V4MAPPED] + , addrFamily = AF_INET6 + , addrSocketType = NoSocketType + , addrProtocol = defaultProtocol + } + (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf) + + _tcpListenerTID <- forkIO $ tcpListen ai + udpListen ai where - udpListen :: IO () - udpListen = do -- FIXME: we should support IPv6 when the network package supports it. - so <- socket AF_INET Datagram defaultProtocol - print cnf - bindSocket so $ cnfServerAddress cnf - udpLoop so + udpListen :: AddrInfo -> IO () + udpListen ai + = do so <- socket (addrFamily ai) Datagram defaultProtocol + setSocketOption so ReuseAddr 1 + bindSocket so (addrAddress ai) + udpLoop so udpLoop :: Socket -> IO () udpLoop so @@ -67,41 +52,241 @@ runNamed cnf zf _handlerTID <- forkIO $ udpHandler so packet cameFrom udpLoop so - tcpListen :: IO () - tcpListen = putStrLn "FIXME: tcpListen is not implemented yet." + tcpListen :: AddrInfo -> IO () + tcpListen ai + = do so <- socket (addrFamily ai) Stream defaultProtocol + setSocketOption so ReuseAddr 1 + bindSocket so (addrAddress ai) + listen so 255 + tcpLoop so + + tcpLoop :: Socket -> IO () + tcpLoop so + = do (so', _) <- accept so + h <- socketToHandle so' ReadWriteMode + hSetBuffering h $ BlockBuffering Nothing + _handlerTID <- forkIO $ tcpHandler h + tcpLoop so udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO () udpHandler so packet cameFrom = do msg <- evaluate $ unpackMessage packet msg' <- handleMessage msg `onException` - NB.sendTo so (packMessage $ makeServerFailure msg) cameFrom - _sent <- NB.sendTo so (packMessage $ msg' ) cameFrom + do let servfail = mkErrorReply ServerFailure msg + NB.sendTo so (packMessage (Just 512) servfail) cameFrom + _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom return () + tcpHandler :: Handle -> IO () + tcpHandler h + = do lenB <- LBS.hGet h 2 + if LBS.length lenB < 2 then + -- Got EOF + hClose h + else + do let len = runGet getWord16be lenB + packet <- BS.hGet h $ fromIntegral len + msg <- evaluate $ unpackMessage packet + msg' <- handleMessage msg + `onException` + do let servfail = mkErrorReply ServerFailure msg + packet' = packMessage Nothing servfail + len' = fromIntegral $ BS.length packet' + LBS.hPut h $ runPut $ putWord16be len' + BS.hPut h packet' + hClose h + let packet' = packMessage Nothing msg' + len' = fromIntegral $ BS.length packet' + LBS.hPut h $ runPut $ putWord16be len' + BS.hPut h packet' + hFlush h + tcpHandler h + handleMessage :: Message -> IO Message handleMessage msg - = fail (show msg) -- FIXME + = case validateQuery msg of + NoError + -> do builders <- mapM handleQuestion $ msgQuestions msg + let builder = foldl (>>) (return ()) builders + msg' = runBuilder msg builder -packMessage :: Message -> BS.ByteString -packMessage = BS.concat . LBS.toChunks . encode + return msg' -unpackMessage :: BS.ByteString -> Message -unpackMessage = decode . LBS.fromChunks . return + err -> return $ mkErrorReply err msg + + handleQuestion :: SomeQ -> IO (Builder ()) + handleQuestion (SomeQ q) + = do zoneM <- findZone (qName q) + case zoneM of + Nothing + -> return $ do unauthorise + setResponseCode Refused + Just zone + -> handleQuestionForZone (SomeQ q) zone + + handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ()) + handleQuestionForZone (SomeQ q) zone + | Just (qType q) == cast AXFR + = handleAXFR (SomeQ q) zone + | otherwise + = do allRecords <- zoneResponder zone (qName q) + let filtered = filterRecords (SomeQ q) allRecords + + additionals <- do xss <- mapM (getAdditionals zone) filtered + ys <- case zoneNSRecord zone of + Just rr -> getAdditionals zone rr + Nothing -> return [] + return (concat xss ++ ys) + + return $ do mapM_ addAnswer filtered + + when (qName q == zoneName zone) $ + do when (Just (qType q) == cast SOA || + Just (qType q) == cast ANY ) + $ case zoneSOARecord zone of + Just rr -> addAnswer rr + Nothing -> return () + + when (Just (qType q) == cast NS || + Just (qType q) == cast ANY ) + $ case zoneNSRecord zone of + Just rr -> addAnswer rr + Nothing -> return () + mapM_ addAdditional additionals -makeServerFailure :: Message -> Message -makeServerFailure msg - = let header = msgHeader msg - msg' = msg { - msgHeader = header { - hdMessageType = Response - , hdIsAuthoritativeAnswer = False - , hdIsTruncated = False - , hdIsRecursionAvailable = False - , hdResponseCode = ServerFailure + case zoneNSRecord zone of + Just rr -> addAuthority rr + Nothing -> unauthorise + + getAdditionals :: Zone -> SomeRR -> IO [SomeRR] + getAdditionals zone (SomeRR rr) + = case cast (rrData rr) :: Maybe DomainName of + Nothing + -> return [] + Just name + -> do allRecords <- zoneResponder zone name + + let rA = filterRecords (SomeQ qA) allRecords + rB = filterRecords (SomeQ qB) allRecords + qA = Question { + qName = name + , qType = A + , qClass = IN + } + qB = Question { + qName = name + , qType = AAAA + , qClass = IN } - } - in - msg' + return (rA ++ rB) + + filterRecords :: SomeQ -> [SomeRR] -> [SomeRR] + filterRecords (SomeQ q) = filter predicate + where + predicate rr + = predForType rr && predForClass rr + + predForType (SomeRR rr) + | typeOf (qType q) == typeOf ANY + = True + + | typeOf (qType q) == typeOf MAILB + = typeOf (rrType rr) == typeOf MR || + typeOf (rrType rr) == typeOf MB || + typeOf (rrType rr) == typeOf MG || + typeOf (rrType rr) == typeOf MINFO + + | otherwise + = typeOf (rrType rr) == typeOf (qType q) || + typeOf (rrType rr) == typeOf CNAME + + predForClass (SomeRR rr) + | typeOf (qClass q) == typeOf ANY + = True + + | otherwise + = typeOf (rrClass rr) == typeOf (qClass q) + + handleAXFR :: SomeQ -> Zone -> IO (Builder ()) + handleAXFR (SomeQ q) zone + | qName q == zoneName zone && + isJust (zoneSOA zone) && + cnfAllowTransfer cnf + = do names <- zoneRecordNames zone + allRecords <- liftM concat $ mapM (zoneResponder zone) names + return $ do addAnswer $ fromJust $ zoneSOARecord zone + addAnswer $ fromJust $ zoneNSRecord zone + mapM_ addAnswer allRecords + addAnswerNonuniquely $ fromJust $ zoneSOARecord zone + | otherwise + = return $ return () + + +validateQuery :: Message -> ResponseCode +validateQuery = validateHeader . msgHeader + where + validateHeader :: Header -> ResponseCode + validateHeader hdr + | hdMessageType hdr /= Query = NotImplemented + | hdOpcode hdr /= StandardQuery = NotImplemented + | otherwise = NoError + + +packMessage :: Maybe Int -> Message -> BS.ByteString +packMessage limM = BS.concat . LBS.toChunks . truncateMsg + where + truncateMsg :: Message -> LBS.ByteString + truncateMsg msg + = let packet = encode msg + needTrunc = fromMaybe False $ + do lim <- limM + return $ fromIntegral (LBS.length packet) > lim + in + if needTrunc then + truncateMsg $ trunc' msg + else + packet + + trunc' :: Message -> Message + trunc' msg + | notNull $ msgAdditionals msg + = msg { + msgAdditionals = truncList $ msgAdditionals msg + } + | notNull $ msgAuthorities msg + = msg { + msgHeader = setTruncFlag $ msgHeader msg + , msgAuthorities = truncList $ msgAuthorities msg + } + | notNull $ msgAnswers msg + = msg { + msgHeader = setTruncFlag $ msgHeader msg + , msgAnswers = truncList $ msgAnswers msg + } + | notNull $ msgQuestions msg + = msg { + msgHeader = setTruncFlag $ msgHeader msg + , msgQuestions = truncList $ msgQuestions msg + } + | otherwise + = error ("packMessage: You are already skinny and need no diet: " ++ show msg) + + setTruncFlag :: Header -> Header + setTruncFlag hdr = hdr { hdIsTruncated = True } + + notNull :: [a] -> Bool + notNull = not . null + + truncList :: [a] -> [a] + truncList xs = take (length xs - 1) xs + +unpackMessage :: BS.ByteString -> Message +unpackMessage = decode . LBS.fromChunks . return + +mkErrorReply :: ResponseCode -> Message -> Message +mkErrorReply err msg + = runBuilder msg $ do unauthorise + setResponseCode err