DEV Community

Indira Mythili
Indira Mythili

Posted on

Implementing Rate Limiting in NestJS with Custom Redis Storage

In modern APIs, rate-limiting is essential to prevent abuse, ensure fair usage, and protect backend resources. This implementation uses Redis for storing rate-limit data, offering both efficiency and scalability. We use the built-in NestJS throttler module in this implementation.

Configuration ----

The rate limit configuration is defined in ratelimit.config.ts

export const apiRateLimit = {
  ttl: 60, // Time to live in seconds
  limit: 5, // Maximum number of requests
};
Enter fullscreen mode Exit fullscreen mode

Throttler Module ----

The ThrottlerModule is configured in throttler.module.ts using the ThrottlerModule.forRootAsync method to load configuration dynamically:

@Module({
  imports: [
    CacheModule,
    ThrottlerModule.forRootAsync({
      imports: [ConfigModule],
      useFactory: (
        customStorage: CustomRedisThrottlerStorage,
        configService: ConfigService,
      ) => ({
        storage: customStorage,
        throttlers: [
          {
            ttl: configService.get<number>('apiRateLimit.ttl'),
            limit: configService.get<number>('apiRateLimit.limit'),
          },
        ],
      }),
      inject: [CustomRedisThrottlerStorage, ConfigService],
    }),
  ],
  providers: [
    CustomRedisThrottlerStorage,
    ConfigService,
    ThrottlerInterceptor,
  ],
  exports: [
    ThrottlerInterceptor,
    CustomRedisThrottlerStorage,
    ConfigService,
  ],
})
export class CustomThrottlerModule {}
Enter fullscreen mode Exit fullscreen mode

Custom Redis Throttler Storage ----

The CustomRedisThrottlerStorage class implements the ThrottlerStorage interface to store rate limit data in Redis:

@Injectable()
export class CustomRedisThrottlerStorage implements ThrottlerStorage {
  constructor(private cacheService: CacheService) {}

  * We can write our own methods to fetch records from cache using the key we used to store. For example,

  async getRecord(key: string): Promise<ThrottlerStorageRecord | undefined> {
    const value = await this.cacheService.get(key);
    if (value) {
      // Fetch the TTL from redis cache
      const ttl = await this.cacheService.ttl(key);
      const record: ThrottlerStorageRecord = {
        totalHits: parseInt(value, 10),
        timeToExpire: ttl,
        isBlocked: false,
        timeToBlockExpire: 0,
      };
      return record;
    }
    return undefined;
  }

- Increment the value when the API gets hit.
  async increment(key: string, ttl: number): Promise<ThrottlerStorageRecord> {
    let oldValue = 0;
    // Retrieve the current request count for the API key from Redis.
    const val = await this.cacheService.get(key);
    if (val) oldValue = parseInt(val, 10);

    const newValue = oldValue + 1;
    // set the new updated value in cache
    await this.cacheService.setex(key, newValue.toString(), ttl);
    const record: ThrottlerStorageRecord = {
      totalHits: newValue,
      timeToExpire: ttl,
      isBlocked: false,
      timeToBlockExpire: 0,
    };
    return record;
  }
}
Enter fullscreen mode Exit fullscreen mode

Throttler Guard ----

The ThrottlerGuard class extends ThrottlerGuard to customize the tracking of requests:

We can use any unique identifier as a key, Here I have used "IP address + user-related info" as an unique parameter.

@Injectable()
export class CustomThrottlerGuard extends ThrottlerGuard {
  protected async getTracker(req: Record<string, string>): Promise<string> {
    return fetchIpAddress(req);
  }
}
Enter fullscreen mode Exit fullscreen mode

Throttler Interceptor

The CustomThrottlerInterceptor class implements NestInterceptor to handle rate limiting logic:

@Injectable()
export class CustomThrottlerInterceptor implements NestInterceptor {
  async isRateLimited(key: string) {
    const limit = this.configService.get<number>('apiRateLimit.limit');
    const record = await this.customRedisThrottlerStorage.getRecord(key);
    return record.totalHits > limit;
  }

  intercept(context: ExecutionContext, next: CallHandler): Observable<any> {
    const request: Record<string, string> = context.switchToHttp().getRequest();
    const requestIP = fetchIpAddress(request);
    const key = `login:${requestIP}:${request.url}`;// Example redis key
    const ttl = this.configService.get<number>('apiRateLimit.ttl');

    return next.handle().pipe(
      catchError(async (err: Error) => {        
        await this.customRedisThrottlerStorage.increment(key, ttl);
          if (await this.isRateLimited(key)) {
            throw new ThrottlerException();
          }
      }),
    );
  }
}
Enter fullscreen mode Exit fullscreen mode

Testing Tips:

  • We can use Postman to send multiple requests and observe Redis entries using redis-cli.

Test edge cases like:

  • Bursts of requests within the TTL.
  • Requests exceeding the limit exactly.

I used a simple flowchart to illustrate the entire process, showing how the throttler mechanism works step by step. If you have any suggestions, edits, or ideas for improvement, feel free to share them—I’d love to hear your feedback!

Image description

Thank you for your time! I hope you found this article helpful!

Top comments (0)