DEV Community

Antidisestablishmentarianism
Antidisestablishmentarianism

Posted on • Edited on

C# SIMD byte array compare

My byte array compare that I recently posted on stackoverflow.
https://stackoverflow.com/a/69280107/13843929

Quoting my own post on stackoverflow...
"This is similar to others, but the difference here is that there is no falling through to the next highest number of bytes I can check at once, e.g. if I have 63 bytes (in my SIMD example) I can check the equality of the first 32 bytes, and then the last 32 bytes, which is faster than checking 32 bytes, 16 bytes, 8 bytes, and so on. The first check you enter is the only check you will need to compare all of the bytes."

It is the fastest performer in my tests.

using System.Runtime.Intrinsics.X86;

internal class Program
{
    public unsafe bool SIMDNoFallThrough(byte[] arr1, byte[] arr2)
    {
        if (arr1 == null || arr2 == null)
            return false;

        int arr1length = arr1.Length;

        if (arr1length != arr2.Length)
            return false;

        fixed (byte* b00 = arr1, b01 = arr2)
        {
            byte* b0 = b00, b1 = b01, last0 = b0 + arr1length, last1 = b1 + arr1length, last32 = last0 - 31;

            if (arr1length > 31)
            {
                while (b0 < last32)
                {
                    if (Avx2.MoveMask(Avx2.CompareEqual(Avx.LoadVector256(b0), Avx.LoadVector256(b1))) != -1)
                        return false;
                    b0 += 32;
                    b1 += 32;
                }
                return Avx2.MoveMask(Avx2.CompareEqual(Avx.LoadVector256(last0 - 32), Avx.LoadVector256(last1 - 32))) == -1;
            }

            if (arr1length > 15)
            {
                if (Sse2.MoveMask(Sse2.CompareEqual(Sse2.LoadVector128(b0), Sse2.LoadVector128(b1))) != 65535)
                    return false;
                return Sse2.MoveMask(Sse2.CompareEqual(Sse2.LoadVector128(last0 - 16), Sse2.LoadVector128(last1 - 16))) == 65535;
            }

            if (arr1length > 7)
            {
                if (*(ulong*)b0 != *(ulong*)b1)
                    return false;
                return *(ulong*)(last0 - 8) == *(ulong*)(last1 - 8);
            }

            if (arr1length > 3)
            {
                if (*(uint*)b0 != *(uint*)b1)
                    return false;
                return *(uint*)(last0 - 4) == *(uint*)(last1 - 4);
            }

            if (arr1length > 1)
            {
                if (*(ushort*)b0 != *(ushort*)b1)
                    return false;
                return *(ushort*)(last0 - 2) == *(ushort*)(last1 - 2);
            }

            return *b0 == *b1;
        }
    }

    static void Main(string[] args)
    {
    }
}
Enter fullscreen mode Exit fullscreen mode

Top comments (0)