]> gitweb @ CieloNegro.org - haskell-dns.git/commitdiff
Make RecordClass a type class.
authorPHO <pho@cielonegro.org>
Sat, 23 May 2009 04:13:39 +0000 (13:13 +0900)
committerPHO <pho@cielonegro.org>
Sat, 23 May 2009 04:13:39 +0000 (13:13 +0900)
Network/DNS/Message.hs

index 570548ced67f03e55c1cbe42466bfb505b641e0c..48bda1d1043bbf43ded6cd074d3b3d5290c38af0 100644 (file)
@@ -11,15 +11,13 @@ module Network.DNS.Message
     , DomainLabel
     , TTL
     , RecordType
-    , RecordClass(..)
+    , RecordClass
 
     , SOAFields(..)
     , WKSFields(..)
 
     , SomeQ
-    , SomeQT
     , SomeRR
-    , SomeRT
 
     , A(..)
     , NS(..)
@@ -43,6 +41,11 @@ module Network.DNS.Message
     , MAILA(..)
     , ANY(..)
 
+    , IN(..)
+    , CS(..)
+    , CH(..)
+    , HS(..)
+
     , mkDomainName
     , wrapQuestion
     , wrapRecord
@@ -123,25 +126,25 @@ data ResponseCode
     | Refused
     deriving (Show, Eq)
 
-data QueryType qt => Question qt
+data (QueryType qt, QueryClass qc) => Question qt qc
     = Question {
         qName  :: !DomainName
       , qType  :: !qt
-      , qClass :: !RecordClass
+      , qClass :: !qc
       }
     deriving (Typeable)
 
-instance QueryType qt => Show (Question qt) where
+instance (QueryType qt, QueryClass qc) => Show (Question qt qc) where
     show q = "Question { qName = " ++ show (qName q) ++
              ", qType = " ++ show (qType q) ++
              ", qClass = " ++ show (qClass q) ++ " }"
 
-instance QueryType qt => Eq (Question qt) where
+instance (QueryType qt, QueryClass qc) => Eq (Question qt qc) where
     a == b = qName  a == qName  b &&
              qType  a == qType  b &&
              qClass a == qClass b
 
-data SomeQ = forall qt. QueryType qt => SomeQ (Question qt)
+data SomeQ = forall qt qc. (QueryType qt, QueryClass qc) => SomeQ (Question qt qc)
 
 instance Show SomeQ where
     show (SomeQ q) = show q
@@ -157,24 +160,32 @@ instance Show SomeQT where
 instance Eq SomeQT where
     (SomeQT a) == (SomeQT b) = Just a == cast b
 
+data SomeQC = forall qc. QueryClass qc => SomeQC qc
+
+instance Show SomeQC where
+    show (SomeQC qc) = show qc
+
+instance Eq SomeQC where
+    (SomeQC a) == (SomeQC b) = Just a == cast b
+
 putSomeQ :: SomeQ -> Packer CompTable ()
 putSomeQ (SomeQ q)
     = do putDomainName $ qName q
-         putQueryType $ qType q
-         putBinary $ qClass q
+         putQueryType  $ qType q
+         putQueryClass $ qClass q
 
 getSomeQ :: Unpacker DecompTable SomeQ
 getSomeQ
     = do nm <- getDomainName
          ty <- getSomeQT
-         cl <- getBinary
-         case ty of
-           SomeQT qt -> return $ SomeQ $
-                        Question {
-                          qName  = nm
-                        , qType  = qt
-                        , qClass = cl
-                        }
+         cl <- getSomeQC
+         case (ty, cl) of
+           (SomeQT qt, SomeQC qc)
+               -> return $ SomeQ $ Question {
+                         qName  = nm
+                       , qType  = qt
+                       , qClass = qc
+                       }
 
 getSomeQT :: Unpacker s SomeQT
 getSomeQT = do n <- liftM fromIntegral U.getWord16be
@@ -184,6 +195,14 @@ getSomeQT = do n <- liftM fromIntegral U.getWord16be
                  Nothing
                      -> fail ("Unknown query type: " ++ show n)
 
+getSomeQC :: Unpacker s SomeQC
+getSomeQC = do n <- liftM fromIntegral U.getWord16be
+               case IM.lookup n defaultQCTable of
+                 Just sqc
+                     -> return sqc
+                 Nothing
+                     -> fail ("Unknown query class: " ++ show n)
+
 
 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
 type DomainLabel    = BS.ByteString
@@ -215,27 +234,26 @@ mkDomainName = DN . mkLabels [] . notEmpty
                                 -> mkLabels (C8.pack l : soFar) rest
                             _   -> error ("Illegal domain name: " ++ xs)
 
-data RecordClass
-    = IN
-    | CS -- Obsolete
-    | CH
-    | HS
-    | AnyClass -- Only for queries
-    deriving (Show, Eq)
 
+class (Show rc, Eq rc, Typeable rc) => RecordClass rc where
+    rcToInt :: rc -> Int
+
+    putRecordClass :: rc -> Packer s ()
+    putRecordClass = P.putWord16be . fromIntegral . rcToInt
 
