]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Named.hs
ZoneFinder comes back
[haskell-dns.git] / Network / DNS / Named.hs
index 1eaf27af89975b0d4df987c8aed4c3180607235a..7abdcd083461ba6720975c2645cba0f403271949 100644 (file)
@@ -7,6 +7,8 @@ import           Control.Concurrent
 import           Control.Exception
 import           Control.Monad
 import           Data.Binary
+import           Data.Binary.Get
+import           Data.Binary.Put
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as LBS
 import           Data.Dynamic
@@ -18,21 +20,31 @@ import           Network.DNS.Named.Config
 import           Network.DNS.Named.ResponseBuilder
 import           Network.DNS.Named.Zone
 import           System.Posix.Signals
+import           System.IO
 
 
-runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
-runNamed cnf findZone
+runNamed :: ZoneFinder zf => Config -> zf -> IO ()
+runNamed cnf zf
     = withSocketsDo $
       do installHandler sigPIPE Ignore Nothing
-         _tcpListenerTID <- forkIO $ tcpListen
-         udpListen
+
+         let hint = defaultHints {
+                      addrFlags      = [AI_PASSIVE, AI_V4MAPPED]
+                    , addrFamily     = AF_INET6
+                    , addrSocketType = NoSocketType
+                    , addrProtocol   = defaultProtocol
+                    }
+         (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
+
+         _tcpListenerTID <- forkIO $ tcpListen ai
+         udpListen ai
     where
-      udpListen :: IO ()
-      udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
-                     so <- socket AF_INET Datagram defaultProtocol
-                     print cnf
-                     bindSocket so $ cnfServerAddress cnf
-                     udpLoop so
+      udpListen :: AddrInfo -> IO ()
+      udpListen ai
+          = do so <- socket (addrFamily ai) Datagram defaultProtocol
+               setSocketOption so ReuseAddr 1
+               bindSocket so (addrAddress ai)
+               udpLoop so
 
       udpLoop :: Socket -> IO ()
       udpLoop so
@@ -40,8 +52,21 @@ runNamed cnf findZone
                _handlerTID <- forkIO $ udpHandler so packet cameFrom
                udpLoop so
 
-      tcpListen :: IO ()
-      tcpListen = putStrLn "FIXME: tcpListen is not implemented yet."
+      tcpListen :: AddrInfo -> IO ()
+      tcpListen ai
+          = do so <- socket (addrFamily ai) Stream defaultProtocol
+               setSocketOption so ReuseAddr 1
+               bindSocket so (addrAddress ai)
+               listen so 255
+               tcpLoop so
+
+      tcpLoop :: Socket -> IO ()
+      tcpLoop so
+          = do (so', _)    <- accept so
+               h           <- socketToHandle so' ReadWriteMode
+               hSetBuffering h $ BlockBuffering Nothing
+               _handlerTID <- forkIO $ tcpHandler h
+               tcpLoop so
 
       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
       udpHandler so packet cameFrom
@@ -53,6 +78,31 @@ runNamed cnf findZone
                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
                return ()
 
+      tcpHandler :: Handle -> IO ()
+      tcpHandler h
+          = do lenB   <- LBS.hGet h 2
+               if LBS.length lenB < 2 then
+                   -- Got EOF
+                   hClose h
+                 else
+                   do let len = runGet getWord16be lenB
+                      packet <- BS.hGet h $ fromIntegral len
+                      msg    <- evaluate $ unpackMessage packet
+                      msg'   <- handleMessage msg
+                                `onException`
+                                do let servfail = mkErrorReply ServerFailure msg
+                                       packet'  = packMessage Nothing servfail
+                                       len'     = fromIntegral $ BS.length packet'
+                                   LBS.hPut h $ runPut $ putWord16be len'
+                                   BS.hPut h packet'
+                                   hClose h
+                      let packet' = packMessage Nothing msg'
+                          len'    = fromIntegral $ BS.length packet'
+                      LBS.hPut h $ runPut $ putWord16be len'
+                      BS.hPut h packet'
+                      hFlush h
+                      tcpHandler h
+
       handleMessage :: Message -> IO Message
       handleMessage msg
           = case validateQuery msg of
@@ -68,7 +118,7 @@ runNamed cnf findZone
 
       handleQuestion :: SomeQ -> IO (Builder ())
       handleQuestion (SomeQ q)
-          = do zoneM <- findZone (qName q)
+          = do zoneM <- findZone zf (qName q)
                case zoneM of
                  Nothing
                      -> return $ do unauthorise
@@ -119,35 +169,60 @@ runNamed cnf findZone
               Just name
                   -> do allRecords <- zoneResponder zone name
 
-                        let filtered = filterRecords (SomeQ q') allRecords
-                            q'       = Question {
-                                         qName  = name
-                                       , qType  = A
-                                       , qClass = IN
-                                       }
-                        return filtered
+                        let rA = filterRecords (SomeQ qA) allRecords
+                            rB = filterRecords (SomeQ qB) allRecords
+                            qA = Question {
+                                   qName  = name
+                                 , qType  = A
+                                 , qClass = IN
+                                 }
+                            qB = Question {
+                                   qName  = name
+                                 , qType  = AAAA
+                                 , qClass = IN
+                                 }
+                        return (rA ++ rB)
 
       filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
-      filterRecords (SomeQ q)
-          | Just (qType  q) == cast ANY &&
-            Just (qClass q) == cast ANY    = id
-          | Just (qType  q) == cast ANY    = filter matchClass
-          | Just (qClass q) == cast ANY    = filter matchType
-          | otherwise                      = filter matchBoth
+      filterRecords (SomeQ q) = filter predicate
           where
-            matchClass (SomeRR rr)
-                = Just (qClass q) == cast (rrClass rr)
+            predicate rr
+                = predForType rr && predForClass rr
 
-            matchType (SomeRR rr)
-                = Just (qType  q) == cast (rrType  rr) ||
-                  Just CNAME      == cast (rrType  rr)
+            predForType (SomeRR rr)
+                | typeOf (qType q) == typeOf ANY
+                    = True
 
-            matchBoth rr
-                = matchType rr && matchClass rr
+                | typeOf (qType q) == typeOf MAILB
+                    = typeOf (rrType rr) == typeOf MR ||
+                      typeOf (rrType rr) == typeOf MB ||
+                      typeOf (rrType rr) == typeOf MG ||
+                      typeOf (rrType rr) == typeOf MINFO
+
+                | otherwise
+                    = typeOf (rrType rr) == typeOf (qType q) ||
+                      typeOf (rrType rr) == typeOf CNAME
+
+            predForClass (SomeRR rr)
+                | typeOf (qClass q) == typeOf ANY
+                    = True
+
+                | otherwise
+                    = typeOf (rrClass rr) == typeOf (qClass q)
 
       handleAXFR :: SomeQ -> Zone -> IO (Builder ())
-      handleAXFR (SomeQ _q) _zone
-          = fail "FIXME: not implemented yet"
+      handleAXFR (SomeQ q) zone
+          | qName q == zoneName zone &&
+            isJust (zoneSOA zone)    &&
+            cnfAllowTransfer cnf
+              = do names      <- zoneRecordNames zone
+                   allRecords <- liftM concat $ mapM (zoneResponder zone) names
+                   return $ do addAnswer $ fromJust $ zoneSOARecord zone
+                               addAnswer $ fromJust $ zoneNSRecord  zone
+                               mapM_ addAnswer allRecords
+                               addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
+          | otherwise
+              = return $ return ()
 
 
 validateQuery :: Message -> ResponseCode