DEV Community

Cover image for Asynchronous Server: Building and Rigorously Testing a WebSocket and HTTP Server
John Owolabi Idogun
John Owolabi Idogun

Posted on • Originally published at johnowolabiidogun.dev

Asynchronous Server: Building and Rigorously Testing a WebSocket and HTTP Server

Introduction

Part 5 is already out here: https://johnowolabiidogun.dev/blog/building-an-ai-powered-financial-behavior-analyzer-with-nodejs-python-sveltekit-and-tailwindcss-part-5-dashboard-fc55f4/67b6cf47fe10666dd67e149d

In this article, a detour from the ongoing series on building a financial data analyzer, I focus on the critical aspect of rigorously testing the server. Ensuring appropriate and accurate data handling is paramount. While Python doesn't enforce strict type-safety, I'll demonstrate how to use tools like mypy, bandit, and later prospector to maintain basic code quality and standards.

Prerequisite

To follow along, it's recommended you've read the article on building the preliminary AI service. This article builds upon the concepts and code established there.

Source code

GitHub logo Sirneij / finance-analyzer

An AI-powered financial behavior analyzer and advisor written in Python (aiohttp) and TypeScript (ExpressJS & SvelteKit with Svelte 5)




Implementation

Step 1: Improving the Initial AI Service

The AI service described in the AI service has several areas for improvement:

  1. Testability: The current structure makes automated testing (both integration and unit) difficult.
  2. Model Accuracy: The zero-shot classification model, originally designed for sentiment analysis, isn't optimal for categorizing financial transactions. A more suitable model is needed.
  3. Code Quality: The code requires refactoring, cleanup, and the addition of new features.
  4. Type Consistency: Type annotations need to be consistently applied and enforced throughout the codebase.

To address these, we will adopt this structure:

.
├── README.md
├── mypy.ini
├── requirements.dev.txt
├── requirements.txt
├── run.py
├── scripts
│   └── test_app.sh
├── src
│   ├── __init_.py
│   ├── app
│   │   ├── __init__.py
│   │   └── app_instance.py
│   ├── models
│   │   ├── __init__.py
│   │   └── base.py
│   └── utils
│       ├── __init__.py
│       ├── analyzer.py
│       ├── base.py
│       ├── extract_text.py
│       ├── resume_parser.py
│       ├── settings.py
│       ├── summarize.py
│       └── websocket.py
└── tests
    ├── __init__.py
Enter fullscreen mode Exit fullscreen mode

We introduced the src/ directory to house the entire application. The aiohttp server setup was refactored into src/app/app_instance.py, with run.py simply responsible for running the created app instance:

import os

from aiohttp import web

from src.app.app_instance import init_app
from src.utils.settings import base_settings

if __name__ == '__main__':
    app = init_app()
    try:
        web.run_app(
            app,
            host='0.0.0.0',
            port=int(os.environ.get('PORT', 5173)),
        )
    except KeyboardInterrupt:
        base_settings.logger.info('Received keyboard interrupt...')
    except Exception as e:
        base_settings.logger.error(f'Server error: {e}')
    finally:
        base_settings.logger.info('Server shutdown complete.')
Enter fullscreen mode Exit fullscreen mode

The run.py file initializes and starts the aiohttp application.

The key changes in app_instance.py are highlighted below:

+ import asyncio
+ from weakref import WeakSet
...
- from utils.analyzer import analyze_transactions
- from utils.extract_text import extract_text_from_pdf
- from utils.resume_parser import extract_text_with_pymupdf, parse_resume_text
- from utils.settings import base_settings
- from utils.summarize import summarize_transactions
- from utils.websocket import WebSocketManager
+ from src.utils.analyzer import analyze_transactions
+ from src.utils.extract_text import extract_text_from_pdf
+ from src.utils.resume_parser import extract_text_with_pymupdf, parse_resume_text
+ from src.utils.settings import base_settings
+ from src.utils.summarize import summarize_transactions
+ from src.utils.websocket import WebSocketManager

# Replace global ws_connections with typed version
- ws_connections: set[WebSocketResponse] = set()
- ws_lock = Lock()
+ WEBSOCKETS = web.AppKey("websockets", WeakSet[WebSocketResponse])


- async def start_background_tasks(app):
+ async def start_background_tasks(app: web.Application) -> None:
    """Initialize application background tasks."""
-   app['ws_connections'] = ws_connections
-   app['ws_lock'] = ws_lock
+   app[WEBSOCKETS] = WeakSet()


- async def cleanup_background_tasks(app):
-     """Cleanup application resources."""
-     await cleanup_ws(app)


- async def cleanup_ws(app):
+ async def cleanup_ws(app: web.Application) -> None:
    """Cleanup WebSocket connections on shutdown."""
-     async with ws_lock:
-         connections = set(ws_connections)  # Create a copy to iterate safely
-         for ws in connections:
-             await ws.close(code=WSMsgType.CLOSE, message='Server shutdown')
-         ws_connections.clear()
+     for websocket in set(app[WEBSOCKETS]):  # type: ignore
+        await websocket.close(code=WSCloseCode.GOING_AWAY, message=b'Server shutdown')