-data RecordType rt dt => ResourceRecord rt dt
+
+data (RecordType rt dt, RecordClass rc) => ResourceRecord rt rc dt
     = ResourceRecord {
         rrName  :: !DomainName
       , rrType  :: !rt
-      , rrClass :: !RecordClass
+      , rrClass :: !rc
       , rrTTL   :: !TTL
       , rrData  :: !dt
       }
     deriving (Show, Eq, Typeable)
 
 
-data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
+data SomeRR = forall rt rc dt. (RecordType rt dt, RecordClass rc) => SomeRR (ResourceRecord rt rc dt)
 
 instance Show SomeRR where
     show (SomeRR rr) = show rr
@@ -248,12 +266,14 @@ putSomeRR :: SomeRR -> Packer CompTable ()
 putSomeRR (SomeRR rr) = putResourceRecord rr
 
 getSomeRR :: Unpacker DecompTable SomeRR
-getSomeRR = do srt <- U.lookAhead $
-                      do getDomainName -- skip
-                         getSomeRT
-               case srt of
-                 SomeRT rt
-                     -> getResourceRecord rt >>= return . SomeRR
+getSomeRR = do (srt, src) <- U.lookAhead $
+                             do getDomainName -- skip
+                                srt <- getSomeRT
+                                src <- getSomeRC
+                                return (srt, src)
+               case (srt, src) of
+                 (SomeRT rt, SomeRC rc)
+                     -> getResourceRecord rt rc >>= return . SomeRR
 
 type CompTable   = Map DomainName Int
 type DecompTable = IntMap DomainName
@@ -353,6 +373,16 @@ class (Show qt, Eq qt, Typeable qt) => QueryType qt where
 instance RecordType rt dt => QueryType rt where
     qtToInt = rtToInt
 
+class (Show qc, Eq qc, Typeable qc) => QueryClass qc where
+    qcToInt :: qc -> Int
+
+    putQueryClass :: qc -> Packer s ()
+    putQueryClass = P.putWord16be . fromIntegral . qcToInt
+
+instance RecordClass rc => QueryClass rc where
+    qcToInt = rcToInt
+
+
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
     putRecordData :: rt -> dt -> Packer CompTable ()
@@ -380,12 +410,12 @@ class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType r
                  fail ("putRecordData " ++ show rt ++ " wrote " ++ show len
                        ++ " bytes, which is way too long")
 
-    putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
+    putResourceRecord :: RecordClass rc => ResourceRecord rt rc dt -> Packer CompTable ()
     putResourceRecord rr
-        = do putDomainName $ rrName  rr
-             putRecordType $ rrType  rr
-             putBinary     $ rrClass rr
-             P.putWord32be $ rrTTL   rr
+        = do putDomainName  $ rrName  rr
+             putRecordType  $ rrType  rr
+             putRecordClass $ rrClass rr
+             P.putWord32be  $ rrTTL   rr
              putRecordDataWithLength (rrType rr) (rrData rr)
 
     getRecordDataWithLength :: rt -> Unpacker DecompTable dt
@@ -402,21 +432,22 @@ class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType r
 
              return dat
 
-    getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
-    getResourceRecord rt
+    getResourceRecord :: RecordClass rc => rt -> rc -> Unpacker DecompTable (ResourceRecord rt rc dt)
+    getResourceRecord rt rc
         = do name     <- getDomainName
              U.skip 2 -- record type
-             cl       <- getBinary
+             U.skip 2 -- record class
              ttl      <- U.getWord32be
              dat      <- getRecordDataWithLength rt
              return $ ResourceRecord {
                           rrName  = name
                         , rrType  = rt
-                        , rrClass = cl
+                        , rrClass = rc
                         , rrTTL   = ttl
                         , rrData  = dat
                         }
 
+
 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
 
 instance Show SomeRT where
@@ -433,6 +464,22 @@ getSomeRT = do n <- liftM fromIntegral U.getWord16be
                  Just srt
                      -> return srt
 
