]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Named.hs
DomainMap: totally untested yet
[haskell-dns.git] / Network / DNS / Named.hs
index 13297e8ae2fbbb2bf31051255ebe23ecd8bb5d8b..4a9eaed2fa26eda15576a94673deead4c5d42a82 100644 (file)
@@ -1,10 +1,5 @@
 module Network.DNS.Named
-    ( ZoneFinder(..)
-    , Zone(..)
-
-    , runNamed
-
-    , defaultRootZone
+    ( runNamed
     )
     where
 
@@ -12,43 +7,24 @@ import           Control.Concurrent
 import           Control.Exception
 import           Control.Monad
 import           Data.Binary
+import           Data.Binary.Get
+import           Data.Binary.Put
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Lazy as LBS
+import           Data.Dynamic
 import           Data.Maybe
 import           Network.Socket
 import qualified Network.Socket.ByteString as NB
 import           Network.DNS.Message
 import           Network.DNS.Named.Config
+import           Network.DNS.Named.ResponseBuilder
+import           Network.DNS.Named.Zone
 import           System.Posix.Signals
+import           System.IO
 
 
-class ZoneFinder a where
-    findZone :: a -> DomainName -> IO Zone
-
-instance ZoneFinder (DomainName -> Zone) where
-    findZone = (return .)
-
-instance ZoneFinder (DomainName -> IO Zone) where
-    findZone = id
-
-instance ZoneFinder (DomainName -> Maybe Zone) where
-    findZone = ((return . fromMaybe defaultRootZone) .)
-
-instance ZoneFinder (DomainName -> IO (Maybe Zone)) where
-    findZone = (fmap (fromMaybe defaultRootZone) .)
-
-
-data Zone
-    = Zone {
-        zoneName :: !DomainName
-      }
-
-defaultRootZone :: Zone
-defaultRootZone = error "FIXME: defaultRootZone is not implemented yet"
-
-
-runNamed :: ZoneFinder zf => Config -> zf -> IO ()
-runNamed cnf zf
+runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
+runNamed cnf findZone
     = withSocketsDo $
       do installHandler sigPIPE Ignore Nothing
          _tcpListenerTID <- forkIO $ tcpListen
@@ -57,7 +33,6 @@ runNamed cnf zf
       udpListen :: IO ()
       udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
                      so <- socket AF_INET Datagram defaultProtocol
-                     print cnf
                      bindSocket so $ cnfServerAddress cnf
                      udpLoop so
 
@@ -68,40 +43,223 @@ runNamed cnf zf
                udpLoop so
 
       tcpListen :: IO ()
-      tcpListen = putStrLn "FIXME: tcpListen is not implemented yet."
+      tcpListen = do so <- socket AF_INET Stream defaultProtocol
+                     bindSocket so $ cnfServerAddress cnf
+                     listen so 255
+                     tcpLoop so
+
+      tcpLoop :: Socket -> IO ()
+      tcpLoop so
+          = do (so', _)    <- accept so
+               h           <- socketToHandle so' ReadWriteMode
+               hSetBuffering h $ BlockBuffering Nothing
+               _handlerTID <- forkIO $ tcpHandler h
+               tcpLoop so
 
       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
       udpHandler so packet cameFrom
           = do msg   <- evaluate $ unpackMessage packet
                msg'  <- handleMessage msg
                         `onException`
-                        NB.sendTo so (packMessage $ makeServerFailure msg) cameFrom
-               _sent <- NB.sendTo so (packMessage $ msg'                 ) cameFrom
+                        do let servfail = mkErrorReply ServerFailure msg
+                           NB.sendTo so (packMessage (Just 512) servfail) cameFrom
+               _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
                return ()
 
+      tcpHandler :: Handle -> IO ()
+      tcpHandler h
+          = do lenB   <- LBS.hGet h 2
+               if LBS.length lenB < 2 then
+                   -- Got EOF
+                   hClose h
+                 else
+                   do let len = runGet getWord16be lenB
+                      packet <- BS.hGet h $ fromIntegral len
+                      msg    <- evaluate $ unpackMessage packet
+                      msg'   <- handleMessage msg
+                                `onException`
+                                do let servfail = mkErrorReply ServerFailure msg
+                                       packet'  = packMessage Nothing servfail
+                                       len'     = fromIntegral $ BS.length packet'
+                                   LBS.hPut h $ runPut $ putWord16be len'
+                                   BS.hPut h packet'
+                                   hClose h
+                      let packet' = packMessage Nothing msg'
+                          len'    = fromIntegral $ BS.length packet'
+                      LBS.hPut h $ runPut $ putWord16be len'
+                      BS.hPut h packet'
+                      hFlush h
+                      tcpHandler h
+
       handleMessage :: Message -> IO Message
       handleMessage msg
-          = fail (show msg) -- FIXME
+          = case validateQuery msg of
+              NoError
+                  -> do builders <- mapM handleQuestion $ msgQuestions msg
+
+                        let builder = foldl (>>) (return ()) builders
+                            msg'    = runBuilder msg builder
+
+                        return msg'
+
+              err -> return $ mkErrorReply err msg
+
+      handleQuestion :: SomeQ -> IO (Builder ())
+      handleQuestion (SomeQ q)
+          = do zoneM <- findZone (qName q)
+               case zoneM of
+                 Nothing
+                     -> return $ do unauthorise
+                                    setResponseCode Refused
+                 Just zone
+                     -> handleQuestionForZone (SomeQ q) zone
+
+      handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
+      handleQuestionForZone (SomeQ q) zone
+          | Just (qType q) == cast AXFR
+              = handleAXFR (SomeQ q) zone
+          | otherwise
+              = do allRecords <- zoneResponder zone (qName q)
+                   let filtered = filterRecords (SomeQ q) allRecords
+
+                   additionals <- do xss <- mapM (getAdditionals zone) filtered
+                                     ys  <- case zoneNSRecord zone of
+                                              Just rr -> getAdditionals zone rr
+                                              Nothing -> return []
+                                     return (concat xss ++ ys)
 
+                   return $ do mapM_ addAnswer filtered
 
-packMessage :: Message -> BS.ByteString
-packMessage = BS.concat . LBS.toChunks . encode
+                               when (qName q == zoneName zone) $
+                                    do when (Just (qType q) == cast SOA ||
+                                             Just (qType q) == cast ANY   )
+                                                $ case zoneSOARecord zone of
+                                                    Just rr -> addAnswer rr
+                                                    Nothing -> return ()
+
+                                       when (Just (qType q) == cast NS ||
+                                             Just (qType q) == cast ANY  )
+                                                $ case zoneNSRecord zone of
+                                                    Just rr -> addAnswer rr
+                                                    Nothing -> return ()
+
+                               mapM_ addAdditional additionals
+
+                               case zoneNSRecord zone of
+                                 Just rr -> addAuthority rr
+                                 Nothing -> unauthorise
+
+      getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
+      getAdditionals zone (SomeRR rr)
+          = case cast (rrData rr) :: Maybe DomainName of
+              Nothing
+                  -> return []
+              Just name
+                  -> do allRecords <- zoneResponder zone name
+
+                        let filtered = filterRecords (SomeQ q') allRecords
+                            q'       = Question {
+                                         qName  = name
+                                       , qType  = A
+                                       , qClass = IN
+                                       }
+                        return filtered
+
+      filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
+      filterRecords (SomeQ q)
+          | Just (qType  q) == cast ANY &&
+            Just (qClass q) == cast ANY    = id
+          | Just (qType  q) == cast ANY    = filter matchClass
+          | Just (qClass q) == cast ANY    = filter matchType
+          | otherwise                      = filter matchBoth
+          where
+            matchClass (SomeRR rr)
+                = Just (qClass q) == cast (rrClass rr)
+
+            matchType (SomeRR rr)
+                = Just (qType  q) == cast (rrType  rr) ||
+                  Just CNAME      == cast (rrType  rr)
+
+            matchBoth rr
+                = matchType rr && matchClass rr
+
+      handleAXFR :: SomeQ -> Zone -> IO (Builder ())
+      handleAXFR (SomeQ q) zone
+          | qName q == zoneName zone &&
+            isJust (zoneSOA zone)    &&
+            cnfAllowTransfer cnf
+              = do names      <- zoneRecordNames zone
+                   allRecords <- liftM concat $ mapM (zoneResponder zone) names
+                   return $ do addAnswer $ fromJust $ zoneSOARecord zone
+                               addAnswer $ fromJust $ zoneNSRecord  zone
+                               mapM_ addAnswer allRecords
+                               addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
+          | otherwise
+              = return $ return ()
+
+
+validateQuery :: Message -> ResponseCode
+validateQuery = validateHeader . msgHeader
+    where
+      validateHeader :: Header -> ResponseCode
+      validateHeader hdr
+          | hdMessageType hdr /= Query         = NotImplemented
+          | hdOpcode      hdr /= StandardQuery = NotImplemented
+          | otherwise                          = NoError
+
+
+packMessage :: Maybe Int -> Message -> BS.ByteString
+packMessage limM = BS.concat . LBS.toChunks . truncateMsg
+    where
+      truncateMsg :: Message -> LBS.ByteString
+      truncateMsg msg
+          = let packet    = encode msg
+                needTrunc = fromMaybe False $
+                            do lim <- limM
+                               return $ fromIntegral (LBS.length packet) > lim
+            in
+              if needTrunc then
+                  truncateMsg $ trunc' msg
+              else
+                  packet
+
+      trunc' :: Message -> Message
+      trunc' msg
+          | notNull $ msgAdditionals msg
+              = msg {
+                  msgAdditionals = truncList $ msgAdditionals msg
+                }
+          | notNull $ msgAuthorities msg
+              = msg {
+                  msgHeader      = setTruncFlag $ msgHeader msg
+                , msgAuthorities = truncList $ msgAuthorities msg
+                }
+          | notNull $ msgAnswers msg
+              = msg {
+                  msgHeader      = setTruncFlag $ msgHeader msg
+                , msgAnswers     = truncList $ msgAnswers msg
+                }
+          | notNull $ msgQuestions msg
+              = msg {
+                  msgHeader      = setTruncFlag $ msgHeader msg
+                , msgQuestions   = truncList $ msgQuestions msg
+                }
+          | otherwise
+              = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
+
+      setTruncFlag :: Header -> Header
+      setTruncFlag hdr = hdr { hdIsTruncated = True }
+
+      notNull :: [a] -> Bool
+      notNull = not . null
+
+      truncList :: [a] -> [a]
+      truncList xs = take (length xs - 1) xs
 
 unpackMessage :: BS.ByteString -> Message
 unpackMessage = decode . LBS.fromChunks . return
 
-
-makeServerFailure :: Message -> Message
-makeServerFailure msg
-    = let header = msgHeader msg
-          msg'   = msg {
-                     msgHeader = header {
-                                   hdMessageType           = Response
-                                 , hdIsAuthoritativeAnswer = False
-                                 , hdIsTruncated           = False
-                                 , hdIsRecursionAvailable  = False
-                                 , hdResponseCode          = ServerFailure
-                                 }
-                   }
-      in
-        msg'
+mkErrorReply :: ResponseCode -> Message -> Message
+mkErrorReply err msg
+    = runBuilder msg $ do unauthorise
+                          setResponseCode err