async def websocket_handler(request: Request) -> WebSocketResponse:
    """WebSocket handler for real-time communication."""
    ws = web.WebSocketResponse()
    await ws.prepare(request)

-    async with ws_lock:
-        ws_connections.add(ws)
+   request.app[WEBSOCKETS].add(ws)
    ws_manager = WebSocketManager(ws)
    await ws_manager.prepare()

+    async def ping_server(ws: WebSocketResponse) -> None:
+        try:
+            while True:
+                await ws.ping()
+                await asyncio.sleep(25)
+        except ConnectionResetError:
+            base_settings.logger.info("Client disconnected")
+        finally:
+            await ws.close()
+
+    asyncio.create_task(ping_server(ws))

    base_settings.logger.info('WebSocket connection established')

    try:
        async for msg in ws:
+            if msg.type == WSMsgType.PING:
+                base_settings.logger.info('Intercepted PING from client')
+                await ws.pong(msg.data)
+            elif msg.type == WSMsgType.PONG:
+                base_settings.logger.info('Intercepted PONG from client')
            if msg.type == WSMsgType.TEXT:
                ...
-            elif msg.type == WSMsgType.ERROR:
+            elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
-                base_settings.logger.error(f'WebSocket error: {ws.exception()}')
+               base_settings.logger.info(
+                    'WebSocket is closing or encountered an error',
+                )
+                break
    except Exception as e:
        base_settings.logger.error(f'WebSocket handler error: {str(e)}')
    finally:
-        async with ws_lock:
-            ws_connections.remove(ws)
-        if not ws.closed:
-            await ws.close()
+        request.app[WEBSOCKETS].discard(ws)
+        if not ws.closed:
+            await ws.close()
        base_settings.logger.info('WebSocket connection closed')

    return ws

def init_app() -> web.Application:
    ...
    # Add startup/cleanup handlers
    app.on_startup.append(start_background_tasks)
-    app.on_cleanup.append(cleanup_background_tasks)
+    app.on_shutdown.append(cleanup_ws)

    return app
Enter fullscreen mode Exit fullscreen mode

We improved type consistency throughout the codebase, using # type: ignore where necessary. We also replaced the global WebSocket connection list with weakref.WeakSet for more robust connection management during shutdown. To maintain persistent connections during long-running processes like zero-shot classification, we implemented a ping/pong mechanism.

Next, we consolidated common utility functions into a new src/utils/base.py file. This included functions like validate_and_convert_transactions, get_device, detect_anomalies, analyze_spending, predict_trends, calculate_trend, and calculate_percentage_change, previously located in utils/summarize.py and utils/analyze.py. We also introduced new functions to estimate financial health (calculate_financial_health) and detect recurring transactions (analyze_recurring_transactions). The anomaly detection was enhanced to identify single-instance anomalies, and the transaction grouping algorithm now uses difflib for fuzzy matching of descriptions. For example, difflib might consider these descriptions to be similar (approximately 69% match): "Target T-12345 Anytown USA" and "Target 12345 Anytown USA":

def group_transactions_by_description(transactions: list[Transaction], cutoff: float = 0.69) -> dict[str, list[float]]:
    """
    Group transactions by description using fuzzy matching with difflib.

    Returns a dictionary mapping a representative description (the group key)
    to a list of transaction amounts. Two descriptions are grouped together if
    their similarity is above a certain threshold.
    """
    groups: dict[str, list[float]] = {}

    for tx in transactions:
        desc = tx.description.lower().strip()
        # Try to find an existing key similar to desc.
        # difflib.get_close_matches returns a list of close matches.
        close_matches = difflib.get_close_matches(desc, groups.keys(), n=1, cutoff=cutoff)
        if close_matches:
            matched_key = close_matches[0]
        else:
            matched_key = None

        if matched_key:
            groups[matched_key].append(tx.amount)
        else:
            groups[desc] = [tx.amount]

    return groups


def find_group_key(description: str, group_keys: list[str], cutoff: float = 0.69) -> str:
    """
    Find the best matching key from group_keys for the given description using difflib.
    Returns the matched key if similarity is above cutoff; otherwise, returns the description.
    """
    desc = description.lower().strip()
    matches = difflib.get_close_matches(desc, group_keys, n=1, cutoff=cutoff)
    if matches:
        return matches[0]
    return desc
Enter fullscreen mode Exit fullscreen mode

We also encapsulated sending progress reports in a reusable function, update_progress.

In src/utils/analyzer.py, the major improvements are:

  1. Improved Model Accuracy: We switched from the yiyanghkust/finbert-tone model to facebook/bart-large-mnli for zero-shot classification. This significantly improves accuracy, although at the cost of speed. For multilingual support, joeddav/xlm-roberta-large-xnli is another option.
  2. Hybrid Classification Approach: We now use a hybrid approach, first attempting to classify transactions using pattern matching. Any remaining unclassified transactions are then processed by the ML model. To improve performance, we process transactions in batches, releasing the event loop after each batch to allow other operations to proceed and to clear memory.
  3. Offloading Calculations: To reduce the load on the classification process, we moved the calculation of anomalies, spending_analysis, spending_trends, recurring_transactions, and financial_health to src/utils/summarize.py, which is significantly faster.

Step 2: Enforcing Type Safety, Security, and Style

