DEV Community

Cover image for OCaml GADTs for Authentication Tokens
Maxim Grankin
Maxim Grankin

Posted on

OCaml GADTs for Authentication Tokens

In this article, I will present and explain a real-world usage of Generalized Algebraic Data Types (GADTs). GADTs will be used for safely parsing complicated objects with a schema, for example, authentication tokens like JWTs. This approach ensures type safety and robustness when dealing with complex token schemas, making your libraries more reliable and less error-prone.

Before reading it, please make sure that you are familiar with GADTs as here I explain specific use case, but not GADTs themselves.

In this article, I will focus on JWT, but the technique is applicable to any format.

Authentication tokens 🛡️

Feel free to skip to next section if you are comfortable with JWTs

There are multiple different types of authentication tokens formats in the wild. Some popular ones are JWT and SAML, but probably every big tech company has its own implementation of authentication format internally.

It is important to note that while JWTs are widely used, they can be controversial due to potential misuse. Since anyone can decode a JWT, it is not recommended to send private data via JWTs. However, they can be effective in a microservice architecture where JWTs are used internally between services, ensuring that nothing is publicly exposed.

What is JWT?

JSON Web Tokens are an open, industry standard RFC 7519 method for representing claims securely between two parties.

Encoded token example would look like:

eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6Ik9DYW1sIPCfkKshIiwibmFtZSI6Ik1heGltIEdyYW5raW4iLCJzdWIiOiJtYXhpbV9ncmFua2luX2lkZW50aWZpZXIiLCJhZ2UiOjIzfQ.DNlMGsF5HJ43gi4j5E-4y-YSs5osAMP-ps7jLsiG6jE
Enter fullscreen mode Exit fullscreen mode

General idea is to encode data into token, for example user age, localisaton, identifier in your system, company name they are registered with etc. Each piece of information is called claim in JWT.

Payload of JWT above is:

{
  "language": "OCaml 🐫!",
  "name": "Maxim Grankin",
  "sub": "maxim_grankin_identifier",
  "age": 23
}
Enter fullscreen mode Exit fullscreen mode

I don't think that my name should be considered as a private data, but do not encode your users name in JWT!

Here, key-value pair "language": "OCaml 🐫!" is a claim as well as any other key-value pair in payload.

You can also take a look on debugger in jwt.io to verify that payload is the one I stated above.

In real world, JWT claims usually follow some format/schema/strong type representation with arbitrary complexity. In our imaginatory example we assume that our JWT has list of claims where each claim is represented via JSON schema.

Let's take a look on a simple JSON:

{
  "user_info": {
    "name": "Maxim Grankin",
    "system_id": "my_unique_identifier",
    "age": 23
  },
  "caller_info": {
    "language": "OCaml 🐫!",
    "from": "Web UI"
  },
  "localisation": {
    "language": "Russian",
    "timezone": "Etc/UTC"
  },
  "iat": "2024-11-14T08:15:30-05:00",
  "sub": "subject_identifier"
}
Enter fullscreen mode Exit fullscreen mode

And JSON schema for it would look like:

{
  "$schema": "https://example.com/jwt.schema",
  "$id": "https://example.com/jwt.schema.json",
  "title": "JWT OCaml example",
  "description": "JWT OCaml example using GADTs",
  "type": "object",
  "properties": {
    "user_info": {
      "type": "object",
      "properties": {
        "name": {
          "type": "string"
        },
        "system_id": {
          "type": "string"
        },
        "age": {
          "type": "integer"
        }
      },
      "required": ["name", "system_id"]
    },
    "caller_info": {
      "type": "object",
      "properties": {
        "language": {
          "type": "string"
        },
        "from": {
          "type": "string"
        }
      },
      "required": ["language", "from"]
    },
    "localisation": {
      "type": "object",
      "properties": {
        "language": {
          "type": "string"
        },
        "timezone": {
          "type": "string"
        }
      },
      "required": ["language", "timezone"]
    },
    "iat": {
      "type": "string",
      "format": "date-time"
    },
    "sub": {
      "type": "string"
    }
  },
  "required": ["iat", "sub"]
}
Enter fullscreen mode Exit fullscreen mode

In reality, it makes more sense to represent claims as other JSON schemas for maintenance purposes, so schema would actually look like:

