]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Named.hs
Code clean up
[haskell-dns.git] / Network / DNS / Named.hs
index 57d9ea4b78b6a765bce5a044444503211c1a8b21..3ce2a9ad7d1018a86c78bb304e4d4d28d28e6c41 100644 (file)
@@ -7,32 +7,44 @@ import           Control.Concurrent
 import           Control.Exception
 import           Control.Monad
 import           Data.Binary
+import           Data.Binary.Get
+import           Data.Binary.Put
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as LBS
+import           Data.Dynamic
 import           Data.Maybe
 import           Network.Socket
 import qualified Network.Socket.ByteString as NB
 import           Network.DNS.Message
 import           Network.DNS.Named.Config
-import           Network.DNS.Named.Responder
 import           Network.DNS.Named.ResponseBuilder
 import           Network.DNS.Named.Zone
 import           System.Posix.Signals
+import           System.IO
 
 
 runNamed :: ZoneFinder zf => Config -> zf -> IO ()
 runNamed cnf zf
     = withSocketsDo $
       do installHandler sigPIPE Ignore Nothing
-         _tcpListenerTID <- forkIO $ tcpListen
-         udpListen
+
+         let hint = defaultHints {
+                      addrFlags      = [AI_PASSIVE, AI_V4MAPPED]
+                    , addrFamily     = AF_INET6
+                    , addrSocketType = NoSocketType
+                    , addrProtocol   = defaultProtocol
+                    }
+         (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
+
+         _tcpListenerTID <- forkIO $ tcpListen ai
+         udpListen ai
     where
-      udpListen :: IO ()
-      udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
-                     so <- socket AF_INET Datagram defaultProtocol
-                     print cnf
-                     bindSocket so $ cnfServerAddress cnf
-                     udpLoop so
+      udpListen :: AddrInfo -> IO ()
+      udpListen ai
+          = do so <- socket (addrFamily ai) Datagram defaultProtocol
+               setSocketOption so ReuseAddr 1
+               bindSocket so (addrAddress ai)
+               udpLoop so
 
       udpLoop :: Socket -> IO ()
       udpLoop so
@@ -40,8 +52,21 @@ runNamed cnf zf
                _handlerTID <- forkIO $ udpHandler so packet cameFrom
                udpLoop so
 
-      tcpListen :: IO ()
-      tcpListen = putStrLn "FIXME: tcpListen is not implemented yet."
+      tcpListen :: AddrInfo -> IO ()
+      tcpListen ai
+          = do so <- socket (addrFamily ai) Stream defaultProtocol
+               setSocketOption so ReuseAddr 1
+               bindSocket so (addrAddress ai)
+               listen so 255
+               tcpLoop so
+
+      tcpLoop :: Socket -> IO ()
+      tcpLoop so
+          = do (so', _)    <- accept so
+               h           <- socketToHandle so' ReadWriteMode
+               hSetBuffering h $ BlockBuffering Nothing
+               _handlerTID <- forkIO $ tcpHandler h
+               tcpLoop so
 
       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
       udpHandler so packet cameFrom
@@ -53,6 +78,31 @@ runNamed cnf zf
                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
                return ()
 
+      tcpHandler :: Handle -> IO ()
+      tcpHandler h
+          = do lenB   <- LBS.hGet h 2
+               if LBS.length lenB < 2 then
+                   -- Got EOF
+                   hClose h
+                 else
+                   do let len = runGet getWord16be lenB
+                      packet <- BS.hGet h $ fromIntegral len
+                      msg    <- evaluate $ unpackMessage packet
+                      msg'   <- handleMessage msg
+                                `onException`
+                                do let servfail = mkErrorReply ServerFailure msg
+                                       packet'  = packMessage Nothing servfail
+                                       len'     = fromIntegral $ BS.length packet'
+                                   LBS.hPut h $ runPut $ putWord16be len'
+                                   BS.hPut h packet'
+                                   hClose h
+                      let packet' = packMessage Nothing msg'
+                          len'    = fromIntegral $ BS.length packet'
+                      LBS.hPut h $ runPut $ putWord16be len'
+                      BS.hPut h packet'
+                      hFlush h
+                      tcpHandler h
+
       handleMessage :: Message -> IO Message
       handleMessage msg
           = case validateQuery msg of
@@ -68,11 +118,45 @@ runNamed cnf zf
 
       handleQuestion :: SomeQ -> IO (Builder ())
       handleQuestion (SomeQ q)
-          = do zone    <- findZone zf (qName q)
-               -- FIXME: this is merely a bogus implementation.
-               -- It considers no additional or authoritative sections.
-               results <- mapM (runResponder' q) (zoneResponders zone)
-               return $ mapM_ addAnswer $ concat results
+          = do zoneM <- findZone zf (qName q)
+               case zoneM of
+                 Nothing
+                     -> return $ do unauthorise
+                                    setResponseCode Refused
+                 Just zone
+                     -> handleQuestionForZone q zone
+
+      handleQuestionForZone :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ())
+      handleQuestionForZone q zone
+          | Just (qType q) == cast AXFR
+              = handleAXFR q zone
+          | otherwise
+              = do answers     <- getRecords zone q
+                   authority   <- getRecords zone (Question (zoneName zone) NS IN)
+                   additionals <- liftM concat $ mapM (getAdditionals zone) (answers ++ authority)
+                   isAuth      <- isAuthoritativeZone zone
+                   return $ do mapM_ addAnswer     answers
+                               mapM_ addAuthority  authority
+                               mapM_ addAdditional additionals
+                               unless isAuth unauthorise
+
+      getAdditionals :: Zone z => z -> SomeRR -> IO [SomeRR]
+      getAdditionals zone (SomeRR rr)
+          = case cast (rrData rr) :: Maybe DomainName of
+              Nothing
+                  -> return []
+              Just name
+                  -> do rrA    <- getRecords zone (Question name A    IN)
+                        rrAAAA <- getRecords zone (Question name AAAA IN)
+                        return (rrA ++ rrAAAA)
+
+      handleAXFR :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ())
+      handleAXFR q zone
+          | cnfAllowTransfer cnf
+              = do rs <- getRecords zone q
+                   return $ mapM_ addAnswerNonuniquely rs
+          | otherwise
+              = return $ return ()
 
 
 validateQuery :: Message -> ResponseCode
@@ -138,15 +222,5 @@ unpackMessage = decode . LBS.fromChunks . return
 
 mkErrorReply :: ResponseCode -> Message -> Message
 mkErrorReply err msg
-    = let header = msgHeader msg
-          msg'   = msg {
-                     msgHeader = header {
-                                   hdMessageType           = Response
-                                 , hdIsAuthoritativeAnswer = False
-                                 , hdIsTruncated           = False
-                                 , hdIsRecursionAvailable  = False
-                                 , hdResponseCode          = err
-                                 }
-                   }
-      in
-        msg'
+    = runBuilder msg $ do unauthorise
+                          setResponseCode err