-- |
-- Module      : Data.ASN1.BinaryEncoding
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- A module containing ASN1 BER and DER specification encoding/decoding.
--
{-# LANGUAGE EmptyDataDecls #-}
module Data.ASN1.BinaryEncoding
    ( BER(..)
    , DER(..)
    ) where

import Data.ASN1.Stream
import Data.ASN1.Types
import Data.ASN1.Types.Lowlevel
import Data.ASN1.Error
import Data.ASN1.Encoding
import Data.ASN1.BinaryEncoding.Parse
import Data.ASN1.BinaryEncoding.Writer
import Data.ASN1.Prim
import qualified Control.Exception as E

-- | Basic Encoding Rules (BER)
data BER = BER

-- | Distinguished Encoding Rules (DER)
data DER = DER

instance ASN1DecodingRepr BER where
    decodeASN1Repr :: BER -> ByteString -> Either ASN1Error [ASN1Repr]
decodeASN1Repr _ lbs :: ByteString
lbs = (ASN1Header -> Maybe ASN1Error) -> [ASN1Event] -> [ASN1Repr]
decodeEventASN1Repr (Maybe ASN1Error -> ASN1Header -> Maybe ASN1Error
forall a b. a -> b -> a
const Maybe ASN1Error
forall a. Maybe a
Nothing) ([ASN1Event] -> [ASN1Repr])
-> Either ASN1Error [ASN1Event] -> Either ASN1Error [ASN1Repr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ByteString -> Either ASN1Error [ASN1Event]
parseLBS ByteString
lbs

instance ASN1Decoding BER where
    decodeASN1 :: BER -> ByteString -> Either ASN1Error [ASN1]
decodeASN1 _ lbs :: ByteString
lbs = ((ASN1Repr -> ASN1) -> [ASN1Repr] -> [ASN1]
forall a b. (a -> b) -> [a] -> [b]
map ASN1Repr -> ASN1
forall a b. (a, b) -> a
fst ([ASN1Repr] -> [ASN1])
-> ([ASN1Event] -> [ASN1Repr]) -> [ASN1Event] -> [ASN1]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ASN1Header -> Maybe ASN1Error) -> [ASN1Event] -> [ASN1Repr]
decodeEventASN1Repr (Maybe ASN1Error -> ASN1Header -> Maybe ASN1Error
forall a b. a -> b -> a
const Maybe ASN1Error
forall a. Maybe a
Nothing)) ([ASN1Event] -> [ASN1])
-> Either ASN1Error [ASN1Event] -> Either ASN1Error [ASN1]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ByteString -> Either ASN1Error [ASN1Event]
parseLBS ByteString
lbs

instance ASN1DecodingRepr DER where
    decodeASN1Repr :: DER -> ByteString -> Either ASN1Error [ASN1Repr]
decodeASN1Repr _ lbs :: ByteString
lbs = (ASN1Header -> Maybe ASN1Error) -> [ASN1Event] -> [ASN1Repr]
decodeEventASN1Repr ASN1Header -> Maybe ASN1Error
checkDER ([ASN1Event] -> [ASN1Repr])
-> Either ASN1Error [ASN1Event] -> Either ASN1Error [ASN1Repr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ByteString -> Either ASN1Error [ASN1Event]
parseLBS ByteString
lbs

instance ASN1Decoding DER where
    decodeASN1 :: DER -> ByteString -> Either ASN1Error [ASN1]
decodeASN1 _ lbs :: ByteString
lbs = ((ASN1Repr -> ASN1) -> [ASN1Repr] -> [ASN1]
forall a b. (a -> b) -> [a] -> [b]
map ASN1Repr -> ASN1
forall a b. (a, b) -> a
fst ([ASN1Repr] -> [ASN1])
-> ([ASN1Event] -> [ASN1Repr]) -> [ASN1Event] -> [ASN1]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ASN1Header -> Maybe ASN1Error) -> [ASN1Event] -> [ASN1Repr]
decodeEventASN1Repr ASN1Header -> Maybe ASN1Error
checkDER) ([ASN1Event] -> [ASN1])
-> Either ASN1Error [ASN1Event] -> Either ASN1Error [ASN1]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ByteString -> Either ASN1Error [ASN1Event]
parseLBS ByteString
lbs

instance ASN1Encoding DER where
    encodeASN1 :: DER -> [ASN1] -> ByteString
encodeASN1 _ l :: [ASN1]
l = [ASN1Event] -> ByteString
toLazyByteString ([ASN1Event] -> ByteString) -> [ASN1Event] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ASN1] -> [ASN1Event]
encodeToRaw [ASN1]
l

