import Control.Exception
import Control.Monad
import Data.Binary
+import Data.Binary.Get
+import Data.Binary.Put
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
+import Data.Dynamic
import Data.Maybe
import Network.Socket
import qualified Network.Socket.ByteString as NB
import Network.DNS.Message
import Network.DNS.Named.Config
-import Network.DNS.Named.Responder
import Network.DNS.Named.ResponseBuilder
import Network.DNS.Named.Zone
import System.Posix.Signals
+import System.IO
-runNamed :: ZoneFinder zf => Config -> zf -> IO ()
-runNamed cnf zf
+runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
+runNamed cnf findZone
= withSocketsDo $
do installHandler sigPIPE Ignore Nothing
_tcpListenerTID <- forkIO $ tcpListen
udpListen :: IO ()
udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
so <- socket AF_INET Datagram defaultProtocol
- print cnf
bindSocket so $ cnfServerAddress cnf
udpLoop so
udpLoop so
tcpListen :: IO ()
- tcpListen = putStrLn "FIXME: tcpListen is not implemented yet."
+ tcpListen = do so <- socket AF_INET Stream defaultProtocol
+ bindSocket so $ cnfServerAddress cnf
+ listen so 255
+ tcpLoop so
+
+ tcpLoop :: Socket -> IO ()
+ tcpLoop so
+ = do (so', _) <- accept so
+ h <- socketToHandle so' ReadWriteMode
+ hSetBuffering h $ BlockBuffering Nothing
+ _handlerTID <- forkIO $ tcpHandler h
+ tcpLoop so
udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
udpHandler so packet cameFrom
_sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
return ()
+ tcpHandler :: Handle -> IO ()
+ tcpHandler h
+ = do lenB <- LBS.hGet h 2
+ if LBS.length lenB < 2 then
+ -- Got EOF
+ hClose h
+ else
+ do let len = runGet getWord16be lenB
+ packet <- BS.hGet h $ fromIntegral len
+ msg <- evaluate $ unpackMessage packet
+ msg' <- handleMessage msg
+ `onException`
+ do let servfail = mkErrorReply ServerFailure msg
+ packet' = packMessage Nothing servfail
+ len' = fromIntegral $ BS.length packet'
+ LBS.hPut h $ runPut $ putWord16be len'
+ BS.hPut h packet'
+ hClose h
+ let packet' = packMessage Nothing msg'
+ len' = fromIntegral $ BS.length packet'
+ LBS.hPut h $ runPut $ putWord16be len'
+ BS.hPut h packet'
+ hFlush h
+ tcpHandler h
+
handleMessage :: Message -> IO Message
handleMessage msg
= case validateQuery msg of
handleQuestion :: SomeQ -> IO (Builder ())
handleQuestion (SomeQ q)
- = do zone <- findZone zf (qName q)
- -- FIXME: this is merely a bogus implementation.
- -- It considers no additional or authoritative sections.
- results <- mapM (runResponder' q) (zoneResponders zone)
- return $ mapM_ addAnswer $ concat results
+ = do zoneM <- findZone (qName q)
+ case zoneM of
+ Nothing
+ -> return $ do unauthorise
+ setResponseCode Refused
+ Just zone
+ -> handleQuestionForZone (SomeQ q) zone
+
+ handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
+ handleQuestionForZone (SomeQ q) zone
+ | Just (qType q) == cast AXFR
+ = handleAXFR (SomeQ q) zone
+ | otherwise
+ = do allRecords <- zoneResponder zone (qName q)
+ let filtered = filterRecords (SomeQ q) allRecords
+
+ additionals <- do xss <- mapM (getAdditionals zone) filtered
+ ys <- case zoneNSRecord zone of
+ Just rr -> getAdditionals zone rr
+ Nothing -> return []
+ return (concat xss ++ ys)
+
+ return $ do mapM_ addAnswer filtered
+
+ when (qName q == zoneName zone) $
+ do when (Just (qType q) == cast SOA ||
+ Just (qType q) == cast ANY )
+ $ case zoneSOARecord zone of
+ Just rr -> addAnswer rr
+ Nothing -> return ()
+
+ when (Just (qType q) == cast NS ||
+ Just (qType q) == cast ANY )
+ $ case zoneNSRecord zone of
+ Just rr -> addAnswer rr
+ Nothing -> return ()
+
+ mapM_ addAdditional additionals
+
+ case zoneNSRecord zone of
+ Just rr -> addAuthority rr
+ Nothing -> unauthorise
+
+ getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
+ getAdditionals zone (SomeRR rr)
+ = case cast (rrData rr) :: Maybe DomainName of
+ Nothing
+ -> return []
+ Just name
+ -> do allRecords <- zoneResponder zone name
+
+ let filtered = filterRecords (SomeQ q') allRecords
+ q' = Question {
+ qName = name
+ , qType = A
+ , qClass = IN
+ }
+ return filtered
+
+ filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
+ filterRecords (SomeQ q)
+ | Just (qType q) == cast ANY &&
+ Just (qClass q) == cast ANY = id
+ | Just (qType q) == cast ANY = filter matchClass
+ | Just (qClass q) == cast ANY = filter matchType
+ | otherwise = filter matchBoth
+ where
+ matchClass (SomeRR rr)
+ = Just (qClass q) == cast (rrClass rr)
+
+ matchType (SomeRR rr)
+ = Just (qType q) == cast (rrType rr) ||
+ Just CNAME == cast (rrType rr)
+
+ matchBoth rr
+ = matchType rr && matchClass rr
+
+ handleAXFR :: SomeQ -> Zone -> IO (Builder ())
+ handleAXFR (SomeQ q) zone
+ | qName q == zoneName zone &&
+ isJust (zoneSOA zone) &&
+ cnfAllowTransfer cnf
+ = do names <- zoneRecordNames zone
+ allRecords <- liftM concat $ mapM (zoneResponder zone) names
+ return $ do addAnswer $ fromJust $ zoneSOARecord zone
+ addAnswer $ fromJust $ zoneNSRecord zone
+ mapM_ addAnswer allRecords
+ addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
+ | otherwise
+ = return $ return ()
validateQuery :: Message -> ResponseCode
mkErrorReply :: ResponseCode -> Message -> Message
mkErrorReply err msg
- = let header = msgHeader msg
- msg' = msg {
- msgHeader = header {
- hdMessageType = Response
- , hdIsAuthoritativeAnswer = False
- , hdIsTruncated = False
- , hdIsRecursionAvailable = False
- , hdResponseCode = err
- }
- }
- in
- msg'
+ = runBuilder msg $ do unauthorise
+ setResponseCode err