module Network.DNS.Named ( runNamed ) where import Control.Concurrent import Control.Exception import Control.Monad import Data.Binary import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Dynamic import Data.Maybe import Network.Socket import qualified Network.Socket.ByteString as NB import Network.DNS.Message import Network.DNS.Named.Config import Network.DNS.Named.ResponseBuilder import Network.DNS.Named.Zone import System.Posix.Signals runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO () runNamed cnf findZone = withSocketsDo $ do installHandler sigPIPE Ignore Nothing _tcpListenerTID <- forkIO $ tcpListen udpListen where udpListen :: IO () udpListen = do -- FIXME: we should support IPv6 when the network package supports it. so <- socket AF_INET Datagram defaultProtocol print cnf bindSocket so $ cnfServerAddress cnf udpLoop so udpLoop :: Socket -> IO () udpLoop so = do (packet, cameFrom) <- NB.recvFrom so 512 _handlerTID <- forkIO $ udpHandler so packet cameFrom udpLoop so tcpListen :: IO () tcpListen = putStrLn "FIXME: tcpListen is not implemented yet." udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO () udpHandler so packet cameFrom = do msg <- evaluate $ unpackMessage packet msg' <- handleMessage msg `onException` do let servfail = mkErrorReply ServerFailure msg NB.sendTo so (packMessage (Just 512) servfail) cameFrom _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom return () handleMessage :: Message -> IO Message handleMessage msg = case validateQuery msg of NoError -> do builders <- mapM handleQuestion $ msgQuestions msg let builder = foldl (>>) (return ()) builders msg' = runBuilder msg builder return msg' err -> return $ mkErrorReply err msg handleQuestion :: SomeQ -> IO (Builder ()) handleQuestion (SomeQ q) = do zoneM <- findZone (qName q) case zoneM of Nothing -> return $ do unauthorise setResponseCode Refused Just zone -> handleQuestionForZone (SomeQ q) zone handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ()) handleQuestionForZone (SomeQ q) zone | Just (qType q) == cast AXFR = handleAXFR (SomeQ q) zone | otherwise = do allRecords <- zoneResponder zone (qName q) let filtered = filterRecords (SomeQ q) allRecords additionals <- do xss <- mapM (getAdditionals zone) filtered ys <- case zoneNSRecord zone of Just rr -> getAdditionals zone rr Nothing -> return [] return (concat xss ++ ys) return $ do mapM_ addAnswer filtered when (qName q == zoneName zone) $ do when (Just (qType q) == cast SOA || Just (qType q) == cast ANY ) $ case zoneSOARecord zone of Just rr -> addAnswer rr Nothing -> return () when (Just (qType q) == cast NS || Just (qType q) == cast ANY ) $ case zoneNSRecord zone of Just rr -> addAnswer rr Nothing -> return () mapM_ addAdditional additionals case zoneNSRecord zone of Just rr -> addAuthority rr Nothing -> unauthorise getAdditionals :: Zone -> SomeRR -> IO [SomeRR] getAdditionals zone (SomeRR rr) = case cast (rrData rr) :: Maybe DomainName of Nothing -> return [] Just name -> do allRecords <- zoneResponder zone name let filtered = filterRecords (SomeQ q') allRecords q' = Question { qName = name , qType = A , qClass = IN } return filtered filterRecords :: SomeQ -> [SomeRR] -> [SomeRR] filterRecords (SomeQ q) | Just (qType q) == cast ANY && Just (qClass q) == cast ANY = id | Just (qType q) == cast ANY = filter matchClass | Just (qClass q) == cast ANY = filter matchType | otherwise = filter matchBoth where matchClass (SomeRR rr) = Just (qClass q) == cast (rrClass rr) matchType (SomeRR rr) = Just (qType q) == cast (rrType rr) || Just CNAME == cast (rrType rr) matchBoth rr = matchType rr && matchClass rr handleAXFR :: SomeQ -> Zone -> IO (Builder ()) handleAXFR (SomeQ _q) _zone = fail "FIXME: not implemented yet" validateQuery :: Message -> ResponseCode validateQuery = validateHeader . msgHeader where validateHeader :: Header -> ResponseCode validateHeader hdr | hdMessageType hdr /= Query = NotImplemented | hdOpcode hdr /= StandardQuery = NotImplemented | otherwise = NoError packMessage :: Maybe Int -> Message -> BS.ByteString packMessage limM = BS.concat . LBS.toChunks . truncateMsg where truncateMsg :: Message -> LBS.ByteString truncateMsg msg = let packet = encode msg needTrunc = fromMaybe False $ do lim <- limM return $ fromIntegral (LBS.length packet) > lim in if needTrunc then truncateMsg $ trunc' msg else packet trunc' :: Message -> Message trunc' msg | notNull $ msgAdditionals msg = msg { msgAdditionals = truncList $ msgAdditionals msg } | notNull $ msgAuthorities msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgAuthorities = truncList $ msgAuthorities msg } | notNull $ msgAnswers msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgAnswers = truncList $ msgAnswers msg } | notNull $ msgQuestions msg = msg { msgHeader = setTruncFlag $ msgHeader msg , msgQuestions = truncList $ msgQuestions msg } | otherwise = error ("packMessage: You are already skinny and need no diet: " ++ show msg) setTruncFlag :: Header -> Header setTruncFlag hdr = hdr { hdIsTruncated = True } notNull :: [a] -> Bool notNull = not . null truncList :: [a] -> [a] truncList xs = take (length xs - 1) xs unpackMessage :: BS.ByteString -> Message unpackMessage = decode . LBS.fromChunks . return mkErrorReply :: ResponseCode -> Message -> Message mkErrorReply err msg = runBuilder msg $ do unauthorise setResponseCode err