module Network.DNS.Named
- ( ZoneFinder(..)
- , Zone(..)
-
- , runNamed
-
- , defaultRootZone
+ ( runNamed
)
where
import Control.Exception
import Control.Monad
import Data.Binary
+import Data.Binary.Get
+import Data.Binary.Put
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
+import Data.Dynamic
import Data.Maybe
import Network.Socket
import qualified Network.Socket.ByteString as NB
import Network.DNS.Message
import Network.DNS.Named.Config
+import Network.DNS.Named.ResponseBuilder
+import Network.DNS.Named.Zone
import System.Posix.Signals
-
-
-class ZoneFinder a where
- findZone :: a -> DomainName -> IO Zone
-
-instance ZoneFinder (DomainName -> Zone) where
- findZone = (return .)
-
-instance ZoneFinder (DomainName -> IO Zone) where
- findZone = id
-
-instance ZoneFinder (DomainName -> Maybe Zone) where
- findZone = ((return . fromMaybe defaultRootZone) .)
-
-instance ZoneFinder (DomainName -> IO (Maybe Zone)) where
- findZone = (fmap (fromMaybe defaultRootZone) .)
-
-
-data Zone
- = Zone {
- zoneName :: !DomainName
- }
-
-defaultRootZone :: Zone
-defaultRootZone = error "FIXME: defaultRootZone is not implemented yet"
+import System.IO
runNamed :: ZoneFinder zf => Config -> zf -> IO ()
runNamed cnf zf
= withSocketsDo $
do installHandler sigPIPE Ignore Nothing
- _tcpListenerTID <- forkIO $ tcpListen
- udpListen
+
+ let hint = defaultHints {
+ addrFlags = [AI_PASSIVE, AI_V4MAPPED]
+ , addrFamily = AF_INET6
+ , addrSocketType = NoSocketType
+ , addrProtocol = defaultProtocol
+ }
+ (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
+
+ _tcpListenerTID <- forkIO $ tcpListen ai
+ udpListen ai
where
- udpListen :: IO ()
- udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
- so <- socket AF_INET Datagram defaultProtocol
- print cnf
- bindSocket so $ cnfServerAddress cnf
- udpLoop so
+ udpListen :: AddrInfo -> IO ()
+ udpListen ai
+ = do so <- socket (addrFamily ai) Datagram defaultProtocol
+ setSocketOption so ReuseAddr 1
+ bindSocket so (addrAddress ai)
+ udpLoop so
udpLoop :: Socket -> IO ()
udpLoop so
_handlerTID <- forkIO $ udpHandler so packet cameFrom
udpLoop so
- tcpListen :: IO ()
- tcpListen = putStrLn "FIXME: tcpListen is not implemented yet."
+ tcpListen :: AddrInfo -> IO ()
+ tcpListen ai
+ = do so <- socket (addrFamily ai) Stream defaultProtocol
+ setSocketOption so ReuseAddr 1
+ bindSocket so (addrAddress ai)
+ listen so 255
+ tcpLoop so
+
+ tcpLoop :: Socket -> IO ()
+ tcpLoop so
+ = do (so', _) <- accept so
+ h <- socketToHandle so' ReadWriteMode
+ hSetBuffering h $ BlockBuffering Nothing
+ _handlerTID <- forkIO $ tcpHandler h
+ tcpLoop so
udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
udpHandler so packet cameFrom
= do msg <- evaluate $ unpackMessage packet
msg' <- handleMessage msg
`onException`
- NB.sendTo so (packMessage $ makeServerFailure msg) cameFrom
- _sent <- NB.sendTo so (packMessage $ msg' ) cameFrom
+ do let servfail = mkErrorReply ServerFailure msg
+ NB.sendTo so (packMessage (Just 512) servfail) cameFrom
+ _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
return ()
+ tcpHandler :: Handle -> IO ()
+ tcpHandler h
+ = do lenB <- LBS.hGet h 2
+ if LBS.length lenB < 2 then
+ -- Got EOF
+ hClose h
+ else
+ do let len = runGet getWord16be lenB
+ packet <- BS.hGet h $ fromIntegral len
+ msg <- evaluate $ unpackMessage packet
+ msg' <- handleMessage msg
+ `onException`
+ do let servfail = mkErrorReply ServerFailure msg
+ packet' = packMessage Nothing servfail
+ len' = fromIntegral $ BS.length packet'
+ LBS.hPut h $ runPut $ putWord16be len'
+ BS.hPut h packet'
+ hClose h
+ let packet' = packMessage Nothing msg'
+ len' = fromIntegral $ BS.length packet'
+ LBS.hPut h $ runPut $ putWord16be len'
+ BS.hPut h packet'
+ hFlush h
+ tcpHandler h
+
handleMessage :: Message -> IO Message
handleMessage msg
- = fail (show msg) -- FIXME
+ = case validateQuery msg of
+ NoError
+ -> do builders <- mapM handleQuestion $ msgQuestions msg
+
+ let builder = foldl (>>) (return ()) builders
+ msg' = runBuilder msg builder
+
+ return msg'
+
+ err -> return $ mkErrorReply err msg
+
+ handleQuestion :: SomeQ -> IO (Builder ())
+ handleQuestion (SomeQ q)
+ = do zoneM <- findZone zf (qName q)
+ case zoneM of
+ Nothing
+ -> return $ do unauthorise
+ setResponseCode Refused
+ Just zone
+ -> handleQuestionForZone q zone
+
+ handleQuestionForZone :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ())
+ handleQuestionForZone q zone
+ | Just (qType q) == cast AXFR
+ = handleAXFR q zone
+ | otherwise
+ = do answers <- getRecords zone q
+ authority <- getRecords zone (Question (zoneName zone) NS IN)
+ additionals <- liftM concat $ mapM (getAdditionals zone) (answers ++ authority)
+ isAuth <- isAuthoritativeZone zone
+ return $ do mapM_ addAnswer answers
+ mapM_ addAuthority authority
+ mapM_ addAdditional additionals
+ unless isAuth unauthorise
+
+ getAdditionals :: Zone z => z -> SomeRR -> IO [SomeRR]
+ getAdditionals zone (SomeRR rr)
+ = case cast (rrData rr) :: Maybe DomainName of
+ Nothing
+ -> return []
+ Just name
+ -> do rrA <- getRecords zone (Question name A IN)
+ rrAAAA <- getRecords zone (Question name AAAA IN)
+ return (rrA ++ rrAAAA)
+
+ handleAXFR :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ())
+ handleAXFR q zone
+ | cnfAllowTransfer cnf
+ = do rs <- getRecords zone q
+ return $ mapM_ addAnswerNonuniquely rs
+ | otherwise
+ = return $ return ()
+
+
+validateQuery :: Message -> ResponseCode
+validateQuery = validateHeader . msgHeader
+ where
+ validateHeader :: Header -> ResponseCode
+ validateHeader hdr
+ | hdMessageType hdr /= Query = NotImplemented
+ | hdOpcode hdr /= StandardQuery = NotImplemented
+ | otherwise = NoError
-packMessage :: Message -> BS.ByteString
-packMessage = BS.concat . LBS.toChunks . encode
+packMessage :: Maybe Int -> Message -> BS.ByteString
+packMessage limM = BS.concat . LBS.toChunks . truncateMsg
+ where
+ truncateMsg :: Message -> LBS.ByteString
+ truncateMsg msg
+ = let packet = encode msg
+ needTrunc = fromMaybe False $
+ do lim <- limM
+ return $ fromIntegral (LBS.length packet) > lim
+ in
+ if needTrunc then
+ truncateMsg $ trunc' msg
+ else
+ packet
+
+ trunc' :: Message -> Message
+ trunc' msg
+ | notNull $ msgAdditionals msg
+ = msg {
+ msgAdditionals = truncList $ msgAdditionals msg
+ }
+ | notNull $ msgAuthorities msg
+ = msg {
+ msgHeader = setTruncFlag $ msgHeader msg
+ , msgAuthorities = truncList $ msgAuthorities msg
+ }
+ | notNull $ msgAnswers msg
+ = msg {
+ msgHeader = setTruncFlag $ msgHeader msg
+ , msgAnswers = truncList $ msgAnswers msg
+ }
+ | notNull $ msgQuestions msg
+ = msg {
+ msgHeader = setTruncFlag $ msgHeader msg
+ , msgQuestions = truncList $ msgQuestions msg
+ }
+ | otherwise
+ = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
+
+ setTruncFlag :: Header -> Header
+ setTruncFlag hdr = hdr { hdIsTruncated = True }
+
+ notNull :: [a] -> Bool
+ notNull = not . null
+
+ truncList :: [a] -> [a]
+ truncList xs = take (length xs - 1) xs
unpackMessage :: BS.ByteString -> Message
unpackMessage = decode . LBS.fromChunks . return
-
-makeServerFailure :: Message -> Message
-makeServerFailure msg
- = let header = msgHeader msg
- msg' = msg {
- msgHeader = header {
- hdMessageType = Response
- , hdIsAuthoritativeAnswer = False
- , hdIsTruncated = False
- , hdIsRecursionAvailable = False
- , hdResponseCode = ServerFailure
- }
- }
- in
- msg'
+mkErrorReply :: ResponseCode -> Message -> Message
+mkErrorReply err msg
+ = runBuilder msg $ do unauthorise
+ setResponseCode err