If you are Langchain lover &/ do not use the Experimental Features from Haystack, this blog is not for you!
Need for this blog
Love the haystack team and I hope they roll out a RC soon on hayhooks, so we have a more intuitive experience. But till then, you can use this workaround, where we create pipeline task and set "sync" streaming callbacks on the running event loop to collect chunk and yeild the chunks.
Hand Holding
You can just copy the code and it should do exactly what you are looking for ie. streaming as Server-Sent Events.
Packages
Make sure to have these packages installed. Give uv or poetry a shot.
python = ">=3.10,<3.13"
fastapi = "^0.111.0"
uvicorn = "^0.30.1"
haystack-ai = "^2.8.0"
haystack-experimental = "^0.4.0"
pydantic = "^2.7.2"
FastAPI
This tutorial is not a guide for you to understand FastAPI, but for just in case, here is a skeleton for the endpoint.
import os
import asyncio
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Dict, Any, AsyncGenerator
class ModalPipeline:
def __init__(self, api_key: str):
# something
async def process_user_input(self, query: string) -> AsyncGenerator[str, None]:
# something
class ChatQuery(BaseModel):
"""Chat query request model with user's input message."""
query: str
api_key: str
@app.post("/chat")
async def chat(chat_query: ChatQuery):
pipeline = ModalPipeline(chat_query.api_key)
return StreamingResponse(
pipeline.process_user_input(chat_query.query),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"X-Accel-Buffering": "no",
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Pipeline
Next, we define the asynchronous pipeline. Here, I will be passing the API_KEY
through the the endpoint and therefore we add components to the pipeline dynamically.
from haystack_experimental.core import AsyncPipeline
from haystack.components.generators import OpenAIGenerator
template = """
Answer the question.
Question: {{question}}
Answer:
"""
class ModalPipeline:
def __init__(self, api_key: str):
self.api_key = api_key
self.generator = self.create_generator()
def create_generator(self):
return OpenAIGenerator(api_key=Secret.from_token(self.api_key), model="gpt-4o-mini")
async def run_pipeline(self, pipeline: AsyncPipeline, input_data: Dict[str, Any]) -> AsyncGenerator[str, None]:
# something
async def process_user_input(self, query: string) -> AsyncGenerator[str, None]:
rag_pipeline = AsyncPipeline()
rag_pipeline.add_component("prompt_builder", PromptBuilder(template= template))
rag_pipeline.add_component("generator", self.generator)
rag_pipeline.connect("prompt_builder.prompt", "generator.prompt")
input_data={
"prompt_builder": {
"query": query,
},
"generator": {}
}
async for chunk in self.run_pipeline(pipeline, input_data):
yield chunk
The Hard Part
If you have just defined functions, you can just follow the gist.
If not, follow the code snippet and I'll try to explain why we have it implemented this way.
async def run_pipeline(self, pipeline: AsyncPipeline, input_data: Dict[str, Any]) -> AsyncGenerator[str, None]:
request_collector = ChunkCollector() # code snippet is below
loop = asyncio.get_running_loop()
# Create sync wrapper for async callback
async def async_callback(chunk):
await collect_chunk(request_collector.queue, chunk) # code snippet is below
def sync_callback(chunk):
# Use run_coroutine_threadsafe instead of create_task
future = asyncio.run_coroutine_threadsafe(async_callback(chunk), loop)
try:
# Wait for the coroutine to complete
future.result()
except Exception as e:
print(f"Error in sync_callback: {str(e)}")
# Set callbacks using sync wrapper
input_data["generator"]["streaming_callback"] = sync_callback
async def pipeline_runner():
try:
async for _ in pipeline.run(input_data):
pass
finally:
await request_collector.queue.put(None)
# Create task for pipeline
pipeline_task = asyncio.create_task(pipeline_runner())
try:
# Start yielding chunks
async for chunk in request_collector.generator():
yield chunk
finally:
# Ensure pipeline task is cleaned up
if not pipeline_task.done():
pipeline_task.cancel()
try:
await pipeline_task
except asyncio.CancelledError:
pass
Some Q/A's for you
* Why dont we do a direct async callback without wrapping?
async def callback(chunk):
await collect_chunk(request_collector.queue, chunk)
input_data["generator"]["streaming_callback"] = callback
The generator is calling the callback synchronously, but we're passing an async function. So, we need a sync wrapper around our async callback, instead of trying to await an async generator.
* Why dont we just create a task?
def sync_callback(chunk):
asyncio.create_task(async_callback(chunk))
The callback is being called from a different thread where there's no event loop. So, we need a thread-safe way to schedule the callback.
Chunking in SSE Format
We need to define the request_collector
which handles the queue, stores the chunks and also yeilds the chunks from the queue(in SSE format).
from typing import AsyncGenerator
import uuid
import json
from asyncio import Queue
from haystack.dataclasses import StreamingChunk
class ChunkCollector:
"""Collects and queues streaming chunks."""
def __init__(self):
self.queue = Queue()
async def generator(self) -> AsyncGenerator[str, None]:
"""Yields chunks from the queue."""
# Send initial metadata event
yield 'event: metadata\n' + f'data: {{"run_id": "{uuid.uuid4()}"}}\n\n'
while True:
chunk = await self.queue.get()
if chunk is None:
# Send end event
yield 'event: end\n\n'
break
# Send data event
yield f'event: data\ndata: {json.dumps(chunk)}\n\n'
async def collect_chunk(queue: Queue, chunk: StreamingChunk):
"""
Collect chunks and store them in the queue.
:param queue: Queue to store the chunks
:param chunk: StreamingChunk to be collected
"""
if chunk and chunk.content:
await queue.put(chunk.content)
Frontend
You can directcly use EventSource or fetch. For this tutorial, let's use fetch-event-source:
import { fetchEventSource } from "@microsoft/fetch-event-source";
const abortController = new AbortController();
await fetchEventSource("/api/v1/multi-modal/stream", {
signal: abortController.signal,
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
onmessage(msg) {
if (msg.event === "data") {
const parsedData = JSON.parse(msg.data);
console.log(parsedData);
}
},
openWhenHidden: true,
onclose() {
// something
},
onerror(error) {
console.error(error);
throw error;
},
});
Final Notes
I completely agree with vblagoje here as sockets will be just ๐๐ฝ
If you found this blog helpful, just send a good vibe my wayโwhether itโs my (research taking off || side project getting some honey || landing some gigs) โ๐ฝ!
Top comments (0)