{
  "$schema": "https://example.com/jwt.schema",
  "$id": "https://example.com/jwt.schema.json",
  "title": "JWT OCaml example",
  "description": "JWT OCaml example using GADTs",
  "type": "object",
  "properties": {
    "user_info": { "$ref": "/schemas/user_info" },
    "caller_info": { "$ref": "/schemas/caller_info" },
    "localisation": { "$ref": "/schemas/localisation" },
    "iat": {
      "type": "string",
      "format": "date-time"
    },
    "sub": {
      "type": "string"
    }
  },
  "required": ["iat", "sub", "user_info"]
}
Enter fullscreen mode Exit fullscreen mode

Ok, so how to decode JWT? 🤔

As we now have common context understanding, let's start simple and create some types for token that needs to be decoded:

(* Header that contains some metadata *)
type jwt_header = {...}

type jwt_token = {
  header : jwt_header;
  data_encoding : string;
  data : string;
}
Enter fullscreen mode Exit fullscreen mode

As a client, you would like to have a more convenient API to get a claim from jwt_token. Currently it would require you to decode data string, parse it and go through all claims to find one needed. Let's improve our jwt_token type to make it more structured by adding types to represent claims:

module Map_string = Map.Make (String)
type jwt_header = {...}

type jwt_schema = {
  name : string;
  version : string;
}

type jwt_claim = {
  data : string;
  schema : jwt_schema;
}

type jwt_data = {
  encoding : string;
  claims : jwt_claim Map_string.t;
}

type jwt_token = {
  header : jwt_header;
  data : jwt_data;
}
Enter fullscreen mode Exit fullscreen mode

Great, now having a token client can find a claim, for example localisation, and decode it using encoding provided in data!

val retrieve_claim :
  claim_key:string ->
  jwt_data:jwt_data ->
  (string, jwt_schema) option
Enter fullscreen mode Exit fullscreen mode

Implementation is pretty straightforward:

let retrieve_claim ~claim_key ~jwt_data =
  let claim_data = Map_string.find_opt claim_key jwt_data.claims in
  claim_data |>
  Option.map (fun cd ->
    (Decoder.decode_by_encoding ~encoding:jwt_data.encoding cd.data, cd.schema))
Enter fullscreen mode Exit fullscreen mode

Now our client can get raw claim data and a schema to decode it! But how can we make their life easier?

If there are not so many different schemas, I'd just go with creating function for each claim type that would do decoding under the hood:

val retrieve_localisation :
  jwt_data -> localisation option

val retrieve_user_info :
  jwt_data -> user_info option

...
Enter fullscreen mode Exit fullscreen mode

And in most cases this approach is great. No need to read below, no need to think about GADTs, no need to overcomplicate things! Done ✅!

But there are situations where number of unique schemas may grow over 50/100/1000 and it constantly grows together with your product. Implementing and maintaining a thousand of boilerplate claim extractor functions sounds like a lot to me. It would be very easy to make typo, miss it and get error in runtime.

Boilerplate examples
let retrieve_localisation jwt_data =
  let localisation_data =
    Map_string.find_opt "localisation" jwt_data.claims
  in
  localisation_data
  |> Option.map (fun cd -> Decoder.decode_localisation cd.data)

let retrieve_caller_info jwt_data =
  let caller_info_data =
    Map_string.find_opt "caller_info" jwt_data.claims
  in
  caller_info_data
  |> Option.map (fun cd -> Decoder.decode_claim_data cd.data)

let retrieve_user_info jwt_data =
  let user_info_data =
    Map_string.find_opt "user_info" jwt_data.claims
  in
  user_info_data
  |> Option.map (fun cd -> Decoder.decode_user_info cd.data)

let retrieve_user_id jwt_data =
  let user_id_data =
    Map_string.find_opt "user_id" jwt_data.claims
  in
  user_id_data
  |> Option.map (fun cd -> Decoder.decode_user_id cd.data)
...
Enter fullscreen mode Exit fullscreen mode

Those functions are pretty similar, you can create a generic one to replace some code in them, but boilerplate would still stay:

let retrieve ~jwt_data ~key ~decoder_fun =
   let data = 
     Map_string.find_opt key ~jwt_data.claims 
   in
   data
   |> Option.map (fun cd -> decoder_fun cd.data)


