module Network.DNS.Named ( runNamed ) where import Control.Concurrent import Control.Exception import Control.Monad import Data.Binary import Data.Binary.Get import Data.Binary.Put import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Dynamic import Data.Maybe import Network.Socket import qualified Network.Socket.ByteString as NB import Network.DNS.Message import Network.DNS.Named.Config import Network.DNS.Named.ResponseBuilder import Network.DNS.Named.Zone import System.Posix.Signals import System.IO runNamed :: ZoneFinder zf => Config -> zf -> IO () runNamed cnf zf = withSocketsDo $ do installHandler sigPIPE Ignore Nothing let hint = defaultHints { addrFlags = [AI_PASSIVE, AI_V4MAPPED] , addrFamily = AF_INET6 , addrSocketType = NoSocketType , addrProtocol = defaultProtocol } (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf) _tcpListenerTID <- forkIO $ tcpListen ai udpListen ai where udpListen :: AddrInfo -> IO () udpListen ai = do so <- socket (addrFamily ai) Datagram defaultProtocol setSocketOption so ReuseAddr 1 bindSocket so (addrAddress ai) udpLoop so udpLoop :: Socket -> IO () udpLoop so = do (packet, cameFrom) <- NB.recvFrom so 512 _handlerTID <- forkIO $ udpHandler so packet cameFrom udpLoop so tcpListen :: AddrInfo -> IO () tcpListen ai = do so <- socket (addrFamily ai) Stream defaultProtocol setSocketOption so ReuseAddr 1 bindSocket so (addrAddress ai) listen so 255 tcpLoop so tcpLoop :: Socket -> IO () tcpLoop so = do (so', _) <- accept so h <- socketToHandle so' ReadWriteMode hSetBuffering h $ BlockBuffering Nothing _handlerTID <- forkIO $ tcpHandler h tcpLoop so udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO () udpHandler so packet cameFrom = do msg <- evaluate $ unpackMessage packet msg' <- handleMessage msg `onException` do let servfail = mkErrorReply ServerFailure msg NB.sendTo so (packMessage (Just 512) servfail) cameFrom _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom return () tcpHandler :: Handle -> IO () tcpHandler h = do lenB <- LBS.hGet h 2 if LBS.length lenB < 2 then -- Got EOF hClose h else do let len = runGet getWord16be lenB packet <- BS.hGet h $ fromIntegral len msg <- evaluate $ unpackMessage packet msg' <- handleMessage msg `onException` do let servfail = mkErrorReply ServerFailure msg packet' = packMessage Nothing servfail len' = fromIntegral $ BS.length packet' LBS.hPut h $ runPut $ putWord16be len' BS.hPut h packet' hClose h let packet' = packMessage Nothing msg' len' = fromIntegral $ BS.length packet' LBS.hPut h $ runPut $ putWord16be len' BS.hPut h packet' hFlush h tcpHandler h handleMessage :: Message -> IO Message handleMessage msg = case validateQuery msg of NoError -> do builders <- mapM handleQuestion $ msgQuestions msg let builder = foldl (>>) (return ()) builders msg' = runBuilder msg builder return msg' err -> return $ mkErrorReply err msg handleQuestion :: SomeQ -> IO (Builder ()) handleQuestion (SomeQ q) = do zoneM <- findZone zf (qName q) case zoneM of Nothing -> return $ do unauthorise setResponseCode Refused Just zone -> handleQuestionForZone q zone handleQuestionForZone :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ()) handleQuestionForZone q zone | Just (qType q) == cast AXFR = handleAXFR q zone | otherwise = do answers <- getRecords zone q authority <- getRecords zone (Question (zoneName zone) NS IN) additionals <- liftM concat $ mapM (getAdditionals zone) (answers ++ authority) isAuth <- isAuthoritativeZone zone return $ do mapM_ addAnswer answers mapM_ addAuthority authority mapM_ addAdditional additionals unless isAuth unauthorise getAdditionals :: Zone z => z -> SomeRR -> IO [SomeRR] getAdditionals zone (SomeRR rr) = case cast (rrData rr) :: Maybe DomainName of Nothing -> return [] Just name -> do rrA <- getRecords zone (Question name A IN) rrAAAA <- getRecords zone (Question name AAAA IN) return (rrA ++ rrAAAA) handleAXFR :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ()) handleAXFR q zone | cnfAllowTransfer cnf = do rs <- getRecords zone q return $ mapM_ addAnswerNonuniquely rs | otherwise = return $ return () validateQuery :: Message -> ResponseCode validateQuery = validateHeader . msgHeader where validateHeader :: Header -> ResponseCode validateHeader hdr | hdMessageType hdr /= Query = NotImplemented | hdOpcode hdr /= StandardQuery = NotImplemented | otherwise = NoError packMessage :: Maybe Int -> Message -> BS.ByteString packMessage limM = BS.concat . LBS.toChunks . truncateMsg where truncateMsg :: Message -> LBS.ByteString truncateMsg msg = let packet = encode msg needTrunc = fromMaybe False $ do lim <- limM return $ fromIntegral (LBS.length packet) > lim in if needTrunc then truncateMsg $ trunc' msg else packet trunc' :: Message -> Message trunc' msg | notNull $ msgAdditionals msg = msg { msgAdditionals = truncList $ msgAdditionals msg } | notNull $ msgAuthorities msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgAuthorities = truncList $ msgAuthorities msg } | notNull $ msgAnswers msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgAnswers = truncList $ msgAnswers msg } | notNull $ msgQuestions msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgQuestions = truncList $ msgQuestions msg } | otherwise = error ("packMessage: You are already skinny and need no diet: " ++ show msg) setTruncFlag :: Header -> Header setTruncFlag hdr = hdr { hdIsTruncated = True } notNull :: [a] -> Bool notNull = not . null truncList :: [a] -> [a] truncList xs = take (length xs - 1) xs unpackMessage :: BS.ByteString -> Message unpackMessage = decode . LBS.fromChunks . return mkErrorReply :: ResponseCode -> Message -> Message mkErrorReply err msg = runBuilder msg $ do unauthorise setResponseCode err