Our type annotations are currently only decorative. To enforce type safety, ensure code security, and maintain consistent code style, we'll use the following tools:

  • mypy: A static type checker.
  • bandit: A security linter.
  • black: An uncompromising code formatter.
  • isort: A tool for sorting imports.

Tip: Consider using Prospector

Prospector provides comprehensive static analysis and ensures your code conforms to PEP8 and other style guidelines. It's highly recommended for in-depth code quality checks.

Install these tools and add them to requirements.dev.txt:

(virtualenv) pip install mypy bandit black isort
Enter fullscreen mode Exit fullscreen mode

Create a mypy.ini file at the root of the project with the following configuration:

# some config from:
# https://www.ralphminderhoud.com/blog/django-mypy-check-runs/
[mypy]
# The mypy configurations: https://mypy.readthedocs.io/en/latest/config_file.html
python_version = 3.13

check_untyped_defs = True
disallow_untyped_defs= True
disallow_incomplete_defs = True
disallow_any_generics = True
disallow_untyped_calls = True
# needs this because celery doesn't have typings
disallow_untyped_decorators = False
ignore_errors = False
ignore_missing_imports = True
implicit_reexport = False
strict_optional = True
strict_equality = True
no_implicit_optional = True
warn_unused_ignores = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unreachable = True
warn_no_return = True
# added these 2 option in mypy 0.800 to enable it to run in our code base
explicit_package_bases = True
namespace_packages = True

[mypy-*.migrations.*]
ignore_errors = True
Enter fullscreen mode Exit fullscreen mode

This configuration enforces various type-checking rules. Each option is generally self-explanatory.

Next, create a bash script (scripts/static_check.sh) to automate the static analysis process:

#!/usr/bin/env bash

set -e

# run black - make sure everyone uses same python style
black --skip-string-normalization --line-length 120 --check src/
black --skip-string-normalization --line-length 120 --check run.py
black --skip-string-normalization --line-length 120 --check tests/

# run isort for import structure checkup with black profile
isort --atomic --profile black -c src/
isort --atomic --profile black -c run.py
isort --atomic --profile black -c tests/

# run mypy
mypy src/

# run bandit - A security linter from OpenStack Security
bandit -r src/

# python static analysis
# prospector  --profile=.prospector.yml --path=src --ignore-patterns=static
# prospector  --profile=.prospector.yml --path=tests --ignore-patterns=static
Enter fullscreen mode Exit fullscreen mode

This script checks the code against the defined standards. To ensure your code passes these checks, run the following commands before committing:

black --skip-string-normalization --line-length 120  src tests *.py

isort --atomic --profile black src tests *.py
Enter fullscreen mode Exit fullscreen mode

To enforce these rules in a team environment, we'll use a CI/CD pipeline. This pipeline runs these checks, and any failure prevents the pull or merge request from being merged. We will use GitHub Actions for our CI/CD. Create a .github/workflows/aiohttp.yml file:

name: UTILITY-SERVER CI

on:
  push:
    branches: [utility]
  pull_request:
    branches: [utility]

jobs:
  build:
    runs-on: ubuntu-latest
    strategy:
      max-parallel: 4
      matrix:
        python-version: [3.13] #[3.7, 3.8, 3.9]

    steps:
      - uses: actions/checkout@v4
      - name: Install system dependencies
        run: |
          sudo apt-get update
          sudo apt-get install -y \
            poppler-utils \
            tesseract-ocr \
            libtesseract-dev \
            libglib2.0-0

      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v3
        with:
          python-version: ${{ matrix.python-version }}
      - name: Install Dependencies
        run: |
          python -m pip install --upgrade pip
          pip install -r requirements.dev.txt
      - name: Debug Environment
        run: |
          python -c "import sys; print(sys.version)"
          python -c "import platform; print(platform.platform())"
          env | sort
      - name: Run static analysis
        run: ./scripts/static_check.sh
      - name: Run Tests
        env:
          LC_ALL: en_NG.UTF-8
          LANG: en_NG.UTF-8
          LABELS: groceries,school,housing,transportation,gadgets,entertainment,utilities,credit cards,other,dining out,healthcare,insurance,savings,investments,childcare,travel,personal care,debts,charity,taxes,subscriptions,streaming services,home maintenance,shopping,pets,fitness,hobbies,gifts
        run: |
          coverage run --parallel-mode -m unittest discover tests && coverage combine && coverage report -m && coverage html
Enter fullscreen mode Exit fullscreen mode

GitHub Actions uses .yaml or .yml files to define workflows, similar to docker-compose.yml. In this case, we're using the latest Ubuntu distribution as the environment. We use version 4 of the actions/checkout action to check out our repository. We also install system dependencies required by some of the Python packages, such as poppler-utils for pdf2image and tesseract-ocr and libtesseract-dev for pytesseract. Since our project doesn't have database interaction, we don't need a services section. The remaining steps are self-explanatory. We then execute our bash script to check the codebase against our defined standards. We also supply environment variables and run the tests (which we'll write later). This CI/CD pipeline runs on every pull request or push to the utility branch.

Step 3: Writing the tests

The last part of our CI/CD was running tests and getting coverage reports. In the Python ecosystem, pytest is an extremely popular testing framework. Though very tempting and might still be used later on, we will stick with Python's built-in testing library, unittest, and use coverage for measuring code test coverage of our program. Let's start with the test setup:

