module Network.DNS.Named ( runNamed ) where import Control.Concurrent import Control.Exception import Control.Monad import Data.Binary import Data.Binary.Get import Data.Binary.Put import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Dynamic import Data.Maybe import Network.Socket import qualified Network.Socket.ByteString as NB import Network.DNS.Message import Network.DNS.Named.Config import Network.DNS.Named.ResponseBuilder import Network.DNS.Named.Zone import System.Posix.Signals import System.IO runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO () runNamed cnf findZone = withSocketsDo $ do installHandler sigPIPE Ignore Nothing _tcpListenerTID <- forkIO $ tcpListen udpListen where udpListen :: IO () udpListen = do -- FIXME: we should support IPv6 when the network package supports it. so <- socket AF_INET Datagram defaultProtocol bindSocket so $ cnfServerAddress cnf udpLoop so udpLoop :: Socket -> IO () udpLoop so = do (packet, cameFrom) <- NB.recvFrom so 512 _handlerTID <- forkIO $ udpHandler so packet cameFrom udpLoop so tcpListen :: IO () tcpListen = do so <- socket AF_INET Stream defaultProtocol bindSocket so $ cnfServerAddress cnf listen so 255 tcpLoop so tcpLoop :: Socket -> IO () tcpLoop so = do (so', _) <- accept so h <- socketToHandle so' ReadWriteMode hSetBuffering h $ BlockBuffering Nothing _handlerTID <- forkIO $ tcpHandler h tcpLoop so udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO () udpHandler so packet cameFrom = do msg <- evaluate $ unpackMessage packet msg' <- handleMessage msg `onException` do let servfail = mkErrorReply ServerFailure msg NB.sendTo so (packMessage (Just 512) servfail) cameFrom _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom return () tcpHandler :: Handle -> IO () tcpHandler h = do lenB <- LBS.hGet h 2 if LBS.length lenB < 2 then -- Got EOF hClose h else do let len = runGet getWord16be lenB packet <- BS.hGet h $ fromIntegral len msg <- evaluate $ unpackMessage packet msg' <- handleMessage msg `onException` do let servfail = mkErrorReply ServerFailure msg packet' = packMessage Nothing servfail len' = fromIntegral $ BS.length packet' LBS.hPut h $ runPut $ putWord16be len' BS.hPut h packet' hClose h let packet' = packMessage Nothing msg' len' = fromIntegral $ BS.length packet' LBS.hPut h $ runPut $ putWord16be len' BS.hPut h packet' hFlush h tcpHandler h handleMessage :: Message -> IO Message handleMessage msg = case validateQuery msg of NoError -> do builders <- mapM handleQuestion $ msgQuestions msg let builder = foldl (>>) (return ()) builders msg' = runBuilder msg builder return msg' err -> return $ mkErrorReply err msg handleQuestion :: SomeQ -> IO (Builder ()) handleQuestion (SomeQ q) = do zoneM <- findZone (qName q) case zoneM of Nothing -> return $ do unauthorise setResponseCode Refused Just zone -> handleQuestionForZone (SomeQ q) zone handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ()) handleQuestionForZone (SomeQ q) zone | Just (qType q) == cast AXFR = handleAXFR (SomeQ q) zone | otherwise = do allRecords <- zoneResponder zone (qName q) let filtered = filterRecords (SomeQ q) allRecords additionals <- do xss <- mapM (getAdditionals zone) filtered ys <- case zoneNSRecord zone of Just rr -> getAdditionals zone rr Nothing -> return [] return (concat xss ++ ys) return $ do mapM_ addAnswer filtered when (qName q == zoneName zone) $ do when (Just (qType q) == cast SOA || Just (qType q) == cast ANY ) $ case zoneSOARecord zone of Just rr -> addAnswer rr Nothing -> return () when (Just (qType q) == cast NS || Just (qType q) == cast ANY ) $ case zoneNSRecord zone of Just rr -> addAnswer rr Nothing -> return () mapM_ addAdditional additionals case zoneNSRecord zone of Just rr -> addAuthority rr Nothing -> unauthorise getAdditionals :: Zone -> SomeRR -> IO [SomeRR] getAdditionals zone (SomeRR rr) = case cast (rrData rr) :: Maybe DomainName of Nothing -> return [] Just name -> do allRecords <- zoneResponder zone name let filtered = filterRecords (SomeQ q') allRecords q' = Question { qName = name , qType = A , qClass = IN } return filtered filterRecords :: SomeQ -> [SomeRR] -> [SomeRR] filterRecords (SomeQ q) | Just (qType q) == cast ANY && Just (qClass q) == cast ANY = id | Just (qType q) == cast ANY = filter matchClass | Just (qClass q) == cast ANY = filter matchType | otherwise = filter matchBoth where matchClass (SomeRR rr) = Just (qClass q) == cast (rrClass rr) matchType (SomeRR rr) = Just (qType q) == cast (rrType rr) || Just CNAME == cast (rrType rr) matchBoth rr = matchType rr && matchClass rr handleAXFR :: SomeQ -> Zone -> IO (Builder ()) handleAXFR (SomeQ q) zone | qName q == zoneName zone && isJust (zoneSOA zone) && cnfAllowTransfer cnf = do names <- zoneRecordNames zone allRecords <- liftM concat $ mapM (zoneResponder zone) names return $ do addAnswer $ fromJust $ zoneSOARecord zone addAnswer $ fromJust $ zoneNSRecord zone mapM_ addAnswer allRecords addAnswerNonuniquely $ fromJust $ zoneSOARecord zone | otherwise = return $ return () validateQuery :: Message -> ResponseCode validateQuery = validateHeader . msgHeader where validateHeader :: Header -> ResponseCode validateHeader hdr | hdMessageType hdr /= Query = NotImplemented | hdOpcode hdr /= StandardQuery = NotImplemented | otherwise = NoError packMessage :: Maybe Int -> Message -> BS.ByteString packMessage limM = BS.concat . LBS.toChunks . truncateMsg where truncateMsg :: Message -> LBS.ByteString truncateMsg msg = let packet = encode msg needTrunc = fromMaybe False $ do lim <- limM return $ fromIntegral (LBS.length packet) > lim in if needTrunc then truncateMsg $ trunc' msg else packet trunc' :: Message -> Message trunc' msg | notNull $ msgAdditionals msg = msg { msgAdditionals = truncList $ msgAdditionals msg } | notNull $ msgAuthorities msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgAuthorities = truncList $ msgAuthorities msg } | notNull $ msgAnswers msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgAnswers = truncList $ msgAnswers msg } | notNull $ msgQuestions msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgQuestions = truncList $ msgQuestions msg } | otherwise = error ("packMessage: You are already skinny and need no diet: " ++ show msg) setTruncFlag :: Header -> Header setTruncFlag hdr = hdr { hdIsTruncated = True } notNull :: [a] -> Bool notNull = not . null truncList :: [a] -> [a] truncList xs = take (length xs - 1) xs unpackMessage :: BS.ByteString -> Message unpackMessage = decode . LBS.fromChunks . return mkErrorReply :: ResponseCode -> Message -> Message mkErrorReply err msg = runBuilder msg $ do unauthorise setResponseCode err