338 lines
12 KiB
Python
338 lines
12 KiB
Python
"""
|
||
CrewAI Flow that orchestrates the blog-generation pipeline.
|
||
|
||
Flow
|
||
----
|
||
1. **Research crew** – a critical researcher with web-search investigates the
|
||
topic and produces verified findings.
|
||
2. **Writing crew** – four creative journalists write draft blog articles
|
||
in parallel (async tasks).
|
||
3. **Editor crew** – a critical editor loads the journalist drafts into
|
||
ChromaDB, queries for the most relevant context, and produces the final
|
||
polished markdown document complete with a metadata header (Title, Date,
|
||
Category, Tags, Slug, Authors, Summary).
|
||
|
||
The ChromaDB integration is preserved from the original implementation: each
|
||
journalist draft is chunked, embedded, and stored in a collection; the editor
|
||
receives the top-N most relevant chunks as context.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import string
|
||
from datetime import datetime
|
||
from typing import TYPE_CHECKING
|
||
|
||
if TYPE_CHECKING:
|
||
import chromadb # noqa: F811
|
||
|
||
from crewai.flow.flow import Flow, listen, start
|
||
from ollama import Client
|
||
from pydantic import BaseModel, ConfigDict
|
||
|
||
from ai_generators.crews.editor_crew.editor_crew import EditorCrew
|
||
from ai_generators.crews.research_crew.research_crew import ResearchCrew
|
||
from ai_generators.crews.writing_crew.writing_crew import WritingCrew
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# State
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class BlogFlowState(BaseModel):
|
||
"""Structured state for the blog generation flow."""
|
||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||
|
||
title: str = ""
|
||
inner_title: str = ""
|
||
content: str = ""
|
||
research_findings: str = ""
|
||
drafts: list[str] = []
|
||
final_document: str = ""
|
||
date: str = ""
|
||
authors: str = ""
|
||
category: str = ""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Flow
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class BlogFlow(Flow[BlogFlowState]):
|
||
"""Orchestrate researcher → journalists → editor via CrewAI Flows.
|
||
|
||
Usage::
|
||
|
||
flow = BlogFlow()
|
||
result = flow.kickoff(inputs={
|
||
"title": "my_blog_slug",
|
||
"inner_title": "My Blog Title",
|
||
"content": "<original content>",
|
||
})
|
||
print(result) # final markdown document
|
||
"""
|
||
|
||
# ------------------------------------------------------------------
|
||
# Helpers – Ollama / ChromaDB / embedding utilities
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _get_ollama_url() -> str:
|
||
return (
|
||
f"{os.environ['OLLAMA_PROTOCOL']}://"
|
||
f"{os.environ['OLLAMA_HOST']}:{os.environ['OLLAMA_PORT']}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _get_chroma_client() -> "chromadb.HttpClient":
|
||
# Lazily import chromadb here rather than at module level.
|
||
# chromadb unconditionally loads hnswlib (a native C++ library
|
||
# compiled with AVX instructions) even when using HttpClient
|
||
# to talk to an external ChromaDB server. Deferring the import
|
||
# avoids "Illegal instruction" (SIGILL) crashes in environments
|
||
# that lack AVX support (e.g. act, older CI runners).
|
||
try:
|
||
import chromadb
|
||
except ImportError as exc:
|
||
raise RuntimeError(
|
||
"chromadb is required for the editor phase but could not be "
|
||
f"imported: {exc}"
|
||
) from exc
|
||
|
||
chroma_port = int(os.environ["CHROMA_PORT"])
|
||
return chromadb.HttpClient(host=os.environ["CHROMA_HOST"], port=chroma_port)
|
||
|
||
@staticmethod
|
||
def _get_ollama_client() -> Client:
|
||
return Client(host=BlogFlow._get_ollama_url())
|
||
|
||
@staticmethod
|
||
def _id_generator(size: int = 6) -> str:
|
||
return "".join(
|
||
random.choice(string.ascii_uppercase + string.digits) for _ in range(size)
|
||
)
|
||
|
||
@staticmethod
|
||
def _split_into_chunks(text: str, chunk_size: int = 100) -> list[str]:
|
||
words = re.findall(r"\S+", text)
|
||
chunks: list[str] = []
|
||
current_chunk: list[str] = []
|
||
word_count = 0
|
||
for word in words:
|
||
current_chunk.append(word)
|
||
word_count += 1
|
||
if word_count >= chunk_size:
|
||
chunks.append(" ".join(current_chunk))
|
||
current_chunk = []
|
||
word_count = 0
|
||
if current_chunk:
|
||
chunks.append(" ".join(current_chunk))
|
||
return chunks
|
||
|
||
@staticmethod
|
||
def _get_embeddings(chunks: list[str]) -> list[list[float]]:
|
||
ollama_client = BlogFlow._get_ollama_client()
|
||
embed_model = os.environ["EMBEDDING_MODEL"]
|
||
try:
|
||
embeds = ollama_client.embed(model=embed_model, input=chunks)
|
||
return embeds.get("embeddings", []) # type: ignore[no-any-return]
|
||
except Exception as exc:
|
||
print(f"Error generating embeddings: {exc}")
|
||
return []
|
||
|
||
def _load_drafts_to_vector_db(self, drafts: list[str]) -> "chromadb.Collection":
|
||
"""Load journalist drafts into a new ChromaDB collection and return it."""
|
||
chroma = self._get_chroma_client()
|
||
collection_name = (
|
||
f"blog_{self.state.title.lower().replace(' ', '_')}_{self._id_generator()}"
|
||
)
|
||
collection = chroma.get_or_create_collection(name=collection_name)
|
||
|
||
for i, draft in enumerate(drafts):
|
||
model_name = f"journalist_{i + 1}"
|
||
chunks = self._split_into_chunks(draft)
|
||
if not chunks or all(chunk.strip() == "" for chunk in chunks):
|
||
print(f"Skipping {model_name} – no content generated")
|
||
continue
|
||
print(f"Generating embeddings for {model_name}")
|
||
embeds = self._get_embeddings(chunks)
|
||
if not embeds:
|
||
print(f"Skipping {model_name} – no embeddings generated")
|
||
continue
|
||
if len(embeds) != len(chunks):
|
||
min_length = min(len(embeds), len(chunks))
|
||
chunks = chunks[:min_length]
|
||
embeds = embeds[:min_length]
|
||
if min_length == 0:
|
||
print(f"Skipping {model_name} – no valid content/embeddings pairs")
|
||
continue
|
||
ids = [model_name + str(j) for j in range(len(chunks))]
|
||
metadata = [{"model_agent": model_name} for _ in chunks]
|
||
print(f"Loading into collection for {model_name}")
|
||
collection.add(
|
||
documents=chunks,
|
||
embeddings=embeds, # type: ignore[arg-type]
|
||
ids=ids,
|
||
metadatas=metadata, # type: ignore[arg-type]
|
||
)
|
||
return collection
|
||
|
||
@staticmethod
|
||
def _query_vector_db(collection: "chromadb.Collection", query_text: str) -> str:
|
||
"""Query the ChromaDB collection and return the most relevant
|
||
document chunks joined as a single string."""
|
||
ollama_client = BlogFlow._get_ollama_client()
|
||
embed_model = os.environ["EMBEDDING_MODEL"]
|
||
try:
|
||
embed_result = ollama_client.embed(model=embed_model, input=query_text)
|
||
query_embed = embed_result.get("embeddings", [])
|
||
if not query_embed:
|
||
print(
|
||
"Warning: Failed to generate query embeddings, "
|
||
"falling back to empty list"
|
||
)
|
||
query_embed = [[]]
|
||
except Exception as exc:
|
||
print(f"Error generating query embeddings: {exc}")
|
||
query_embed = [[]]
|
||
|
||
try:
|
||
query_result = collection.query(
|
||
query_embeddings=query_embed,
|
||
n_results=100, # type: ignore[arg-type]
|
||
)
|
||
documents = query_result.get("documents", [])
|
||
if documents and len(documents) > 0 and len(documents[0]) > 0:
|
||
return "\n\n".join(documents[0])
|
||
print("Warning: No relevant documents found in collection")
|
||
return "No relevant information found in drafts."
|
||
except Exception as exc:
|
||
print(f"Error querying collection: {exc}")
|
||
return "No relevant information found in drafts due to query error."
|
||
|
||
# ------------------------------------------------------------------
|
||
# Flow steps
|
||
# ------------------------------------------------------------------
|
||
|
||
@start()
|
||
def research(self) -> str:
|
||
"""Run the research crew to investigate the blog topic."""
|
||
print("=" * 60)
|
||
print("RESEARCH PHASE – investigating topic")
|
||
print("=" * 60)
|
||
|
||
result = (
|
||
ResearchCrew()
|
||
.crew()
|
||
.kickoff(
|
||
inputs={
|
||
"inner_title": self.state.inner_title,
|
||
"content": self.state.content,
|
||
}
|
||
)
|
||
)
|
||
self.state.research_findings = result.raw
|
||
print("Research phase complete")
|
||
return result.raw
|
||
|
||
@listen(research)
|
||
def write_drafts(self, research_findings: str) -> list[str]:
|
||
"""Run the writing crew (4 journalists in parallel) and collect
|
||
their draft outputs."""
|
||
print("=" * 60)
|
||
print("WRITING PHASE – 4 journalists drafting in parallel")
|
||
print("=" * 60)
|
||
|
||
result = (
|
||
WritingCrew()
|
||
.crew()
|
||
.kickoff(
|
||
inputs={
|
||
"inner_title": self.state.inner_title,
|
||
"content": self.state.content,
|
||
"research_findings": research_findings,
|
||
}
|
||
)
|
||
)
|
||
|
||
# Collect all draft outputs from the crew's task outputs
|
||
drafts: list[str] = []
|
||
for task_output in result.tasks_output:
|
||
drafts.append(task_output.raw)
|
||
|
||
self.state.drafts = drafts
|
||
print(f"Writing phase complete – {len(drafts)} drafts produced")
|
||
return drafts
|
||
|
||
@staticmethod
|
||
def _compute_authors() -> str:
|
||
"""Build an author string from the CONTENT_CREATOR_MODELS env var.
|
||
|
||
Each model name is stripped of any tag suffix (e.g. ``:latest``)
|
||
and ``.ai`` is appended. Multiple models are joined with ``', '``.
|
||
"""
|
||
try:
|
||
models = json.loads(os.environ["CONTENT_CREATOR_MODELS"])
|
||
except (KeyError, json.JSONDecodeError):
|
||
models = []
|
||
authors = ", ".join(m.split(":")[0].split("/")[-1] + ".ai" for m in models)
|
||
return authors or "unknown.ai"
|
||
|
||
@listen(write_drafts)
|
||
def edit_final(self, drafts: list[str]) -> str:
|
||
"""Load journalist drafts into the vector DB, query for the most
|
||
relevant context, and run the editor crew to produce the final
|
||
polished document with a metadata header."""
|
||
print("=" * 60)
|
||
print("EDITOR PHASE – producing final document")
|
||
print("=" * 60)
|
||
|
||
# ---- Compute date and authors for the metadata header ----
|
||
if not self.state.date:
|
||
self.state.date = datetime.now().strftime("%Y-%m-%d %H:%M")
|
||
self.state.authors = self._compute_authors()
|
||
if not self.state.category:
|
||
self.state.category = "<pick one word that best describes the topic, e.g. Homelab, DevOps, Security, Networking>"
|
||
|
||
# ---- Vector DB integration ----
|
||
print("Loading drafts into vector database")
|
||
collection = self._load_drafts_to_vector_db(drafts)
|
||
|
||
# Build the editor's brief so we can query the vector DB with it
|
||
editor_brief = (
|
||
f"You are an editor taking information from 3 Software "
|
||
f"Developers and Data experts writing a 5000 word blog article. "
|
||
f"You like when they use almost no code examples. "
|
||
f"You are also Australian. The title for the blog is "
|
||
f"{self.state.inner_title}. "
|
||
f"The basis for the content of the blog is: "
|
||
f"<blog>{self.state.content}</blog>"
|
||
)
|
||
draft_context = self._query_vector_db(collection, editor_brief)
|
||
print("Showing pertinent info from drafts used in final edited edition")
|
||
|
||
# ---- Editor crew ----
|
||
result = (
|
||
EditorCrew()
|
||
.crew()
|
||
.kickoff(
|
||
inputs={
|
||
"inner_title": self.state.inner_title,
|
||
"content": self.state.content,
|
||
"draft_context": draft_context,
|
||
"date": self.state.date,
|
||
"authors": self.state.authors,
|
||
"category": self.state.category,
|
||
}
|
||
)
|
||
)
|
||
self.state.final_document = result.raw
|
||
print("Editor phase complete")
|
||
return result.raw
|