import unittest
import uuid
from datetime import datetime

from aiohttp.test_utils import AioHTTPTestCase

from src.app.app_instance import init_app
from src.utils.websocket import WebSocketManager


class Base:
    """Base class for tests."""

    def create_transaction_dict(
        self,
        date: datetime | str,
        description: str,
        amount: float,
        balance: float,
        type_str='expense',
        include_v: bool = True,
    ) -> dict:
        txn = {
            '_id': str(uuid.uuid4()),
            'date': date.isoformat() if isinstance(date, datetime) else date,
            'createdAt': date.isoformat() if isinstance(date, datetime) else date,
            'updatedAt': date.isoformat() if isinstance(date, datetime) else date,
            'description': description,
            'amount': amount,
            'balance': balance,
            'type': type_str,
            'userId': '1',
        }
        if include_v:
            txn['__v'] = 0

        return txn


# A simple fake WebSocketResponse to simulate aiohttp behavior.
class FakeWebSocket:
    def __init__(self, raise_on_send=False):
        self.messages = []  # will store the JSON messages sent
        self.closed = False
        self.raise_on_send = raise_on_send
        self.close_code = None
        self.close_message = None

    async def send_json(self, data):
        if self.raise_on_send:
            raise Exception('send_json error')
        self.messages.append(data)

    async def close(self, code=None, message=None):
        self.closed = True
        self.close_code = code
        self.close_message = message


class BaseAsyncTestClass(Base, unittest.IsolatedAsyncioTestCase):
    """Base class for async tests."""

    async def asyncSetUp(self):
        # Create a FakeWebSocket for each test.
        self.fake_ws = FakeWebSocket()
        self.websocket_manager = WebSocketManager(self.fake_ws)


class BaseTestClass(Base, unittest.TestCase):
    """Base class for sync tests."""


class BaseAioHTTPTestCase(Base, AioHTTPTestCase):
    """Base class for aiohttp tests."""

    async def get_application(self):
        return init_app()

    async def asyncSetUp(self):
        await super().asyncSetUp()
        # Create a FakeWebSocket for each test.
        self.fake_ws = FakeWebSocket()
        self.websocket_manager = WebSocketManager(self.fake_ws)


if __name__ == '__main__':
    unittest.main()
Enter fullscreen mode Exit fullscreen mode

We simply have classes which provide blueprints for our tests. The Base class makes the create_transaction_dict method available to all its children, simplifying the creation of transaction data for tests. The FakeWebSocket class simulates aiohttp WebSocket behavior, which is essential for unit testing the project's WebSocket utilities. All asynchronous unit tests inherit from BaseAsyncTestClass, while synchronous tests inherit from BaseTestClass. BaseAioHTTPTestCase is used for integration-style tests that involve the aiohttp application. The get_application is required in this class to return our app's instance.

Note: Unit vs Integration tests

A unit test focuses on testing a single piece of code (like a function such as analyze_recurring_transactions) whereas integration tests examine how multiple units of code interact with each other within a system (this is like testing the behavior of sending a request to /ws)

Let's take an example integration-style test, especially for our websocket, and another unit test for some of the subprocesses to balance things out:

import asyncio
import json
from unittest.mock import AsyncMock, patch

from aiohttp import WSMsgType

from src.app.app_instance import WEBSOCKETS
from tests import BaseAioHTTPTestCase