decodeConstruction :: ASN1Header -> ASN1ConstructionType
decodeConstruction :: ASN1Header -> ASN1ConstructionType
decodeConstruction (ASN1Header Universal 0x10 _ _) = ASN1ConstructionType
Sequence
decodeConstruction (ASN1Header Universal 0x11 _ _) = ASN1ConstructionType
Set
decodeConstruction (ASN1Header c :: ASN1Class
c t :: ASN1Tag
t _ _)            = ASN1Class -> ASN1Tag -> ASN1ConstructionType
Container ASN1Class
c ASN1Tag
t

decodeEventASN1Repr :: (ASN1Header -> Maybe ASN1Error) -> [ASN1Event] -> [ASN1Repr]
decodeEventASN1Repr :: (ASN1Header -> Maybe ASN1Error) -> [ASN1Event] -> [ASN1Repr]
decodeEventASN1Repr checkHeader :: ASN1Header -> Maybe ASN1Error
checkHeader l :: [ASN1Event]
l = [ASN1ConstructionType] -> [ASN1Event] -> [ASN1Repr]
loop [] [ASN1Event]
l
    where loop :: [ASN1ConstructionType] -> [ASN1Event] -> [ASN1Repr]
loop _ []     = []
          loop acc :: [ASN1ConstructionType]
acc (h :: ASN1Event
h@(Header hdr :: ASN1Header
hdr@(ASN1Header _ _ True _)):ConstructionBegin:xs :: [ASN1Event]
xs) =
                let ctype :: ASN1ConstructionType
ctype = ASN1Header -> ASN1ConstructionType
decodeConstruction ASN1Header
hdr in
                case ASN1Header -> Maybe ASN1Error
checkHeader ASN1Header
hdr of
                    Nothing  -> (ASN1ConstructionType -> ASN1
Start ASN1ConstructionType
ctype,[ASN1Event
h,ASN1Event
ConstructionBegin]) ASN1Repr -> [ASN1Repr] -> [ASN1Repr]
forall a. a -> [a] -> [a]
: [ASN1ConstructionType] -> [ASN1Event] -> [ASN1Repr]
loop (ASN1ConstructionType
ctypeASN1ConstructionType
-> [ASN1ConstructionType] -> [ASN1ConstructionType]
forall a. a -> [a] -> [a]
:[ASN1ConstructionType]
acc) [ASN1Event]
xs
                    Just err :: ASN1Error
err -> ASN1Error -> [ASN1Repr]
forall a e. Exception e => e -> a
E.throw ASN1Error
err
          loop acc :: [ASN1ConstructionType]
acc (h :: ASN1Event
h@(Header hdr :: ASN1Header
hdr@(ASN1Header _ _ False _)):p :: ASN1Event
p@(Primitive prim :: ByteString
prim):xs :: [ASN1Event]
xs) =
                case ASN1Header -> Maybe ASN1Error
checkHeader ASN1Header
hdr of
                    Nothing -> case ASN1Header -> ByteString -> ASN1Ret
decodePrimitive ASN1Header
hdr ByteString
prim of
                        Left err :: ASN1Error
err  -> ASN1Error -> [ASN1Repr]
forall a e. Exception e => e -> a
E.throw ASN1Error
err
                        Right obj :: ASN1
obj -> (ASN1
obj, [ASN1Event
h,ASN1Event
p]) ASN1Repr -> [ASN1Repr] -> [ASN1Repr]
forall a. a -> [a] -> [a]
: [ASN1ConstructionType] -> [ASN1Event] -> [ASN1Repr]
loop [ASN1ConstructionType]
acc [ASN1Event]
xs
                    Just err :: ASN1Error
