DEV Community

Cover image for Deploy a Cognito Secured WebSocket API with AWS CDK
Wesley Cheek
Wesley Cheek

Posted on • Edited on

Deploy a Cognito Secured WebSocket API with AWS CDK

API Gateway is great for knitting together serverless application, but have you ever wanted a little more? WebSocket APIs offer many benefits by allowing active two-way communication between your clients and backend. Consider a WebSocket API when building real-time apps.

This post will show you how to deploy a WebSocket API that is secured using AWS Cognito. Clients can connect to the WebSocket by using the idToken retrieved from Cognito after a successful login. This post is heavily inspired by this AWS sample and this dev.to article. I have simplified the code and refactored it a bit for my own purposes and understanding.

Let's start by creating our Auth backend. We create a Cognito UserPool and a Cognito UserPool Client:

import {
  aws_cognito as cognito,
  RemovalPolicy,
} from "aws-cdk-lib";
import { Construct } from "constructs";

export class Auth extends Construct {
  userPool: cognito.UserPool;
  userPoolClient: cognito.UserPoolClient;

  constructor(scope: Construct, id: string) {
    super(scope, id);
    this.userPool = this._userPool();
    this.userPoolClient = this._userPoolClient(this.userPool);

  }
  _userPool(): cognito.UserPool {
    let userPool = new cognito.UserPool(this, "userPool", {
      removalPolicy: RemovalPolicy.DESTROY,
    });
    return userPool;
  }

  _userPoolClient(userPool: cognito.UserPool): cognito.UserPoolClient {
    let userPoolClient = new cognito.UserPoolClient(this, "userPoolClient", {
      userPool: userPool,
      authFlows: { userSrp: true, userPassword: true },

    });
    return userPoolClient;
  }
}
Enter fullscreen mode Exit fullscreen mode

Now let's use these to associate the WebSocket $connect route with a Custom Authorizer Lambda. Follow along with the comments

import {
  aws_cognito as cognito,
  aws_apigatewayv2 as apigatewayv2,
  aws_apigatewayv2_integrations as apigatewayv2_integrations,
  aws_apigatewayv2_authorizers as apigatewayv2_auth,
  aws_lambda_nodejs as lambdaNode,
  aws_lambda as lambda,
  aws_dynamodb as dynamo,
} from "aws-cdk-lib";
import { RemovalPolicy } from "aws-cdk-lib";
import { Construct } from "constructs";
import * as path from "path";