class TestWebSocketHandler(BaseAioHTTPTestCase):
    """Exhaustively test the WebSocket handler."""

    async def setUpAsync(self):
        await super().setUpAsync()
        # Capture the original create_task function.
        self.orig_create_task = asyncio.create_task

    async def __dummy_analyze(self, transactions, ws_manager):
        """Dummy analyze implementation that returns a known result."""
        return {
            'categories': {
                'expenses': {
                    'groceries': 10.0,
                    'rent': 90.0,
                },
                'expense_percentages': {
                    'groceries': 5,
                    'rent': 45,
                },
                'income': 200.0,
            }
        }

    async def __dummy_summarize(self, transactions, ws_manager):
        """Dummy summarize implementation that returns a known result."""
        return {
            'income': {
                'total': 200.00,
                'trend': 'neutral',
                'change': 0.0,
            },
            'expenses': {
                'total': 100.00,
                'trend': 'neutral',
                'change': 0.0,
            },
            'savings': {
                'total': 100.00,
                'trend': 'neutral',
                'change': 0.0,
            },
            'total_transactions': 2,
            'expense_count': 1,
            'income_count': 1,
            'avg_expense': 100.00,
            'avg_income': 200.00,
            'start_date': '2022-01-01',
            'end_date': '2022-01-31',
            'largest_expense': 200.00,
            'largest_income': 200.00,
            'savings_rate': 50.0,
            'monthly_summary': {
                '2022-01': {
                    'income': 200.00,
                    'expenses': 100.00,
                    'savings': 100.00,
                },
            },
            'anomalies': [],
            'spending_analysis': {
                'total_spent': 100.00,
                'total_income': 200.00,
                'savings_rate': 50.0,
                'daily_summary': {
                    '2022-01-01': {
                        'total_spent': 100.00,
                        'total_income': 200.00,
                        'savings_rate': 50.0,
                    },
                },
                'cumulative_balance': {
                    '2022-01-01': 100.00,
                },
            },
            'spending_trends': {
                'total_spent': 100.00,
                'total_income': 200.00,
                'savings_rate': 50.0,
            },
            'recurring_transactions': [],
            'financial_health': {
                'debt_to_income_ratio': 0,
                'savings_rate': 0,
                'balance_growth_rate': 0,
                'financial_health_score': 0,
            },
        }

    def __dummy_create_task(self, coro):
        if hasattr(coro, 'cr_code') and 'ping_server' in coro.cr_code.co_qualname:
            # Explicitly close the ping_server coroutine so it doesn't leak.
            coro.close()
            # Return a dummy, already‐completed future.
            fut = asyncio.Future()
            fut.set_result(None)
            return fut
        return self.orig_create_task(coro)

    async def __receive_messages(self, ws, count, timeout=5):
        """Helper to collect 'count' text messages from the WebSocket."""
        messages = []
        while len(messages) < count:
            msg = await ws.receive(timeout=timeout)
            if msg.type == WSMsgType.TEXT:
                messages.append(json.loads(msg.data))
            elif msg.type == WSMsgType.CLOSE:
                break
        return messages

    async def test_analyze_action(self):
        """Test that sending an 'analyze' action yields progress and result messages."""
        self.transactions = [
            self.create_transaction_dict('2022-01-01', 'Transaction 1', -100.0, 100.0),
            self.create_transaction_dict('2022-01-02', 'Transaction 2', 200.0, 300.0),
        ]

        # Patch the analyzer so that it returns a predictable result,
        # and patch create_task with our dummy version.
        with patch(
            'src.app.app_instance.analyze_transactions',
            new=AsyncMock(side_effect=self.__dummy_analyze),
        ), patch("asyncio.create_task", self.__dummy_create_task):
            ws = await self.client.ws_connect('/ws')
            msg_data = {'action': 'analyze', 'transactions': self.transactions}
            await ws.send_str(json.dumps(msg_data))
            # This helps avoid timout errors when the server is slow to respond.
            messages = await self.__receive_messages(ws, 2)
            self.assertEqual(len(messages), 2)

            # First response: progress message.
            progress, result = messages
            self.assertEqual(progress.get('action'), 'progress')
            self.assertEqual(progress.get('message'), 'Analysis complete')
            self.assertEqual(progress.get('progress'), 1.0)
            self.assertEqual(progress.get('taskType'), 'Analysis')

            # Second response: result message.
            self.assertEqual(result.get('action'), 'analysis_complete')
            self.assertEqual(result.get('taskType'), 'Analysis')
            expected_data = await self.__dummy_analyze(self.transactions, self.websocket_manager)
            self.assertEqual(result.get('result'), expected_data)
            await ws.close()

    async def test_summary_action(self):
        """Test that sending a 'summary' action yields progress and result messages."""
        self.transactions = [
            self.create_transaction_dict('2022-01-01', 'Transaction 1', -100.0, 100.0),
            self.create_transaction_dict('2022-01-02', 'Transaction 2', 200.0, 300.0),
        ]

        # Patch the summarizer so that it returns a predictable result,
        # and patch create_task with our dummy version.
        with patch(
            'src.app.app_instance.summarize_transactions',
            new=AsyncMock(side_effect=self.__dummy_summarize),
        ), patch("asyncio.create_task", self.__dummy_create_task):
            ws = await self.client.ws_connect('/ws')
            msg_data = {'action': 'summary', 'transactions': self.transactions}
            await ws.send_str(json.dumps(msg_data))
            # This helps avoid timout errors when the server is slow to respond.
            messages = await self.__receive_messages(ws, 2)
            self.assertEqual(len(messages), 2)

            # First response: progress message.
            progress, result = messages
            self.assertEqual(progress.get('action'), 'progress')
            self.assertEqual(progress.get('message'), 'Summary complete')
            self.assertEqual(progress.get('progress'), 1.0)
            self.assertEqual(progress.get('taskType'), 'Summarize')

            # Second response: result message.
            self.assertEqual(result.get('action'), 'summary_complete')
            self.assertEqual(result.get('taskType'), 'Summarize')
            expected_data = await self.__dummy_summarize(self.transactions, self.websocket_manager)
            self.assertEqual(result.get('result'), expected_data)
            await ws.close()

    async def test_unknown_action(self):
        """Test that an unknown action returns an error message."""
        ws = await self.client.ws_connect('/ws')
        msg_data = {'action': 'nonexistent'}
        await ws.send_str(json.dumps(msg_data))
        # This helps avoid timout errors when the server is slow to respond.
        messages = await self.__receive_messages(ws, 1)
        self.assertEqual(len(messages), 1)
        error = messages[0]
        self.assertEqual(error.get('action'), 'error')
        self.assertEqual(error.get('taskType'), 'Error')
        self.assertEqual(error.get('result'), {'message': 'Unknown action'})
        await ws.close()

    async def test_message_processing_exception(self):
        """Test that sending invalid JSON produces an error message."""
        ws = await self.client.ws_connect('/ws')
        await ws.send_str('invalid json')
        # This helps avoid timout errors when the server is slow to respond.
        messages = await self.__receive_messages(ws, 1)
        self.assertEqual(len(messages), 1)
        error = messages[0]
        self.assertEqual(error.get('action'), 'error')
        self.assertEqual(error.get('taskType'), 'Error')
        self.assertEqual(error.get('result'), {'error': 'Expecting value: line 1 column 1 (char 0)'})
        await ws.close()

    async def test_close_on_error(self):
        """Test that when a client closes the connection, the WebSocket is removed from the app."""
        ws = await self.client.ws_connect('/ws')
        await ws.send_str('invalid json')
        # This helps avoid timout errors when the server is slow to respond.
        messages = await self.__receive_messages(ws, 1)
        self.assertEqual(len(messages), 1)
        await ws.close()
        self.assertNotIn(ws, self.app[WEBSOCKETS])