+data SomeRC = forall rc. RecordClass rc => SomeRC rc
+
+instance Show SomeRC where
+    show (SomeRC rc) = show rc
+
+instance Eq SomeRC where
+    (SomeRC a) == (SomeRC b) = Just a == cast b
+
+getSomeRC :: Unpacker s SomeRC
+getSomeRC = do n <- liftM fromIntegral U.getWord16be
+               case IM.lookup n defaultRCTable of
+                 Nothing
+                     -> fail ("Unknown resource record class: " ++ show n)
+                 Just src
+                     -> return src
+
 
 data SOAFields
     = SOAFields {
@@ -653,6 +700,24 @@ instance QueryType MAILA where
 data ANY = ANY deriving (Show, Eq, Typeable)
 instance QueryType ANY where
     qtToInt _ = 255
+instance QueryClass ANY where
+    qcToInt _ = 255
+
+data IN = IN deriving (Show, Eq, Typeable)
+instance RecordClass IN where
+    rcToInt _ = 1
+
+data CS = CS deriving (Show, Eq, Typeable)
+instance RecordClass CS where
+    rcToInt _ = 2
+
+data CH = CH deriving (Show, Eq, Typeable)
+instance RecordClass CH where
+    rcToInt _ = 3
+
+data HS = HS deriving (Show, Eq, Typeable)
+instance RecordClass HS where
+    rcToInt _ = 4
 
 
 instance Binary Message where
@@ -757,55 +822,36 @@ instance Enum ResponseCode where
     toEnum 5 = Refused
     toEnum _ = undefined
 
-instance Enum RecordClass where
-    fromEnum IN       = 1
-    fromEnum CS       = 2
-    fromEnum CH       = 3
-    fromEnum HS       = 4
-    fromEnum AnyClass = 255
-
-    toEnum 1   = IN
-    toEnum 2   = CS
-    toEnum 3   = CH
-    toEnum 4   = HS
-    toEnum 255 = AnyClass
-    toEnum _   = undefined
-
-instance Binary RecordClass where
-    get = liftM (toEnum . fromIntegral) G.getWord16be
-    put = P'.putWord16be . fromIntegral . fromEnum
-
 
 defaultRTTable :: IntMap SomeRT
 defaultRTTable = IM.fromList $ map toPair $
-                 [ wrapRecordType A
-                 , wrapRecordType NS
-                 , wrapRecordType MD
-                 , wrapRecordType MF
-                 , wrapRecordType CNAME
-                 , wrapRecordType SOA
-                 , wrapRecordType MB
-                 , wrapRecordType MG
-                 , wrapRecordType MR
-                 , wrapRecordType NULL
-                 , wrapRecordType WKS
-                 , wrapRecordType PTR
-                 , wrapRecordType HINFO
-                 , wrapRecordType MINFO
-                 , wrapRecordType MX
-                 , wrapRecordType TXT
+                 [ SomeRT A
+                 , SomeRT NS
+                 , SomeRT MD
+                 , SomeRT MF
+                 , SomeRT CNAME
+                 , SomeRT SOA
+                 , SomeRT MB
+                 , SomeRT MG
+                 , SomeRT MR
+                 , SomeRT NULL
+                 , SomeRT WKS
+                 , SomeRT PTR
+                 , SomeRT HINFO
+                 , SomeRT MINFO
+                 , SomeRT MX
+                 , SomeRT TXT
                  ]
     where
       toPair :: SomeRT -> (Int, SomeRT)
       toPair srt@(SomeRT rt) = (rtToInt rt, srt)
 
-
 defaultQTTable :: IntMap SomeQT
 defaultQTTable = mergeWithRTTable defaultRTTable $ IM.fromList $ map toPair $
-                 [ wrapQueryType AXFR
-                 , wrapQueryType MAILB
-                 , wrapQueryType MAILA
-                 , wrapQueryType ANY
+                 [ SomeQT AXFR
+                 , SomeQT MAILB
+                 , SomeQT MAILA
+                 , SomeQT ANY
                  ]
     where
       toPair :: SomeQT -> (Int, SomeQT)
@@ -821,15 +867,38 @@ defaultQTTable = mergeWithRTTable defaultRTTable $ IM.fromList $ map toPair $
       toSomeQT :: SomeRT -> SomeQT
       toSomeQT (SomeRT rt) = SomeQT rt
 
+defaultRCTable :: IntMap SomeRC
+defaultRCTable = IM.fromList $ map toPair $
+                 [ SomeRC IN
+                 , SomeRC CS
+                 , SomeRC CH
+                 , SomeRC HS
+                 ]
+    where
+      toPair :: SomeRC -> (Int, SomeRC)
+      toPair src@(SomeRC rc) = (rcToInt rc, src)
+
+defaultQCTable :: IntMap SomeQC
+defaultQCTable = mergeWithRCTable defaultRCTable $ IM.fromList $ map toPair $
+                 [ SomeQC ANY
+                 ]
+    where
+      toPair :: SomeQC -> (Int, SomeQC)
+      toPair sqc@(SomeQC qc) = (qcToInt qc, sqc)
+
+      mergeWithRCTable :: IntMap SomeRC -> IntMap SomeQC -> IntMap SomeQC
+      mergeWithRCTable rcs qcs
+          = IM.union (toQCTable rcs) qcs
+
+      toQCTable :: IntMap SomeRC -> IntMap SomeQC
+      toQCTable = IM.map toSomeQC
 
-wrapQueryType :: QueryType qt => qt -> SomeQT
-wrapQueryType = SomeQT
+      toSomeQC :: SomeRC -> SomeQC
+      toSomeQC (SomeRC rc) = SomeQC rc
 
-wrapRecordType :: RecordType rt dt => rt -> SomeRT
-wrapRecordType = SomeRT
 
-wrapQuestion :: QueryType qt => Question qt -> SomeQ
+wrapQuestion :: (QueryType qt, QueryClass qc) => Question qt qc -> SomeQ
 wrapQuestion = SomeQ
 
-wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
+wrapRecord :: (RecordType rt dt, RecordClass rc) => ResourceRecord rt rc dt -> SomeRR
 wrapRecord = SomeRR
\ No newline at end of file