udpate chromadb import to be lazy

This commit is contained in:
Andrew Ridgway 2026-05-21 21:28:03 +10:00
parent 1781a1dbf5
commit e69b83694c
Signed by: armistace
GPG Key ID: C8D9EAC514B47EF1

View File

@ -17,14 +17,19 @@ journalist draft is chunked, embedded, and stored in a collection; the editor
receives the top-N most relevant chunks as context. receives the top-N most relevant chunks as context.
""" """
from __future__ import annotations
import json import json
import os import os
import random import random
import re import re
import string import string
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import chromadb # noqa: F811
import chromadb
from crewai.flow.flow import Flow, listen, start from crewai.flow.flow import Flow, listen, start
from ollama import Client from ollama import Client
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -85,7 +90,9 @@ class BlogFlow(Flow[BlogFlowState]):
) )
@staticmethod @staticmethod
def _get_chroma_client() -> chromadb.HttpClient: def _get_chroma_client() -> "chromadb.HttpClient":
import chromadb
chroma_port = int(os.environ["CHROMA_PORT"]) chroma_port = int(os.environ["CHROMA_PORT"])
return chromadb.HttpClient(host=os.environ["CHROMA_HOST"], port=chroma_port) return chromadb.HttpClient(host=os.environ["CHROMA_HOST"], port=chroma_port)
@ -127,7 +134,7 @@ class BlogFlow(Flow[BlogFlowState]):
print(f"Error generating embeddings: {exc}") print(f"Error generating embeddings: {exc}")
return [] return []
def _load_drafts_to_vector_db(self, drafts: list[str]) -> chromadb.Collection: def _load_drafts_to_vector_db(self, drafts: list[str]) -> "chromadb.Collection":
"""Load journalist drafts into a new ChromaDB collection and return it.""" """Load journalist drafts into a new ChromaDB collection and return it."""
chroma = self._get_chroma_client() chroma = self._get_chroma_client()
collection_name = ( collection_name = (
@ -165,7 +172,7 @@ class BlogFlow(Flow[BlogFlowState]):
return collection return collection
@staticmethod @staticmethod
def _query_vector_db(collection: chromadb.Collection, query_text: str) -> str: def _query_vector_db(collection: "chromadb.Collection", query_text: str) -> str:
"""Query the ChromaDB collection and return the most relevant """Query the ChromaDB collection and return the most relevant
document chunks joined as a single string.""" document chunks joined as a single string."""
ollama_client = BlogFlow._get_ollama_client() ollama_client = BlogFlow._get_ollama_client()