X-Git-Url: http://git.cielonegro.org/gitweb.cgi?a=blobdiff_plain;f=Network%2FDNS%2FNamed.hs;h=4a9eaed2fa26eda15576a94673deead4c5d42a82;hb=248b1c63284bbe00550bf2402ee6a9da6997143e;hp=57d9ea4b78b6a765bce5a044444503211c1a8b21;hpb=5015e5caa39e015e6ffa28a87fc5f189e7ba3c71;p=haskell-dns.git diff --git a/Network/DNS/Named.hs b/Network/DNS/Named.hs index 57d9ea4..4a9eaed 100644 --- a/Network/DNS/Named.hs +++ b/Network/DNS/Named.hs @@ -7,21 +7,24 @@ import Control.Concurrent import Control.Exception import Control.Monad import Data.Binary +import Data.Binary.Get +import Data.Binary.Put import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS +import Data.Dynamic import Data.Maybe import Network.Socket import qualified Network.Socket.ByteString as NB import Network.DNS.Message import Network.DNS.Named.Config -import Network.DNS.Named.Responder import Network.DNS.Named.ResponseBuilder import Network.DNS.Named.Zone import System.Posix.Signals +import System.IO -runNamed :: ZoneFinder zf => Config -> zf -> IO () -runNamed cnf zf +runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO () +runNamed cnf findZone = withSocketsDo $ do installHandler sigPIPE Ignore Nothing _tcpListenerTID <- forkIO $ tcpListen @@ -30,7 +33,6 @@ runNamed cnf zf udpListen :: IO () udpListen = do -- FIXME: we should support IPv6 when the network package supports it. so <- socket AF_INET Datagram defaultProtocol - print cnf bindSocket so $ cnfServerAddress cnf udpLoop so @@ -41,7 +43,18 @@ runNamed cnf zf udpLoop so tcpListen :: IO () - tcpListen = putStrLn "FIXME: tcpListen is not implemented yet." + tcpListen = do so <- socket AF_INET Stream defaultProtocol + bindSocket so $ cnfServerAddress cnf + listen so 255 + tcpLoop so + + tcpLoop :: Socket -> IO () + tcpLoop so + = do (so', _) <- accept so + h <- socketToHandle so' ReadWriteMode + hSetBuffering h $ BlockBuffering Nothing + _handlerTID <- forkIO $ tcpHandler h + tcpLoop so udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO () udpHandler so packet cameFrom @@ -53,6 +66,31 @@ runNamed cnf zf _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom return () + tcpHandler :: Handle -> IO () + tcpHandler h + = do lenB <- LBS.hGet h 2 + if LBS.length lenB < 2 then + -- Got EOF + hClose h + else + do let len = runGet getWord16be lenB + packet <- BS.hGet h $ fromIntegral len + msg <- evaluate $ unpackMessage packet + msg' <- handleMessage msg + `onException` + do let servfail = mkErrorReply ServerFailure msg + packet' = packMessage Nothing servfail + len' = fromIntegral $ BS.length packet' + LBS.hPut h $ runPut $ putWord16be len' + BS.hPut h packet' + hClose h + let packet' = packMessage Nothing msg' + len' = fromIntegral $ BS.length packet' + LBS.hPut h $ runPut $ putWord16be len' + BS.hPut h packet' + hFlush h + tcpHandler h + handleMessage :: Message -> IO Message handleMessage msg = case validateQuery msg of @@ -68,11 +106,96 @@ runNamed cnf zf handleQuestion :: SomeQ -> IO (Builder ()) handleQuestion (SomeQ q) - = do zone <- findZone zf (qName q) - -- FIXME: this is merely a bogus implementation. - -- It considers no additional or authoritative sections. - results <- mapM (runResponder' q) (zoneResponders zone) - return $ mapM_ addAnswer $ concat results + = do zoneM <- findZone (qName q) + case zoneM of + Nothing + -> return $ do unauthorise + setResponseCode Refused + Just zone + -> handleQuestionForZone (SomeQ q) zone + + handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ()) + handleQuestionForZone (SomeQ q) zone + | Just (qType q) == cast AXFR + = handleAXFR (SomeQ q) zone + | otherwise + = do allRecords <- zoneResponder zone (qName q) + let filtered = filterRecords (SomeQ q) allRecords + + additionals <- do xss <- mapM (getAdditionals zone) filtered + ys <- case zoneNSRecord zone of + Just rr -> getAdditionals zone rr + Nothing -> return [] + return (concat xss ++ ys) + + return $ do mapM_ addAnswer filtered + + when (qName q == zoneName zone) $ + do when (Just (qType q) == cast SOA || + Just (qType q) == cast ANY ) + $ case zoneSOARecord zone of + Just rr -> addAnswer rr + Nothing -> return () + + when (Just (qType q) == cast NS || + Just (qType q) == cast ANY ) + $ case zoneNSRecord zone of + Just rr -> addAnswer rr + Nothing -> return () + + mapM_ addAdditional additionals + + case zoneNSRecord zone of + Just rr -> addAuthority rr + Nothing -> unauthorise + + getAdditionals :: Zone -> SomeRR -> IO [SomeRR] + getAdditionals zone (SomeRR rr) + = case cast (rrData rr) :: Maybe DomainName of + Nothing + -> return [] + Just name + -> do allRecords <- zoneResponder zone name + + let filtered = filterRecords (SomeQ q') allRecords + q' = Question { + qName = name + , qType = A + , qClass = IN + } + return filtered + + filterRecords :: SomeQ -> [SomeRR] -> [SomeRR] + filterRecords (SomeQ q) + | Just (qType q) == cast ANY && + Just (qClass q) == cast ANY = id + | Just (qType q) == cast ANY = filter matchClass + | Just (qClass q) == cast ANY = filter matchType + | otherwise = filter matchBoth + where + matchClass (SomeRR rr) + = Just (qClass q) == cast (rrClass rr) + + matchType (SomeRR rr) + = Just (qType q) == cast (rrType rr) || + Just CNAME == cast (rrType rr) + + matchBoth rr + = matchType rr && matchClass rr + + handleAXFR :: SomeQ -> Zone -> IO (Builder ()) + handleAXFR (SomeQ q) zone + | qName q == zoneName zone && + isJust (zoneSOA zone) && + cnfAllowTransfer cnf + = do names <- zoneRecordNames zone + allRecords <- liftM concat $ mapM (zoneResponder zone) names + return $ do addAnswer $ fromJust $ zoneSOARecord zone + addAnswer $ fromJust $ zoneNSRecord zone + mapM_ addAnswer allRecords + addAnswerNonuniquely $ fromJust $ zoneSOARecord zone + | otherwise + = return $ return () validateQuery :: Message -> ResponseCode @@ -138,15 +261,5 @@ unpackMessage = decode . LBS.fromChunks . return mkErrorReply :: ResponseCode -> Message -> Message mkErrorReply err msg - = let header = msgHeader msg - msg' = msg { - msgHeader = header { - hdMessageType = Response - , hdIsAuthoritativeAnswer = False - , hdIsTruncated = False - , hdIsRecursionAvailable = False - , hdResponseCode = err - } - } - in - msg' + = runBuilder msg $ do unauthorise + setResponseCode err