2026-05-21 22:44:04 +10:00

338 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
CrewAI Flow that orchestrates the blog-generation pipeline.
Flow
----
1. **Research crew** a critical researcher with web-search investigates the
topic and produces verified findings.
2. **Writing crew** four creative journalists write draft blog articles
in parallel (async tasks).
3. **Editor crew** a critical editor loads the journalist drafts into
ChromaDB, queries for the most relevant context, and produces the final
polished markdown document complete with a metadata header (Title, Date,
Category, Tags, Slug, Authors, Summary).
The ChromaDB integration is preserved from the original implementation: each
journalist draft is chunked, embedded, and stored in a collection; the editor
receives the top-N most relevant chunks as context.
"""
from __future__ import annotations
import json
import os
import random
import re
import string
from datetime import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import chromadb # noqa: F811
from crewai.flow.flow import Flow, listen, start
from ollama import Client
from pydantic import BaseModel, ConfigDict
from ai_generators.crews.editor_crew.editor_crew import EditorCrew
from ai_generators.crews.research_crew.research_crew import ResearchCrew
from ai_generators.crews.writing_crew.writing_crew import WritingCrew
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
class BlogFlowState(BaseModel):
"""Structured state for the blog generation flow."""
model_config = ConfigDict(arbitrary_types_allowed=True)
title: str = ""
inner_title: str = ""
content: str = ""
research_findings: str = ""
drafts: list[str] = []
final_document: str = ""
date: str = ""
authors: str = ""
category: str = ""
# ---------------------------------------------------------------------------
# Flow
# ---------------------------------------------------------------------------
class BlogFlow(Flow[BlogFlowState]):
"""Orchestrate researcher → journalists → editor via CrewAI Flows.
Usage::
flow = BlogFlow()
result = flow.kickoff(inputs={
"title": "my_blog_slug",
"inner_title": "My Blog Title",
"content": "<original content>",
})
print(result) # final markdown document
"""
# ------------------------------------------------------------------
# Helpers Ollama / ChromaDB / embedding utilities
# ------------------------------------------------------------------
@staticmethod
def _get_ollama_url() -> str:
return (
f"{os.environ['OLLAMA_PROTOCOL']}://"
f"{os.environ['OLLAMA_HOST']}:{os.environ['OLLAMA_PORT']}"
)
@staticmethod
def _get_chroma_client() -> "chromadb.HttpClient":
# Lazily import chromadb here rather than at module level.
# chromadb unconditionally loads hnswlib (a native C++ library
# compiled with AVX instructions) even when using HttpClient
# to talk to an external ChromaDB server. Deferring the import
# avoids "Illegal instruction" (SIGILL) crashes in environments
# that lack AVX support (e.g. act, older CI runners).
try:
import chromadb
except ImportError as exc:
raise RuntimeError(
"chromadb is required for the editor phase but could not be "
f"imported: {exc}"
) from exc
chroma_port = int(os.environ["CHROMA_PORT"])
return chromadb.HttpClient(host=os.environ["CHROMA_HOST"], port=chroma_port)
@staticmethod
def _get_ollama_client() -> Client:
return Client(host=BlogFlow._get_ollama_url())
@staticmethod
def _id_generator(size: int = 6) -> str:
return "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(size)
)
@staticmethod
def _split_into_chunks(text: str, chunk_size: int = 100) -> list[str]:
words = re.findall(r"\S+", text)
chunks: list[str] = []
current_chunk: list[str] = []
word_count = 0
for word in words:
current_chunk.append(word)
word_count += 1
if word_count >= chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = []
word_count = 0
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
@staticmethod
def _get_embeddings(chunks: list[str]) -> list[list[float]]:
ollama_client = BlogFlow._get_ollama_client()
embed_model = os.environ["EMBEDDING_MODEL"]
try:
embeds = ollama_client.embed(model=embed_model, input=chunks)
return embeds.get("embeddings", []) # type: ignore[no-any-return]
except Exception as exc:
print(f"Error generating embeddings: {exc}")
return []
def _load_drafts_to_vector_db(self, drafts: list[str]) -> "chromadb.Collection":
"""Load journalist drafts into a new ChromaDB collection and return it."""
chroma = self._get_chroma_client()
collection_name = (
f"blog_{self.state.title.lower().replace(' ', '_')}_{self._id_generator()}"
)
collection = chroma.get_or_create_collection(name=collection_name)
for i, draft in enumerate(drafts):
model_name = f"journalist_{i + 1}"
chunks = self._split_into_chunks(draft)
if not chunks or all(chunk.strip() == "" for chunk in chunks):
print(f"Skipping {model_name} no content generated")
continue
print(f"Generating embeddings for {model_name}")
embeds = self._get_embeddings(chunks)
if not embeds:
print(f"Skipping {model_name} no embeddings generated")
continue
if len(embeds) != len(chunks):
min_length = min(len(embeds), len(chunks))
chunks = chunks[:min_length]
embeds = embeds[:min_length]
if min_length == 0:
print(f"Skipping {model_name} no valid content/embeddings pairs")
continue
ids = [model_name + str(j) for j in range(len(chunks))]
metadata = [{"model_agent": model_name} for _ in chunks]
print(f"Loading into collection for {model_name}")
collection.add(
documents=chunks,
embeddings=embeds, # type: ignore[arg-type]
ids=ids,
metadatas=metadata, # type: ignore[arg-type]
)
return collection
@staticmethod
def _query_vector_db(collection: "chromadb.Collection", query_text: str) -> str:
"""Query the ChromaDB collection and return the most relevant
document chunks joined as a single string."""
ollama_client = BlogFlow._get_ollama_client()
embed_model = os.environ["EMBEDDING_MODEL"]
try:
embed_result = ollama_client.embed(model=embed_model, input=query_text)
query_embed = embed_result.get("embeddings", [])
if not query_embed:
print(
"Warning: Failed to generate query embeddings, "
"falling back to empty list"
)
query_embed = [[]]
except Exception as exc:
print(f"Error generating query embeddings: {exc}")
query_embed = [[]]
try:
query_result = collection.query(
query_embeddings=query_embed,
n_results=100, # type: ignore[arg-type]
)
documents = query_result.get("documents", [])
if documents and len(documents) > 0 and len(documents[0]) > 0:
return "\n\n".join(documents[0])
print("Warning: No relevant documents found in collection")
return "No relevant information found in drafts."
except Exception as exc:
print(f"Error querying collection: {exc}")
return "No relevant information found in drafts due to query error."
# ------------------------------------------------------------------
# Flow steps
# ------------------------------------------------------------------
@start()
def research(self) -> str:
"""Run the research crew to investigate the blog topic."""
print("=" * 60)
print("RESEARCH PHASE investigating topic")
print("=" * 60)
result = (
ResearchCrew()
.crew()
.kickoff(
inputs={
"inner_title": self.state.inner_title,
"content": self.state.content,
}
)
)
self.state.research_findings = result.raw
print("Research phase complete")
return result.raw
@listen(research)
def write_drafts(self, research_findings: str) -> list[str]:
"""Run the writing crew (4 journalists in parallel) and collect
their draft outputs."""
print("=" * 60)
print("WRITING PHASE 4 journalists drafting in parallel")
print("=" * 60)
result = (
WritingCrew()
.crew()
.kickoff(
inputs={
"inner_title": self.state.inner_title,
"content": self.state.content,
"research_findings": research_findings,
}
)
)
# Collect all draft outputs from the crew's task outputs
drafts: list[str] = []
for task_output in result.tasks_output:
drafts.append(task_output.raw)
self.state.drafts = drafts
print(f"Writing phase complete {len(drafts)} drafts produced")
return drafts
@staticmethod
def _compute_authors() -> str:
"""Build an author string from the CONTENT_CREATOR_MODELS env var.
Each model name is stripped of any tag suffix (e.g. ``:latest``)
and ``.ai`` is appended. Multiple models are joined with ``', '``.
"""
try:
models = json.loads(os.environ["CONTENT_CREATOR_MODELS"])
except (KeyError, json.JSONDecodeError):
models = []
authors = ", ".join(m.split(":")[0].split("/")[-1] + ".ai" for m in models)
return authors or "unknown.ai"
@listen(write_drafts)
def edit_final(self, drafts: list[str]) -> str:
"""Load journalist drafts into the vector DB, query for the most
relevant context, and run the editor crew to produce the final
polished document with a metadata header."""
print("=" * 60)
print("EDITOR PHASE producing final document")
print("=" * 60)
# ---- Compute date and authors for the metadata header ----
if not self.state.date:
self.state.date = datetime.now().strftime("%Y-%m-%d %H:%M")
self.state.authors = self._compute_authors()
if not self.state.category:
self.state.category = "<pick one word that best describes the topic, e.g. Homelab, DevOps, Security, Networking>"
# ---- Vector DB integration ----
print("Loading drafts into vector database")
collection = self._load_drafts_to_vector_db(drafts)
# Build the editor's brief so we can query the vector DB with it
editor_brief = (
f"You are an editor taking information from 3 Software "
f"Developers and Data experts writing a 5000 word blog article. "
f"You like when they use almost no code examples. "
f"You are also Australian. The title for the blog is "
f"{self.state.inner_title}. "
f"The basis for the content of the blog is: "
f"<blog>{self.state.content}</blog>"
)
draft_context = self._query_vector_db(collection, editor_brief)
print("Showing pertinent info from drafts used in final edited edition")
# ---- Editor crew ----
result = (
EditorCrew()
.crew()
.kickoff(
inputs={
"inner_title": self.state.inner_title,
"content": self.state.content,
"draft_context": draft_context,
"date": self.state.date,
"authors": self.state.authors,
"category": self.state.category,
}
)
)
self.state.final_document = result.raw
print("Editor phase complete")
return result.raw