DEV Community

Cover image for Semantic Routing with Qdrant, Rig & Rust
Josh Mo
Josh Mo

Posted on

Semantic Routing with Qdrant, Rig & Rust

Introduction

In conversational AI and intelligent systems, decision-making is often the most critical layer. By themselves, LLMs can be mislead by users into making the wrong decisions. This has manifested itself many times - particularly to the detriment of companies who don't safeguard against it. However, what if I told you there's an easy way to protect your model against injection attacks and the like that's also cheap on compute?

In this article, we’ll delve into building an efficient semantic router using Qdrant, Rig, and Rust. By combining Qdrant’s vector search capabilities, the Rig LLM framework, and Rust’s performance, we’ll craft a system that empowers your agents and model prompts to make precise, context-aware decisions.

What even is a semantic router?

A semantic router, in short, is a decision making layer for your agents (and model prompts). It works as a lightweight abstraction by grouping utterances (short snippets of sentences) into topics, then comparing the original query against the router for semantic similarity. A topic is then returned based on the most similar sentence - or none, if there's no similar topic.

Semantic routing has several advantages:

  • Your agents won't deviate from the topics you want them to talk about
  • It's scalable to however many "routes" you want to implement
  • It's fast and doesn't require much compute

Getting Started

Before we get started, make sure the Rust programming language is installed! If you plan to run this, you'll also want an OpenAI API key. Make sure you grab one from the dashboard.

Before you run cargo run, make sure you export your API key:

export OPENAI_API_KEY=<key-goes-here>
Enter fullscreen mode Exit fullscreen mode

Let's get building!

Ok so first of all let's create our project. Create your project like so with cargo init then cd into it:

cargo init semantic-router
cd semantic-router
Enter fullscreen mode Exit fullscreen mode

The first thing we'll need to create is the thing that we want to embed. Additionally because we'll be grouping our embeddings by topic, we will want to declare a Topic struct that can be used to easily generate new instances of Utterance that can be easily embedded.

Note that although there is a Derive macro that rig provides, we embed manually here as there are other fields that we need to preserve which will act as metadata.

// topics.rs
use rig::{embeddings::{EmbedError, TextEmbedder}, Embed};

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
pub struct Utterance {
    pub id: String,
    pub topic: String,
    pub content: String,
}

impl Utterance {
    pub fn new(topic: &str, content: &str) -> Self {
        Self {
            id: uuid::Uuid::new_v4().to_string(),
            topic: topic.to_string(),
            content: content.to_string(),
        }
    }
}
impl Embed for Utterance {
    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
        // Embeddings only need to be generated for `content` field.
        // Queries will be compared against the content        
        embedder.embed(self.content.to_owned());

       Ok(())
    }
}
Enter fullscreen mode Exit fullscreen mode

Note that while Utterance::new() does take a topic name, it is a bit cumbersome to have to keep writing the topic name every time you need to instantiate a new Utterance. It might not matter for development workloads, but what about when you need to do a large amount of snippets for say, 50-60 topics? How can we make this a bit more ergonomic?

This is where we'll create our Topic struct. This struct will be used to make it easier to create instances of Utterance per topic. See below:

// topics.rs
pub struct Topic {
    name: String
}

impl Topic {
    pub fn new(name: &str) -> Self {
        let name = name.to_string();
        Self {
            name
        }
    }

    pub fn new_utterance(&self, content: &str) -> Utterance {
        Utterance::new(&self.name, content)
    }
}
Enter fullscreen mode Exit fullscreen mode

Now to build our semantic router! For this part, we'll create a new() method that tries to initialise an OpenAI client using environment variables then attempts to connect to Qdrant locally, creating a collection if none exists (with the given name) and returns the struct.

use rig::{embeddings::EmbeddingsBuilder, providers::openai::{Client, EmbeddingModel as Model, TEXT_EMBEDDING_ADA_002}};
use rig_qdrant::QdrantVectorStore;
use qdrant_client::{qdrant::{CreateCollectionBuilder, PointStruct, QueryPointsBuilder, VectorParamsBuilder}, Payload, Qdrant};


pub struct SemanticRouter {
    model: Model,
    qdrant: Arc<Qdrant>,
    vector_store: QdrantVectorStore<Model>,
}

const COLLECTION_NAME: &str = "SEMANTIC_ROUTING";
const COLLECTION_SIZE: usize = 1536;