export class API extends Construct {
  apiv2: apigatewayv2.HttpApi;
  websocket: apigatewayv2.WebSocketApi;
  constructor(
    scope: Construct,
    id: string,
  // We get these from our `Auth` construct above
    userPool: cognito.UserPool,
    userPoolClient: cognito.UserPoolClient,
  ) {
    super(scope, id);
    this.websocket = this._webSocket(userPool, userPoolClient);

  _webSocket(
    userPool: cognito.UserPool,
    userPoolClient: cognito.UserPoolClient,
  ): apigatewayv2.WebSocketApi {
// This DynamoDB table will be used to save the `connectionId` when 
// clients connect.
    const connectionIdTable = new dynamo.Table(this, "ConnectionIdTable", {
      partitionKey: { name: "connectionId", type: dynamo.AttributeType.STRING },
      timeToLiveAttribute: "removedAt",
      billingMode: dynamo.BillingMode.PAY_PER_REQUEST,
      removalPolicy: RemovalPolicy.DESTROY,
    });

// The `userId` is used to uniquely ID the connected user associated 
// with the `Cognito UserPool`. We'll use this to clean-up the 
// database upon client `disconnect`
    connectionIdTable.addGlobalSecondaryIndex({
      partitionKey: { name: "userId", type: dynamo.AttributeType.STRING },
      indexName: "userIdIndex",
    });

// This lambda function will be used to authorize the connection. 
// Upon successful connection, it will save user information into 
// the DynamoDB table
    const authHandler = new lambdaNode.NodejsFunction(this, "AuthHandler", {
      runtime: lambda.Runtime.NODEJS_18_X,
      entry: path.join(__dirname, "../lambdas/websocket/authorizer/index.ts"),
      environment: {
        USER_POOL_ID: userPool.userPoolId,
        APP_CLIENT_ID: userPoolClient.userPoolClientId,
        CONNECTION_TABLE_NAME: connectionIdTable.tableName,
      },
    });

// This lambda function will handle all routes (in our simple example)
    const websocketHandler = new lambdaNode.NodejsFunction(
      this,
      "WebSocketHandler",
      {
        runtime: lambda.Runtime.NODEJS_18_X,
        entry: path.join(__dirname, "../lambdas/websocket/handler/index.ts"),
        environment: {
          CONNECTION_TABLE_NAME: connectionIdTable.tableName,
        },
      },
    );

    connectionIdTable.grantReadWriteData(websocketHandler);
    connectionIdTable.grantReadWriteData(authHandler);

// Create the authorizer. The `identitySource` is the information we 
// need to supply when connecting to the API. `authHandler` will 
// use the `idToken` to verify the user.
    const authorizer = new apigatewayv2_auth.WebSocketLambdaAuthorizer(
      "Authorizer",
      authHandler,
      {
        identitySource: ["route.request.querystring.idToken"],
      },
    );

// Now create our WebSocket.
    let apiv2Websocket = new apigatewayv2.WebSocketApi(
      this,
      "knowledgeBaseWebSocket",
      {
        description: "Used by your serverless app to integrate services",
// On the `$connect` route, specify our authorizer. When connecting, 
// traffic will go through the authorizer before being routed 
// to the `websocketHandler` integration.
        connectRouteOptions: {
          authorizer,
          integration: new apigatewayv2_integrations.WebSocketLambdaIntegration(
            "ConnectIntegration",
            websocketHandler,
          ),
        },
        disconnectRouteOptions: {
          integration: new apigatewayv2_integrations.WebSocketLambdaIntegration(
            "DisconnectIntegration",
            websocketHandler,
          ),
        },
// For this sample example, route all traffic through the `websocketHandler`
        defaultRouteOptions: {
          integration: new apigatewayv2_integrations.WebSocketLambdaIntegration(
            "DefaultIntegration",
            websocketHandler,
          ),
        },
      },
    );

// You need to manually add the stage.
    new apigatewayv2.WebSocketStage(this, "dev", {
      webSocketApi: apiv2Websocket,
      stageName: "dev",
      autoDeploy: true,
    });

// Be sure to grant the `websocketHandler` permission to manage 
// the WebSocket connection.
    apiv2Websocket.grantManageConnections(websocketHandler);

    return apiv2Websocket;
  }
}
Enter fullscreen mode Exit fullscreen mode

OKAY! Now what about the authHandler and websocketHandler functions?

// `authHandler`
import { APIGatewayRequestAuthorizerHandler } from "aws-lambda";
import { CognitoJwtVerifier } from "aws-jwt-verify";
import { DynamoDBClient } from "@aws-sdk/client-dynamodb";
import { DynamoDBDocumentClient, PutCommand } from "@aws-sdk/lib-dynamodb";
import { CognitoIdTokenPayload } from "aws-jwt-verify/jwt-model";

const UserPoolId = process.env.USER_POOL_ID!;
const AppClientId = process.env.APP_CLIENT_ID!;

const client = DynamoDBDocumentClient.from(new DynamoDBClient({}));
const ConnectionTableName = process.env.CONNECTION_TABLE_NAME!;

export const handler: APIGatewayRequestAuthorizerHandler = async (
  event,
  _context,
) => {
  try {
// We will use the CognitoJwtVerifier to verify that the `idToken` 
// supplied by the client is indeed associated with our `UserPool`
    const verifier = CognitoJwtVerifier.create({
      userPoolId: UserPoolId,
      tokenUse: "id",
      clientId: AppClientId,
    });

    const connectionId = event.requestContext.connectionId!;
    const encodedToken = event.queryStringParameters!.idToken!;
    const payload = await verifier.verify(encodedToken);

// Successfully Authenticated!
    console.log("Token is valid. Payload:", payload);
// Save the session information to the DynamoDB table
    await saveSession(payload, connectionId);
    console.log("User information saved to database!");
// Now return a policy which allows the client to connect!
    return allowPolicy(event.methodArn, payload);
  } catch (error: any) {
    console.log(error.message);
    return denyAllPolicy();
  }
};

async function saveSession(
  payload: CognitoIdTokenPayload,
  connectionId: string,
) {
  try {
    await client.send(
      new PutCommand({
        TableName: ConnectionTableName,
        Item: {
          userName: payload["cognito:username"] ?? "",
          userGroups: payload["cognito:groups"] ?? "",
          userEmail: payload["email"] ?? "",
          userId: payload.sub,
          connectionId: connectionId,
          removedAt: Math.ceil(Date.now() / 1000) + 3600 * 3,
        },
      }),
    );
  } catch (err) {
    console.error(err);
  }
}

function allowPolicy(methodArn: string, idToken: any) {
  return {
    principalId: idToken.sub,
    policyDocument: {
      Version: "2012-10-17",
      Statement: [
        {
          Action: "execute-api:Invoke",
          Effect: "Allow",
          Resource: methodArn,
        },
      ],
    },
    context: {
      // set userId in the context
      userId: idToken.sub,
    },
  };
}

function denyAllPolicy() {
  return {
    principalId: "*",
    policyDocument: {
      Version: "2012-10-17",
      Statement: [
        {
          Action: "*",
          Effect: "Deny",
          Resource: "*",
        },
      ],
    },
  };
}
Enter fullscreen mode Exit fullscreen mode

Once the client is authenticated, the connection request is routed to the websocketHandler function:

// `websocketHandler`
import { APIGatewayProxyHandler } from "aws-lambda";
import { DynamoDBClient } from "@aws-sdk/client-dynamodb";
import { DeleteCommand, DynamoDBDocumentClient } from "@aws-sdk/lib-dynamodb";
import {
  ApiGatewayManagementApiClient,
  PostToConnectionCommand,
} from "@aws-sdk/client-apigatewaymanagementapi";

const client = DynamoDBDocumentClient.from(new DynamoDBClient({}));
const ConnectionTableName = process.env.CONNECTION_TABLE_NAME!;

export const handler: APIGatewayProxyHandler = async (event, _context) => {
  console.log(event);
  const routeKey = event.requestContext.routeKey!;
  const connectionId = event.requestContext.connectionId!;

// If the route is `$connect` go ahead and connect
  if (routeKey == "$connect") {
    return { statusCode: 200, body: "Connected." };
  }
// If `$disconnect`
  if (routeKey == "$disconnect") {
    try {
// This function will delete the connection information from the 
// DynamoDB table when you disconnect.
      await removeConnectionId(connectionId);
      return { statusCode: 200, body: "Disconnected." };
    } catch (err) {
      console.error(err);
      return { statusCode: 500, body: "Disconnection failed." };
    }
// Default route for any other requests
  } else {
// Just echo back messages in other route than connect, disconnect 
// (for testing purpose)
    const domainName = event.requestContext.domainName!;
// When we use a custom domain, we don't need to append a stage name
    const endpoint = domainName.endsWith("amazonaws.com")
      ? `https://${event.requestContext.domainName}/${event.requestContext.stage}`
      : `https://${event.requestContext.domainName}`;

// Now let's send a message from the WebSocket back to the client!
    const managementApi = new ApiGatewayManagementApiClient({
      endpoint,
    });

    try {
// This is where the message gets sent back to client. This example 
// just echos the event body. So if we connect and just send RAW text, 
// the WebSocket will send back the same RAW text.
      await managementApi.send(
        new PostToConnectionCommand({
          ConnectionId: connectionId,
          Data: Buffer.from(
            JSON.stringify({
              message: event.body,
            }),
            "utf-8",
          ),
        }),
      );
    } catch (e: any) {
      if (e.statusCode == 410) {
        await removeConnectionId(connectionId);
      } else {
        console.log(e);
        throw e;
      }
    }

    return { statusCode: 200, body: "Received." };
  }
};

const removeConnectionId = async (connectionId: string) => {
  return await client.send(
    new DeleteCommand({
      TableName: ConnectionTableName,
      Key: {
        connectionId,
      },
    }),
  );
};
Enter fullscreen mode Exit fullscreen mode

Hope somebody found this useful! In the end probably using the already established IAM authorizer is more simple, but if you are looking to use your Cognito login credentials, then this example should help!

Top comments (0)