DEV Community

Cover image for Text-to-SQL: Generating SQL with Nebius AI Studio (part 2)
Sophia Parafina
Sophia Parafina

Posted on

Text-to-SQL: Generating SQL with Nebius AI Studio (part 2)

In the previous post, we created documents based on the tables in the Northwind Trader database. We used Nebius AI Studio's embedding model to vectorize the documents and insert them into Postgres with the vector extension. This post shows how to query the database and use the results in a prompt. Let's dive into the code.

A simple client

The script begins with instantiating a Nebius client and the database connection parameters.

API_KEY = os.environ.get('NEBIUS_API_KEY')

conn = psycopg.connect(dbname='rag',user="postgres", autocommit=True)

client = OpenAI(
    base_url="https://api.studio.nebius.ai/v1/",
    api_key=API_KEY,
)
Enter fullscreen mode Exit fullscreen mode

The query function retrieves the embeddings from the database. To form the query, we reuse the create_vector function. The query uses cosine similarity. Pgvector offers other vector search methods.

def query(query_string):
    vector = create_vector(query_string)
    embedding_query = "[" + ",".join(map(str, vector)) + "]"

    # cosine similarity
    query_sql = f"""
    SELECT chunk, embedding <=>'{embedding_query}' AS similarity
    FROM items
    ORDER BY embedding <=> '{embedding_query}'
    LIMIT 20;
    """
    data = conn.execute(query_sql).fetchall()
    result=[]
    for row in data:
        result.append(row[0])

    return result     

def create_vector(text):
     embedding = client.embeddings.create(
           model="BAAI/bge-en-icl",
           input=text,
           dimensions=1024,
           encoding_format="float",
           ).data[0].embedding    

     return embedding

Enter fullscreen mode Exit fullscreen mode

The database result enhances the OpenAI-style prompt when added to the context. The prompt is sent to the Nebius Qwen2.5-Coder-7B-Instruct model, which was trained with SQL, among other programming languages.

def create_prompt(llm_query, database_results):
    content_start = (
        "Write a SQL statement using the database information provided.\n\n"+
        "Context:\n"
    )

    content_end = (
        f"\n\nQuestion: {llm_query}\nAnswer:"
    )

    content = (
        content_start + "\n\n---\n\n".join(database_results) + 
        content_end
    )

    prompt = [{'role': 'user', 'content': content }]

    return prompt

def create_completion(prompt):

    completion = client.chat.completions.create(
        model = "Qwen/Qwen2.5-Coder-7B-Instruct",
        messages = prompt,
        temperature=0.6
    )

    return completion.to_json()
Enter fullscreen mode Exit fullscreen mode

Next, we write a query in English: "List the number of suppliers alphabetically by country." The query is sent to the database to retrieve applicable embeddings.

client_query = "List number of suppliers alphabetically by country."
rag_results = query(client_query)
prompt = create_prompt(rag_query, rag_results)
response = json.loads(create_completion(prompt))

print(response["choices"][0]["message"]["content"])
Enter fullscreen mode Exit fullscreen mode

The embeddings are added to the prompt, and the client returns the following.

To list the number of suppliers alphabetically by country, you can use the following SQL statement:

SELECT country, COUNT(*) AS number_of_suppliers
FROM suppliers
GROUP BY country
ORDER BY country;

This query performs the following actions:
- `SELECT country, COUNT(*) AS number_of_suppliers`: Selects the `country` column and counts the number of rows for each country.
- `FROM suppliers`: Specifies the table from which to retrieve the data.
- `GROUP BY country`: Groups the results by the `country` column.
- `ORDER BY country`: Orders the results alphabetically by the `country` column.
Enter fullscreen mode Exit fullscreen mode

For convenience, use pgAdmin or a similar tool to test the result.

Result of LLM generated SQL

Takeaways

This code is a testbed to try out concepts. Several changes had to be made to the documents to produce working SQL:

  • The document included the database name along with the table name. The database name was removed since the table name is sufficient for queries.
  • Tables without relations should be removed, or the LLM uses them to produce SQL. This was the case for the us_states table, which did not have a secondary key for joins.
  • Qwen2.5-Coder creates table aliases and sometimes assigns them to the wrong columns.

Formulating a query requires specificity. Using table names and terms similar to column names produces better results. So far, the LLM can produce standard SQL functions such as COUNT, DISTINCT, GROUP BY, ORDER BY, and OUTER JOINS.

The results are promising. The following post will show how to build an agent to run the SQL.

Top comments (0)