err -> ASN1Error -> [ASN1Repr]
forall a e. Exception e => e -> a
E.throw ASN1Error
err
          loop (ctype :: ASN1ConstructionType
ctype:acc :: [ASN1ConstructionType]
acc) (ConstructionEnd:xs :: [ASN1Event]
xs) = (ASN1ConstructionType -> ASN1
End ASN1ConstructionType
ctype, [ASN1Event
ConstructionEnd]) ASN1Repr -> [ASN1Repr] -> [ASN1Repr]
forall a. a -> [a] -> [a]
: [ASN1ConstructionType] -> [ASN1Event] -> [ASN1Repr]
loop [ASN1ConstructionType]
acc [ASN1Event]
xs
          loop _ (x :: ASN1Event
x:_) = ASN1Error -> [ASN1Repr]
forall a e. Exception e => e -> a
E.throw (ASN1Error -> [ASN1Repr]) -> ASN1Error -> [ASN1Repr]
forall a b. (a -> b) -> a -> b
$ String -> ASN1Error
StreamUnexpectedSituation (ASN1Event -> String
forall a. Show a => a -> String
show ASN1Event
x)

-- | DER header need to be all of finite size and of minimum possible size.
checkDER :: ASN1Header -> Maybe ASN1Error
checkDER :: ASN1Header -> Maybe ASN1Error
checkDER (ASN1Header _ _ _ len :: ASN1Length
len) = ASN1Length -> Maybe ASN1Error
checkLength ASN1Length
len
    where checkLength :: ASN1Length -> Maybe ASN1Error
          checkLength :: ASN1Length -> Maybe ASN1Error
checkLength LenIndefinite = ASN1Error -> Maybe ASN1Error
forall a. a -> Maybe a
Just (ASN1Error -> Maybe ASN1Error) -> ASN1Error -> Maybe ASN1Error
forall a b. (a -> b) -> a -> b
$ String -> String -> ASN1Error
PolicyFailed "DER" "indefinite length not allowed"
          checkLength (LenShort _)  = Maybe ASN1Error
forall a. Maybe a
Nothing
          checkLength (LenLong n :: ASN1Tag
n i :: ASN1Tag
i)
              | ASN1Tag
n ASN1Tag -> ASN1Tag -> Bool
forall a. Eq a => a -> a -> Bool
== 1 Bool -> Bool -> Bool
&& ASN1Tag
i ASN1Tag -> ASN1Tag -> Bool
forall a. Ord a => a -> a -> Bool
< 0x80  = ASN1Error -> Maybe ASN1Error
forall a. a -> Maybe a
Just (ASN1Error -> Maybe ASN1Error) -> ASN1Error -> Maybe ASN1Error
forall a b. (a -> b) -> a -> b
$ String -> String -> ASN1Error
PolicyFailed "DER" "long length should be a short length"
              | ASN1Tag
n ASN1Tag -> ASN1Tag -> Bool
forall a. Eq a => a -> a -> Bool
== 1 Bool -> Bool -> Bool
&& ASN1Tag
i ASN1Tag -> ASN1Tag -> Bool
forall a. Ord a => a -> a -> Bool
>= 0x80 = Maybe ASN1Error
forall a. Maybe a
Nothing
              | Bool
otherwise           = if ASN1Tag
i ASN1Tag -> ASN1Tag -> Bool
forall a. Ord a => a -> a -> Bool
>= 2ASN1Tag -> ASN1Tag -> ASN1Tag
forall a b. (Num a, Integral b) => a -> b -> a
^((ASN1Tag
nASN1Tag -> ASN1Tag -> ASN1Tag
forall a. Num a => a -> a -> a
-1)ASN1Tag -> ASN1Tag -> ASN1Tag
forall a. Num a => a -> a -> a
*8) Bool -> Bool -> Bool
&& ASN1Tag
i ASN1Tag -> ASN1Tag -> Bool
forall a. Ord a => a -> a -> Bool
< 2ASN1Tag -> ASN1Tag -> ASN1Tag
forall a b. (Num a, Integral b) => a -> b -> a
^(ASN1Tag
nASN1Tag -> ASN1Tag -> ASN1Tag
forall a. Num a => a -> a -> a
*8)
                  then Maybe ASN1Error
forall a. Maybe a
Nothing
                  else ASN1Error -> Maybe ASN1Error
forall a. a -> Maybe a
Just (ASN1Error -> Maybe ASN1Error) -> ASN1Error -> Maybe ASN1Error
forall a b. (a -> b) -> a -> b
$ String -> String -> ASN1Error
PolicyFailed "DER" "long length is not shortest"