Enter fullscreen mode Exit fullscreen mode

Overlooking the dummy data generators, the __receive_messages helper is crucial for accumulating WebSocket messages. Without it, attempting await ws.receive_json(...) multiple times could lead to timeout errors, resulting in cryptic tracebacks:

----------------------------------------------------------------------
Traceback (most recent call last):
  File ".../utility/virtualenv/lib/python3.13/site-packages/aiohttp/client_ws.py", line 332, in receive
    msg = await self._reader.read()
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "aiohttp/_websocket/reader_c.py", line 109, in read
  File "aiohttp/_websocket/reader_c.py", line 106, in aiohttp._websocket.reader_c.WebSocketDataQueue.read
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^
  File "/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/base_events.py", line 725, in run_until_complete
    return future.result()
           ~~~~~~~~~~~~~^^
  File ".../utility/tests/app/websocket_handler/test_integration.py", line 114, in __receive_messages
    msg = await ws.receive_json(timeout=timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../utility/virtualenv/lib/python3.13/site-packages/aiohttp/client_ws.py", line 331, in receive
    async with async_timeout.timeout(receive_timeout):
               ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/asyncio/timeouts.py", line 116, in __aexit__
    raise TimeoutError from exc_val
TimeoutError
Enter fullscreen mode Exit fullscreen mode

The helper also aids in filtering messages of interest. We also created a dummy version of ping_server to properly close it and prevent memory leaks. With the dummy functions in place, we created test cases that interact with our WebSocket endpoint. Using async patches and mocks, we fed predictable responses to the tests. Note that we used our async dummy methods as the side_effect of the AsyncMock. Using return_value instead of side_effect in the mocks prolonged the processes and caused timeout errors.

The other test cases handle various scenarios to provide better test coverage.

Warning: The file path in patch

When supplying file paths in patch, use the path where the program is used, not where it was defined. For instance, src.app.app_instance.analyze_transactions was defined in src/utils/analyzer.py but since it was used in src/app/app_instance.py, we used src.app.app_instance.analyze_transactions.

However, the integration testing approach poses some limitations. We can't modify the internals of the aiohttp WebSocket instance. This is where unit testing comes to the rescue, as we can modify internals and mock them as needed to thoroughly test the desired feature. Hence the other test file for our WebSocket, tests/app/websocket_handler/test_ping.py.

To wrap up, let's see how we tested the src/utils/analyze.py:

import uuid
from datetime import datetime, timedelta
from unittest.mock import patch

import torch

from src.utils.analyzer import analyze_transactions, classify_transactions
from src.utils.base import (
    analyze_recurring_transactions,
    analyze_spending,
    calculate_financial_health,
    detect_anomalies,
    predict_trends,
    validate_and_convert_transactions,
)
from tests import BaseAsyncTestClass


class TestAnalyzer(BaseAsyncTestClass):
    @patch(
        'src.utils.analyzer.pipeline', return_value=lambda *args, **kwargs: [{'labels': ['groceries'], 'scores': [1.0]}]
    )
    async def test_analyze_transactions_valid(self, mock_pipeline):
        tx_data = [
            {
                '_id': str(uuid.uuid4()),
                'date': '2024-01-01T00:00:00',
                'createdAt': '2024-01-01T00:00:00',
                'updatedAt': '2024-01-01T00:00:00',
                'description': 'Test expense',
                'amount': -100,
                'balance': 900,
                'type': 'expense',
                'userId': '1',
            },
            {
                '_id': str(uuid.uuid4()),
                'date': '2024-01-02T00:00:00',
                'createdAt': '2024-01-02T00:00:00',
                'updatedAt': '2024-01-02T00:00:00',
                'description': 'Salary',
                'amount': 2000,
                'balance': 2900,
                'type': 'income',
                'userId': '1',
            },
        ]
        result = await analyze_transactions(tx_data)
        self.assertIn('categories', result)
    ...

    async def test_classify_transactions_pattern_matching(self):
        """
        Test that transactions with descriptions matching common patterns
        are categorized correctly without invoking the ML pipeline.
        """
        # Create dummy transactions that should match predefined patterns
        tx1 = self.create_transaction_dict('2024-01-01T00:00:00', 'Walmart grocery purchase', -50.0, 950.0, 'expense')
        tx2 = self.create_transaction_dict('2024-01-02T00:00:00', 'Uber ride', -20.0, 930.0, 'expense')
        tx3 = self.create_transaction_dict('2024-01-03T00:00:00', 'Netflix subscription', -15.0, 915.0, 'expense')
        tx4 = self.create_transaction_dict('2024-01-04T00:00:00', 'Salary', 3000.0, 3915.0, 'income')

        # We assume pattern matching is applied first
        transactions = await validate_and_convert_transactions([tx1, tx2, tx3, tx4])
        # Call classify_transactions without a WebSocket manager
        result = await classify_transactions(transactions)
        categories = result.get('expenses', {})
        income_total = result.get('income', 0)

        # Check that the descriptions are mapped to expected categories:
        # "Walmart grocery" should fall under 'groceries'
        self.assertIn('groceries', categories)
        self.assertGreater(categories['groceries'], 0)

        # "Uber ride" should fall under 'transportation'
        self.assertIn('transportation', categories)
        self.assertGreater(categories['transportation'], 0)

        # "Netflix subscription" should be captured under 'subscriptions'
        self.assertIn('subscriptions', categories)
        self.assertGreater(categories['subscriptions'], 0)

        # Income should include the salary
        self.assertEqual(income_total, 3000)

    @patch.dict(
        'os.environ',
        {"LABELS": "groceries,housing,transportation,entertainment,utilities,education,credit_cards,insurance,other"},
    )
    @patch('src.utils.analyzer.pipeline')
    async def test_classify_transactions_ml_fallback(self, mock_pipeline):
        # Simulate a transaction with an unmatched description
        tx1 = self.create_transaction_dict(
            '2024-01-05T00:00:00', 'Unusual expense with no pattern', -75.0, 840.0, 'expense'
        )

        # Setup the fake pipeline result
        fake_result = [{'labels': ['other'], 'scores': [0.95]}]
        mock_pipeline.return_value = lambda *args, **kwargs: fake_result

        transactions = await validate_and_convert_transactions([tx1])
        result = await classify_transactions(transactions)
        categories = result.get('expenses', {})

        # Expect that the ML fallback has assigned this expense to 'other'
        self.assertIn('other', categories)
        self.assertAlmostEqual(categories['other'], 75 * 0.95, places=2)

    async def test_analyze_recurring_transactions_monthly(self):
        """
        Test that transactions with the same description and a roughly monthly interval
        are detected as recurring.
        """

        base_date = datetime(2024, 1, 1)
        # Create 3 monthly transactions (interval ~30 days)
        tx1 = self.create_transaction_dict((base_date).isoformat(), 'Gym membership', -50.0, 950.0, 'expense')
        tx2 = self.create_transaction_dict(
            (base_date + timedelta(days=30)).isoformat(), 'Gym membership', -50.0, 900.0, 'expense'
        )
        tx3 = self.create_transaction_dict(
            (base_date + timedelta(days=60)).isoformat(), 'Gym membership', -50.0, 850.0, 'expense'
        )

        transactions = await validate_and_convert_transactions([tx1, tx2, tx3])

        recurring = analyze_recurring_transactions(transactions)
        self.assertTrue(len(recurring) > 0)
        monthly_recurring = next((r for r in recurring if r['frequency'] == 'Monthly'), None)
        self.assertIsNotNone(monthly_recurring)
        self.assertEqual(monthly_recurring['description'], 'gym membership')

    async def test_analyze_recurring_transactions_weekly(self):
        """
        Test that transactions with the same description and a roughly weekly interval
        are detected as recurring.
        """

        base_date = datetime(2024, 1, 1)
        # Create 3 weekly transactions (interval ~7 days)
        tx1 = self.create_transaction_dict((base_date).isoformat(), 'Weekly yoga class', -20.0, 980.0, 'expense')
        tx2 = self.create_transaction_dict(
            (base_date + timedelta(days=7)).isoformat(), 'Weekly yoga class', -20.0, 960.0, 'expense'
        )
        tx3 = self.create_transaction_dict(
            (base_date + timedelta(days=14)).isoformat(), 'Weekly yoga class', -20.0, 940.0, 'expense'
        )

        transactions = await validate_and_convert_transactions([tx1, tx2, tx3])

        recurring = analyze_recurring_transactions(transactions)
        self.assertTrue(len(recurring) > 0)
        weekly_recurring = next((r for r in recurring if r['frequency'] == 'Weekly'), None)
        self.assertIsNotNone(weekly_recurring)
        self.assertEqual(weekly_recurring['description'], 'weekly yoga class')

    ...

    def test_edge_empty_transactions(self):
        """
        Ensure that functions gracefully handle an empty list of transactions.
        """
        # predict_trends should return a message indicating insufficient data
        trends = predict_trends([])
        self.assertIn('trend', trends)
        self.assertEqual(trends['trend'], 'Not enough data')

        # calculate_financial_health on empty list should not crash (might return infinity or 0)
        health = calculate_financial_health([])
        self.assertIn('debt_to_income_ratio', health)
        self.assertIn('savings_rate', health)
        self.assertIn('balance_growth_rate', health)
        self.assertIn('financial_health_score', health)

    async def test_not_transaction_analyzer(self):
        """
        Ensure that functions gracefully handle invalid transaction data.
        """
        analysis = await analyze_transactions(None, self.websocket_manager)
        self.assertIn('error', analysis)
        self.assertEqual(analysis['error'], 'No transactions provided')
        self.assertTrue(self.fake_ws.messages)

    async def test_analyze_transactions_with_websocket(self):
        """Test the analyze_transactions function with a WebSocketManager."""
        # Create an invalid transaction
        tx_1 = [
            {
                '_id': str(uuid.uuid4()),
                'date': '2024-01-01T00:00:00',
                'createdAt': '2024-01-01T00:00:00',
                'updatedAt': '2024-01-01T00:00:00',
                'description': 'Test expense',
                'amount': -100,
                'type': 'expense',
                'userId': '1',
            }
        ]

        result = await analyze_transactions(tx_1, self.websocket_manager)
        self.assertIn('error', result)
        self.assertEqual(result['error'], 'No valid transactions provided')
        # Check that progress messages were sent
        self.assertTrue(self.fake_ws.messages)

        # Create valid transactions
        tx_2 = self.create_transaction_dict('2024-01-01T00:00:00', 'Salary', 2000, 2900, 'income')
        result = await analyze_transactions([tx_2], self.websocket_manager)
        self.assertIn('categories', result)
        # Check that progress messages were sent
        self.assertTrue(self.fake_ws.messages)

    @patch('src.utils.analyzer.validate_and_convert_transactions')
    async def test_analyze_transactions_validation_exception(self, mock_validate):
        """Test analyze_transactions handling when validation fails"""
        valid_tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Salary', 2000, 2900, 'income')
        mock_validate.side_effect = ValueError('Mock validation error')
        result = await analyze_transactions(valid_tx, self.websocket_manager)
        self.assertIn('error', result)
        msg = self.fake_ws.messages[-1]
        self.assertEqual(msg['action'], 'progress')
        self.assertEqual(msg['message'], 'Analysis failed')
        self.assertEqual(msg['taskType'], 'Analysis')

    @patch('src.utils.analyzer.pipeline')
    async def test_classify_transactions_exception(self, mock_pipeline):
        """Test classify_transactions handling when pipeline fails"""
        tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
        transactions = await validate_and_convert_transactions([tx])
        mock_pipeline.side_effect = RuntimeError('Mock pipeline error')

        result = await classify_transactions(transactions, self.websocket_manager)

        # Check error response
        self.assertIn('error', result)
        self.assertTrue('Classification failed' in result['error'])

        # Check websocket message
        msg = self.fake_ws.messages[-1]
        self.assertEqual(msg['action'], 'progress')
        self.assertEqual(msg['message'], 'Analysis failed')
        self.assertEqual(msg['taskType'], 'Analysis')

    @patch('src.utils.analyzer.get_device')
    @patch('src.utils.analyzer.pipeline')
    async def test_classify_transactions_device_cpu(self, mock_pipeline, mock_device):
        """Test that classify_transactions uses CPU device for the pipeline."""
        tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
        transactions = await validate_and_convert_transactions([tx])
        mock_device.return_value = (torch.device('cpu'), 'CPU')
        await classify_transactions(transactions, self.websocket_manager)
        mock_pipeline.assert_called_once_with('zero-shot-classification', model='facebook/bart-large-mnli', device=-1)

    @patch('src.utils.analyzer.get_device')
    @patch('src.utils.analyzer.pipeline')
    async def test_classify_transactions_device_gpu(self, mock_pipeline, mock_device):
        """Test that classify_transactions uses GPU device for the pipeline."""
        tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
        transactions = await validate_and_convert_transactions([tx])
        mock_device.return_value = (torch.device('cuda'), 'GPU')
        await classify_transactions(transactions, self.websocket_manager)
        mock_pipeline.assert_called_once_with('zero-shot-classification', model='facebook/bart-large-mnli', device=0)

    @patch('src.utils.analyzer.get_device')
    @patch('src.utils.analyzer.pipeline')
    async def test_classify_transactions_device_mps(self, mock_pipeline, mock_device):
        """Test that classify_transactions uses MPS (Apple Metal) device for the pipeline."""
        tx = self.create_transaction_dict('2024-01-01T00:00:00', 'Test expense', -100, 900, 'expense')
        transactions = await validate_and_convert_transactions([tx])
        mock_device.return_value = (torch.device('mps'), 'MPS (Apple Metal)')
        await classify_transactions(transactions, self.websocket_manager)
        mock_pipeline.assert_called_once_with('zero-shot-classification', model='facebook/bart-large-mnli', device=0)
Enter fullscreen mode Exit fullscreen mode

This thorough testing allows us to have confidence in the reliability of our code. The repository's tests folder contains other test files that rigorously test our implementations. Currently, we have 100% test coverage on the AI service, and static analysis is enforced.

We will stop here. In the next article, we will return to implementing the dashboard.

Outro

Enjoyed this article? I'm a Software Engineer, Technical Writer and Technical Support Engineer actively seeking new opportunities, particularly in areas related to web security, finance, healthcare, and education. If you think my expertise aligns with your team's needs, let's chat! You can find me on LinkedIn and X. I am also an email away.

If you found this article valuable, consider sharing it with your network to help spread the knowledge!

Top comments (0)