impl SemanticRouter {
    pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {

        // Initialize OpenAI client.
        // Get your API key from https://platform.openai.com/api-keys
        let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
        let openai_client = Client::new(&openai_api_key);

        let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

        // note that this assumes you're hosting Qdrant locally on localhost:6334
        let qdrant = Arc::new(Qdrant::from_url("http://localhost:6334").build()?);

        // note that we use `Arc::into_inner() here` because
        // the Qdrant client doesn't actually implement Clone
        // so we need to create an `Arc::clone` then get the inner value
        let qdrant_inner = Arc::clone(&qdrant);
        let qdrant_inner = Arc::into_inner(qdrant_inner).unwrap();

            // Create a collection with 1536 dimensions if it doesn't exist
            // Note: Make sure the dimensions match the size of the embeddings returned by the
            // model you are using
            if !qdrant.collection_exists(COLLECTION_NAME).await? {
                qdrant
                    .create_collection(
                        CreateCollectionBuilder::new(COLLECTION_NAME)
                            .vectors_config(VectorParamsBuilder::new(COLLECTION_SIZE as u64, qdrant_client::qdrant::Distance::Cosine)),
                    )
                    .await?;
            }


        let query_params = QueryPointsBuilder::new(COLLECTION_NAME).with_payload(true);
        let vector_store = QdrantVectorStore::new(qdrant_inner, model.clone(), query_params.build());
            Ok(Self {
                model,
                qdrant,
                vector_store
            })

    }
}
Enter fullscreen mode Exit fullscreen mode

Next, we need to implement methods for embedding utterances and adding them to our Qdrant collection. We iterate through the collection of documents, mapping them into a Vec<f32> and creating a PointStruct. Then all of our points are added.

// semantic_router.rs
use rig::vector_store::{VectorStoreIndex};

use crate::topics::Utterance;

impl SemanticRouter {
    // .. other method(s) above
    pub async fn embed_utterances(&self, utterances: Vec<Utterance>) -> Result<(), Box<dyn std::error::Error>> {
        let mut documents = EmbeddingsBuilder::new(self.model.clone())
            .documents(utterances)?
            .build()
            .await?;

        let points: Vec<PointStruct> = documents
            .into_iter()
            .map(|(d, embeddings)| {
                let vec: Vec<f32> = embeddings.first().vec.iter().map(|&x| x as f32).collect();
                PointStruct::new(
                    d.id.clone(),
                    vec,
                    Payload::try_from(serde_json::to_value(&d).unwrap()).unwrap(),
                )
            })
            .collect();

            self.qdrant
                .upsert_points(UpsertPointsBuilder::new(COLLECTION_NAME, points))
                .await?;

            Ok(())
    }
}
Enter fullscreen mode Exit fullscreen mode

Finally, we need to add our query method for checking whether or not a prompt matches a conversation topic. Note that we only want results that are semantically similar to our prompt - so we need to make sure we check the similarity score is high and only return the metadata if it matches or passes the threshold. If not, we should return an error here - which will then get handled higher up the stack.

// semantic_router.rs
impl SemanticRouter {
    // .. other methods
    pub async fn query(&self, query: &str) -> Result<Utterance, Box<dyn std::error::Error>> {

        let results = self.vector_store
            .top_n::<Utterance>(query, 1)
            .await?;

        if results[0].0 <= 0.85 {
            return Err("No relevant snippet found.".into());
        }

        Ok(results[0].2.clone())
    }

}

Enter fullscreen mode Exit fullscreen mode

Now to wrap this up, we need to add some code to our main function to show that this works. See the following code below:

// main.rs
pub mod semantic_router;
pub mod topics;

use semantic_router::SemanticRouter;
use topics::{Topic, Utterance};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let router = SemanticRouter::new().await.unwrap();

    let bees_topic = Topic::new("bees");

    // create a vector of strings then iterate through the vector
    // and map them all to `Utterance` instances
    let bee_facts = vec![
    "Bees communicate with their hive mates through intricate dances that convey the location of nectar-rich flowers.",
    "A single bee can visit up to 5,000 flowers in a day, tirelessly collecting nectar and pollen.",
    "The queen bee can lay up to 2,000 eggs in a single day during peak season.",
].into_iter().map(|x| bees_topic.new_utterance(x)).collect::<Vec<Utterance>>();

    // embed utterances into Qdrant
    router.embed_utterances(bee_facts).await.unwrap();

    let bee_answer = router.query("how many flowers does a bee visit in a day?").await?;
    println!("Topic: {}", bee_answer.topic);

    // note that this query *should* error out as it's unrelated
    // in which case, we simply tell the user we can't help them
     match router.query("what is skibidi toilet").await {
         Ok(res) => println!("Unexpectedly found a topic: {}", res.topic),
        Err(_) => println!("Sorry, I can't help you with that.")
    };

    Ok(())

}
Enter fullscreen mode Exit fullscreen mode

Extending this example

So, now you know how to create your own semantic routing! Some ways to extend this can be found below:

Usage in a web application

For usage with web applications, you may find that you need to implement Clone for your shared application state - our router doesn't actually implement Clone. An easy way to fix this is to simply use the Arc<T> pattern, which will allow it to be shared. Note that this only works if the inner type is already Send + Sync - wrapping it in Arc doesn't automatically guarantee you can use it!

Finishing up

Thanks for reading! Hopefully this article has helped you learn a little bit more about how you can make your RAG and LLM-assisted applications even better than before.

Top comments (0)