let retrieve_localisation jwt_data =
  retrieve ~jwt_data ~key:"localisation" ~decoder_fun:Decode.decode_localisation

let retrieve_caller_info jwt_data =
  retrieve ~jwt_data ~key:"caller_info" ~decoder_fun:Decode.decode_caller_info

let retrieve_user_info jwt_data =
  retrieve ~jwt_data ~key:"user_info" ~decoder_fun:Decode.decode_user_info

let retrieve_user_id jwt_data =
  retrieve ~jwt_data ~key:"user_id" ~decoder_fun:Decode.decode_user_id

...
Enter fullscreen mode Exit fullscreen mode

And this is the classic example of copy-pasted-forgot-to-change situation, isn't it?


Powerful OCaml features can help you to make your library API to be consice and extensible.

GADTs attack JWT decoding ⚔️

The main function of our API should be retrieving data from JWT token. Our client would also like to get only data that they need, for example, localisation module/service doesn't need to know user age.

Proposed API looks like:

(* Decode JWT represented as string. *)
val decode_jwt :
  jwt:string ->
  (jwt_token, [> error]) Result.t

(* Retrieve claim *)
val retrieve_claim :
  claim_typ:'a claim_typ ->
  raw_data:jwt_data ->
  ('a, [> error]) Result.t
Enter fullscreen mode Exit fullscreen mode

decode_jwt is not that interesting, so it won't be covered here. I'll just assume that it is given to us.

Here I introduced multiple things:

  • Replaced option with Result.t to pass detailed error to a client.
    • [> error] syntax may look uncommon, but it is just usage of polymorphic variants for errors. I will not cover it here, but I highly recommend to go through Composable Error Handling in OCaml to understand this approach.
  • I introduced 'a claim_typ, but what is it? Well, claim_typ is GADT that represents exactly one type, our claim! (some may refer to it as a singleton type)! In our case it may look like:
type _ claim_typ =
  | Caller_info : caller_info_claim claim_typ
  | Localisation : localisation_claim claim_typ
  | User_info : user_info_claim claim_typ
Enter fullscreen mode Exit fullscreen mode

Each claim obviously has its type, which is either generated from JSON schema or written manually. I'll define one type manually and use yojson to encode/decode it.

(* Inside user_info.mli *)
type user_info_claim = {
  name : string; [@key "name"];
  system_id : string; [@key "system_id"];
  age :
    (int option[@key "age"] [@default None]);
}
[@@deriving yojson {strict = false}]
Enter fullscreen mode Exit fullscreen mode

As said above, claim_typ is GADT representing exactly one type which means it actually can serve us as a runtime representation for claim type. This allows us to write retrieve_claim that is generic for any claim type:

let retrieve_claim :
    type a.
    claim_typ:a claim_typ ->
    claims:jwt_claim Map_string.t ->
    (a, [> error]) Result.t =
 fun ~claim_typ ~claims ->
   let claim_name = claim_name_of_claim_typ claim_name in
   let claim_result =
     claims
     |> Map_string.find_opt claim_name
     |> Option.to_result ~none:(`ClaimNotFoundError (Claim_type claim_typ))
   in
   match claim_typ with
   | Caller_info ->
     Result.bind claim_result (fun claim ->
         Caller_info_claim.of_raw_string claim.data )
   | Localisation ->
     Result.bind claim_result (fun claim ->
         Localisation_claim.of_raw_string claim.data )
   | User_info ->
     Result.bind claim_result (fun claim ->
         User_info_claim.of_raw_string claim.data )
Enter fullscreen mode Exit fullscreen mode

On the client side workflow would look like:

let print_user_name =
  let ( let* ) = Result.bind in
  let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6Ik9DYW1sIPCfkKshIiwibmFtZSI6Ik1heGltIEdyYW5raW4iLCJzdWIiOiJtYXhpbV9ncmFua2luX2lkZW50aWZpZXIiLCJhZ2UiOjIzfQ.DNlMGsF5HJ43gi4j5E-4y-YSs5osAMP-ps7jLsiG6jE" in
  let* jwt_token = decode_jwt ~jwt in
  let* user_info = retreieve_claim ~claim_typ:User_info ~raw_data:jwt_token.data in
  print_endline user_info.name
Enter fullscreen mode Exit fullscreen mode

Comparison and Extensibility ⚖️

Let's take a look at how the GADT approach makes the API extensible, type-safe, and less error-prone compared to non-GADT:

Add new claim 🌚

Non-GADT

(* New type inside user_info.mli *)
type user_info_claim = {
  name : string; [@key "name"];
  system_id : string; [@key "system_id"];
  age : (int option[@key "age"] [@default None]);
}
[@@deriving yojson {strict = false}]
Enter fullscreen mode Exit fullscreen mode

GADT

(* New type inside user_info.mli*)
type user_info_claim = {
  name : string; [@key "name"];
  system_id : string; [@key "system_id"];
  age : (int option[@key "age"] [@default None]);
}
[@@deriving yojson {strict = false}]

(* Inside claims.mli *)
type _ claim_typ =
  | Caller_info : caller_info_claim claim_typ
  | Localisation : localisation_claim claim_typ
  (* New constructor *)
  | User_info : user_info_claim claim_typ
Enter fullscreen mode Exit fullscreen mode

Retrieve claim 🔍

Non-GADT

let retrieve_localisation jwt_data =
  let localisation_data =
    Map_string.find_opt "localisation" jwt_data.claims
  in
  localisation_data
  |> Option.map (fun cd -> Localisation_claim.of_raw_string cd.data)

let retrieve_caller_info jwt_data =
  let caller_info_data =
    Map_string.find_opt "caller_info" jwt_data.claims
  in
  caller_info_data
  |> Option.map (fun cd -> Caller_info_claim.of_raw_string cd.data)

(* New function *)
let retrieve_user_info jwt_data =
  let user_info_data =
    Map_string.find_opt "user_info" jwt_data.claims
  in
  user_info_data
  |> Option.map (fun cd -> User_info_claim.of_raw_string cd.data)
Enter fullscreen mode Exit fullscreen mode

GADT

let retrieve_claim :
    type a.
    claim_typ:a claim_typ ->
    claims:jwt_claim Map_string.t ->
    (a, [> error]) result =
fun ~claim_typ ~claims ->
  let claim_name = claim_name_of_claim_typ claim_name in
  let claim_result =
    claims
    |> Map_string.find_opt claim_name
    |> Option.to_result ~none:(`ClaimNotFoundError (Claim_type claim_typ))
  in
  match claim_typ with
  | Caller_info ->
    Result.bind claim_result (fun claim ->
        Caller_info_claim.of_raw_string claim.data )
  | Localisation ->
    Result.bind claim_result (fun claim ->
        Localisation_claim.of_raw_string claim.data )
  (* New claim typ *)
  | User_info ->
    Result.bind claim_result (fun claim ->
        User_info_claim.of_raw_string claim.data )
Enter fullscreen mode Exit fullscreen mode

Conclusion 🚀

  1. The GADT-based API is smaller, meaning that it is much easier to reason about. There are no problems with naming conventions, and it is much easier to maintain. Also, clients need to use only two functions in their codebase.
  2. The GADT-based API is type-safe and less prone to errors.

    With GADTs, if you add a new claim_typ and try to return wrong type your code will fail to compile with error:

    Error: This expression has type caller_info_claim
       but an expression was expected of type a = localisation_claim
    

    In a non-GADT solution you don't have such guarantees. It is also easier to make typos, instead of:

    val retrieve_user_info : jwt_data -> user_info option
    

    You can accidentally write:

    val retrieve_user_language : jwt_data -> user_info option
    

    It will be hard to spot such typo, and the only possible chance to do it is during code review.

  3. The GADT-based API forces handling of new claims:

    • In order to support a new claim, you have to add a new constructor.
    • When you add a new constructor, you are now forced to handle it. Moreover, you are forced to return the expected type for your specific claim; you cannot accidentally return user_info instead of localisation, it will fail to compile as shown above!

More about GADTs 📚

If you would like to learn more about GADTs, I'd recommend to take a look on posts from others:

Subscribe 🔔

If you liked this post, consider following me on social networks and GitHub:

Top comments (1)

Collapse
 
chshersh profile image
Dmitrii Kovanikov

Really nice and pragmatic use case for GADTs in OCaml!

We need more of such content.