encodeToRaw :: [ASN1] -> [ASN1Event]
encodeToRaw :: [ASN1] -> [ASN1Event]
encodeToRaw = ((ASN1, [ASN1]) -> [ASN1Event]) -> [(ASN1, [ASN1])] -> [ASN1Event]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (ASN1, [ASN1]) -> [ASN1Event]
writeTree ([(ASN1, [ASN1])] -> [ASN1Event])
-> ([ASN1] -> [(ASN1, [ASN1])]) -> [ASN1] -> [ASN1Event]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ASN1] -> [(ASN1, [ASN1])]
mkTree
    where writeTree :: (ASN1, [ASN1]) -> [ASN1Event]
writeTree (p :: ASN1
p@(Start _),children :: [ASN1]
children) = (ASN1Tag, [ASN1Event]) -> [ASN1Event]
forall a b. (a, b) -> b
snd ((ASN1Tag, [ASN1Event]) -> [ASN1Event])
-> (ASN1Tag, [ASN1Event]) -> [ASN1Event]
forall a b. (a -> b) -> a -> b
$ ASN1 -> [ASN1] -> (ASN1Tag, [ASN1Event])
encodeConstructed ASN1
p [ASN1]
children
          writeTree (p :: ASN1
p,_)                  = (ASN1Tag, [ASN1Event]) -> [ASN1Event]
forall a b. (a, b) -> b
snd ((ASN1Tag, [ASN1Event]) -> [ASN1Event])
-> (ASN1Tag, [ASN1Event]) -> [ASN1Event]
forall a b. (a -> b) -> a -> b
$ ASN1 -> (ASN1Tag, [ASN1Event])
encodePrimitive ASN1
p

          mkTree :: [ASN1] -> [(ASN1, [ASN1])]
mkTree []           = []
          mkTree (x :: ASN1
x@(Start _):xs :: [ASN1]
xs) =
              let (tree :: [ASN1]
tree, r :: [ASN1]
r) = ASN1Tag -> [ASN1] -> ([ASN1], [ASN1])
spanEnd 0 [ASN1]
xs
               in (ASN1
x,[ASN1]
tree)(ASN1, [ASN1]) -> [(ASN1, [ASN1])] -> [(ASN1, [ASN1])]
forall a. a -> [a] -> [a]
:[ASN1] -> [(ASN1, [ASN1])]
mkTree [ASN1]
r
          mkTree (p :: ASN1
p:xs :: [ASN1]
xs)       = (ASN1
p,[]) (ASN1, [ASN1]) -> [(ASN1, [ASN1])] -> [(ASN1, [ASN1])]
forall a. a -> [a] -> [a]
: [ASN1] -> [(ASN1, [ASN1])]
mkTree [ASN1]
xs

          spanEnd :: Int -> [ASN1] -> ([ASN1], [ASN1])
          spanEnd :: ASN1Tag -> [ASN1] -> ([ASN1], [ASN1])
spanEnd _ []             = ([], [])
          spanEnd 0 (x :: ASN1
x@(End _):xs :: [ASN1]
xs) = ([ASN1
x], [ASN1]
xs)
          spanEnd lvl :: ASN1Tag
lvl (x :: ASN1
x:xs :: [ASN1]
xs)       = case ASN1
x of
                    Start _ -> let (ys :: [ASN1]
ys, zs :: [ASN1]
zs) = ASN1Tag -> [ASN1] -> ([ASN1], [ASN1])
spanEnd (ASN1Tag
lvlASN1Tag -> ASN1Tag -> ASN1Tag
forall a. Num a => a -> a -> a
+1) [ASN1]
xs in (ASN1
xASN1 -> [ASN1] -> [ASN1]
forall a. a -> [a] -> [a]
:[ASN1]
ys, [ASN1]
zs)
                    End _   -> let (ys :: [ASN1]
ys, zs :: [ASN1]
zs) = ASN1Tag -> [ASN1] -> ([ASN1], [ASN1])
spanEnd (ASN1Tag
lvlASN1Tag -> ASN1Tag -> ASN1Tag
forall a. Num a => a -> a -> a
-1) [ASN1]
xs in (ASN1
xASN1 -> [ASN1] -> [ASN1]
forall a. a -> [a] -> [a]
:[ASN1]
ys, [ASN1]
zs)
                    _       -> let (ys :: [ASN1]
ys, zs :: [ASN1]
zs) = ASN1Tag -> [ASN1] -> ([ASN1], [ASN1])
spanEnd ASN1Tag
lvl [ASN1]
xs in (ASN1
xASN1 -> [ASN1] -> [ASN1]
forall a. a -> [a] -> [a]
:[ASN1]
ys, [ASN1]
zs)