diff --git a/crawl4ai/__init__.py b/crawl4ai/__init__.py index 03e734deb..30ea4d5f6 100644 --- a/crawl4ai/__init__.py +++ b/crawl4ai/__init__.py @@ -42,6 +42,7 @@ LLMContentFilter, RelevantContentFilter, ) +from .document_extraction_strategy import DocumentExtractionStrategy, DocumentExtractionResult from .models import CrawlResult, MarkdownGenerationResult, DisplayMode from .components.crawler_monitor import CrawlerMonitor from .link_preview import LinkPreview diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index 37b5480d9..2e3ecb310 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -1403,6 +1403,7 @@ def __init__( extraction_strategy: ExtractionStrategy = None, chunking_strategy: ChunkingStrategy = RegexChunking(), markdown_generator: MarkdownGenerationStrategy = DefaultMarkdownGenerator(), + document_extraction_strategy=None, only_text: bool = False, css_selector: str = None, target_elements: List[str] = None, @@ -1526,6 +1527,7 @@ def __init__( self.extraction_strategy = extraction_strategy self.chunking_strategy = chunking_strategy self.markdown_generator = markdown_generator + self.document_extraction_strategy = document_extraction_strategy self.only_text = only_text self.css_selector = css_selector self.target_elements = target_elements or [] diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index 36b999fd1..61a8ddab1 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -455,45 +455,83 @@ async def arun( async_response = await self.crawler_strategy.crawl( url, config=config) - html = sanitize_input_encode(async_response.html) - screenshot_data = async_response.screenshot - pdf_data = async_response.pdf_data - js_execution_result = async_response.js_execution_result - - self.logger.url_status( - url=cache_context.display_url, - success=bool(html), - timing=time.perf_counter() - t1, - tag="FETCH", - ) - - crawl_result = await self.aprocess_html( - url=url, html=html, - extracted_content=extracted_content, - config=config, - screenshot_data=screenshot_data, - pdf_data=pdf_data, - verbose=config.verbose, - is_raw_html=True if url.startswith("raw:") else False, - redirected_url=async_response.redirected_url, - original_scheme=urlparse(url).scheme, - **kwargs, - ) - - crawl_result.status_code = async_response.status_code - is_raw_url = url.startswith("raw:") or url.startswith("raw://") - crawl_result.redirected_url = async_response.redirected_url or (None if is_raw_url else url) - crawl_result.redirected_status_code = async_response.redirected_status_code - crawl_result.response_headers = async_response.response_headers - crawl_result.downloaded_files = async_response.downloaded_files - crawl_result.js_execution_result = js_execution_result - crawl_result.mhtml = async_response.mhtml_data - crawl_result.ssl_certificate = async_response.ssl_certificate - crawl_result.network_requests = async_response.network_requests - crawl_result.console_messages = async_response.console_messages - crawl_result.success = bool(html) - crawl_result.session_id = getattr(config, "session_id", None) - crawl_result.cache_status = "miss" + # Document extraction: detect binary documents before HTML processing + doc_strategy = getattr(config, "document_extraction_strategy", None) + if doc_strategy and doc_strategy.detect(async_response): + self.logger.info( + message="Document detected for {url}, using document extraction", + tag="DOCUMENT", + params={"url": url}, + ) + doc_result = await doc_strategy.extract(async_response, url) + crawl_result = CrawlResult( + url=url, + html="", + success=True, + cleaned_html="", + _markdown=MarkdownGenerationResult( + raw_markdown=doc_result.content, + markdown_with_citations=doc_result.content, + references_markdown="", + fit_markdown="", + fit_html="", + ), + metadata={ + "is_document": True, + "content_type": doc_result.content_type, + **(doc_result.metadata or {}), + }, + status_code=async_response.status_code, + response_headers=async_response.response_headers, + downloaded_files=async_response.downloaded_files, + redirected_url=async_response.redirected_url, + redirected_status_code=async_response.redirected_status_code, + ssl_certificate=async_response.ssl_certificate, + network_requests=async_response.network_requests, + console_messages=async_response.console_messages, + session_id=getattr(config, "session_id", None), + cache_status="miss", + ) + else: + html = sanitize_input_encode(async_response.html) + screenshot_data = async_response.screenshot + pdf_data = async_response.pdf_data + js_execution_result = async_response.js_execution_result + + self.logger.url_status( + url=cache_context.display_url, + success=bool(html), + timing=time.perf_counter() - t1, + tag="FETCH", + ) + + crawl_result = await self.aprocess_html( + url=url, html=html, + extracted_content=extracted_content, + config=config, + screenshot_data=screenshot_data, + pdf_data=pdf_data, + verbose=config.verbose, + is_raw_html=True if url.startswith("raw:") else False, + redirected_url=async_response.redirected_url, + original_scheme=urlparse(url).scheme, + **kwargs, + ) + + crawl_result.status_code = async_response.status_code + is_raw_url = url.startswith("raw:") or url.startswith("raw://") + crawl_result.redirected_url = async_response.redirected_url or (None if is_raw_url else url) + crawl_result.redirected_status_code = async_response.redirected_status_code + crawl_result.response_headers = async_response.response_headers + crawl_result.downloaded_files = async_response.downloaded_files + crawl_result.js_execution_result = js_execution_result + crawl_result.mhtml = async_response.mhtml_data + crawl_result.ssl_certificate = async_response.ssl_certificate + crawl_result.network_requests = async_response.network_requests + crawl_result.console_messages = async_response.console_messages + crawl_result.success = bool(html) + crawl_result.session_id = getattr(config, "session_id", None) + crawl_result.cache_status = "miss" # Check if blocked (skip for raw: URLs — # caller-provided content, anti-bot N/A) diff --git a/crawl4ai/document_extraction_strategy.py b/crawl4ai/document_extraction_strategy.py new file mode 100644 index 000000000..f50707775 --- /dev/null +++ b/crawl4ai/document_extraction_strategy.py @@ -0,0 +1,90 @@ +""" +Document Extraction Strategy — abstract base for detecting and extracting +text from binary documents (PDF, DOCX, XLSX, etc.) during the crawl pipeline. + +When configured on CrawlerRunConfig, the strategy is checked after browser +navigation but before HTML content scraping. If it detects a document, it +extracts text directly — skipping the HTML pipeline entirely. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .models import AsyncCrawlResponse + + +@dataclass +class DocumentExtractionResult: + """Result of extracting text from a binary document.""" + + content: str + """Extracted text content (plain text or markdown).""" + + content_type: str + """MIME type or file extension (e.g., 'application/pdf', 'pdf').""" + + source_path: Optional[Path] = None + """Local file path if the document was downloaded.""" + + metadata: dict = field(default_factory=dict) + """Optional metadata (title, author, page count, etc.).""" + + +class DocumentExtractionStrategy(ABC): + """ + Abstract strategy for detecting and extracting text from binary documents. + + Subclass this and implement ``detect()`` and ``extract()`` using your + preferred extraction backend (Kreuzberg, PyMuPDF, Docling, etc.). + + Example:: + + class KreuzbergDocumentStrategy(DocumentExtractionStrategy): + DOCUMENT_TYPES = {"application/pdf", "application/msword", ...} + + def detect(self, response): + if response.downloaded_files: + return True + ct = (response.response_headers or {}).get("content-type", "") + return ct.split(";")[0].strip() in self.DOCUMENT_TYPES + + async def extract(self, response, url): + from kreuzberg import extract_file + path = Path(response.downloaded_files[0]) + result = await extract_file(str(path)) + return DocumentExtractionResult( + content=result.content, + content_type=path.suffix.lstrip("."), + source_path=path, + ) + """ + + @abstractmethod + def detect(self, response: "AsyncCrawlResponse") -> bool: + """Return True if the response represents a binary document. + + Implementations can check: + - ``response.downloaded_files`` — browser triggered a download + - ``response.response_headers`` — Content-Type header + - ``response.status_code`` — failed navigation + - URL extension heuristics + """ + ... + + @abstractmethod + async def extract( + self, response: "AsyncCrawlResponse", url: str + ) -> DocumentExtractionResult: + """Extract text content from the document. + + Args: + response: The crawl response (may contain downloaded file paths). + url: The original URL that was crawled. + + Returns: + DocumentExtractionResult with extracted text. + """ + ... diff --git a/tests/test_document_extraction_strategy_1890.py b/tests/test_document_extraction_strategy_1890.py new file mode 100644 index 000000000..a2f6e8668 --- /dev/null +++ b/tests/test_document_extraction_strategy_1890.py @@ -0,0 +1,303 @@ +""" +Tests for #1890: DocumentExtractionStrategy — pluggable binary document +detection and extraction in the crawl pipeline. + +Tests the abstract interface, a concrete mock implementation, the +CrawlerRunConfig integration, and the pipeline routing logic. +""" +import pytest +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from crawl4ai.document_extraction_strategy import ( + DocumentExtractionStrategy, + DocumentExtractionResult, +) +from crawl4ai.async_configs import CrawlerRunConfig +from crawl4ai.models import AsyncCrawlResponse, MarkdownGenerationResult + + +# ── Concrete test implementation ───────────────────────────────────────── + +class MockDocumentStrategy(DocumentExtractionStrategy): + """Test implementation that detects documents by downloaded_files.""" + + def detect(self, response) -> bool: + return bool(response.downloaded_files) + + async def extract(self, response, url) -> DocumentExtractionResult: + filename = response.downloaded_files[0] + return DocumentExtractionResult( + content=f"Extracted text from {filename}", + content_type="pdf", + source_path=Path(filename), + metadata={"pages": 5}, + ) + + +class ContentTypeDocumentStrategy(DocumentExtractionStrategy): + """Test implementation that detects by Content-Type header.""" + + DOCUMENT_TYPES = {"application/pdf", "application/msword"} + + def detect(self, response) -> bool: + ct = (response.response_headers or {}).get("content-type", "") + return ct.split(";")[0].strip() in self.DOCUMENT_TYPES + + async def extract(self, response, url) -> DocumentExtractionResult: + return DocumentExtractionResult( + content="Document content here", + content_type="pdf", + ) + + +class NeverDetectStrategy(DocumentExtractionStrategy): + """Strategy that never detects documents — HTML pipeline should be used.""" + + def detect(self, response) -> bool: + return False + + async def extract(self, response, url) -> DocumentExtractionResult: + raise RuntimeError("Should never be called") + + +# ── Helper ─────────────────────────────────────────────────────────────── + +def make_response(**kwargs): + """Create a mock AsyncCrawlResponse.""" + defaults = { + "html": "", + "response_headers": {}, + "status_code": 200, + "downloaded_files": None, + "redirected_url": None, + "redirected_status_code": None, + } + defaults.update(kwargs) + resp = MagicMock(spec=AsyncCrawlResponse) + for k, v in defaults.items(): + setattr(resp, k, v) + return resp + + +# ── Base class tests ───────────────────────────────────────────────────── + +class TestDocumentExtractionResult: + + def test_basic_creation(self): + result = DocumentExtractionResult( + content="Hello world", + content_type="pdf", + ) + assert result.content == "Hello world" + assert result.content_type == "pdf" + assert result.source_path is None + assert result.metadata == {} + + def test_with_all_fields(self): + result = DocumentExtractionResult( + content="Text", + content_type="application/pdf", + source_path=Path("/tmp/doc.pdf"), + metadata={"pages": 10, "author": "Jane"}, + ) + assert result.source_path == Path("/tmp/doc.pdf") + assert result.metadata["pages"] == 10 + + def test_cannot_instantiate_abstract(self): + """DocumentExtractionStrategy is abstract — can't instantiate directly.""" + with pytest.raises(TypeError): + DocumentExtractionStrategy() + + +# ── Detection tests ────────────────────────────────────────────────────── + +class TestDetection: + + def test_detect_by_downloaded_files(self): + strategy = MockDocumentStrategy() + resp = make_response(downloaded_files=["/tmp/doc.pdf"]) + assert strategy.detect(resp) is True + + def test_no_downloaded_files_not_detected(self): + strategy = MockDocumentStrategy() + resp = make_response(downloaded_files=None) + assert strategy.detect(resp) is False + + def test_empty_downloaded_files_not_detected(self): + strategy = MockDocumentStrategy() + resp = make_response(downloaded_files=[]) + assert strategy.detect(resp) is False + + def test_detect_by_content_type(self): + strategy = ContentTypeDocumentStrategy() + resp = make_response( + response_headers={"content-type": "application/pdf; charset=utf-8"} + ) + assert strategy.detect(resp) is True + + def test_html_content_type_not_detected(self): + strategy = ContentTypeDocumentStrategy() + resp = make_response( + response_headers={"content-type": "text/html; charset=utf-8"} + ) + assert strategy.detect(resp) is False + + def test_never_detect_strategy(self): + strategy = NeverDetectStrategy() + resp = make_response(downloaded_files=["/tmp/doc.pdf"]) + assert strategy.detect(resp) is False + + +# ── Extraction tests ───────────────────────────────────────────────────── + +class TestExtraction: + + @pytest.mark.asyncio + async def test_extract_returns_result(self): + strategy = MockDocumentStrategy() + resp = make_response(downloaded_files=["/tmp/report.pdf"]) + result = await strategy.extract(resp, "https://example.com/report.pdf") + + assert isinstance(result, DocumentExtractionResult) + assert "report.pdf" in result.content + assert result.content_type == "pdf" + assert result.source_path == Path("/tmp/report.pdf") + assert result.metadata == {"pages": 5} + + @pytest.mark.asyncio + async def test_extract_content_type_strategy(self): + strategy = ContentTypeDocumentStrategy() + resp = make_response( + response_headers={"content-type": "application/pdf"} + ) + result = await strategy.extract(resp, "https://example.com/doc.pdf") + assert result.content == "Document content here" + + +# ── CrawlerRunConfig integration ───────────────────────────────────────── + +class TestConfigIntegration: + + def test_default_is_none(self): + config = CrawlerRunConfig() + assert config.document_extraction_strategy is None + + def test_set_strategy(self): + strategy = MockDocumentStrategy() + config = CrawlerRunConfig(document_extraction_strategy=strategy) + assert config.document_extraction_strategy is strategy + + def test_none_strategy_no_effect(self): + """When strategy is None, the config should work normally.""" + config = CrawlerRunConfig(document_extraction_strategy=None) + assert config.document_extraction_strategy is None + + +# ── Pipeline routing tests ─────────────────────────────────────────────── + +class TestPipelineRouting: + """Test that the arun integration point routes correctly.""" + + def test_detect_true_skips_html_pipeline(self): + """When detect() returns True, HTML processing should be skipped.""" + strategy = MockDocumentStrategy() + resp = make_response(downloaded_files=["/tmp/doc.pdf"]) + + # Simulate the routing logic from async_webcrawler.py + doc_strategy = strategy + if doc_strategy and doc_strategy.detect(resp): + route = "document" + else: + route = "html" + + assert route == "document" + + def test_detect_false_uses_html_pipeline(self): + """When detect() returns False, normal HTML processing should proceed.""" + strategy = NeverDetectStrategy() + resp = make_response(downloaded_files=["/tmp/doc.pdf"]) + + doc_strategy = strategy + if doc_strategy and doc_strategy.detect(resp): + route = "document" + else: + route = "html" + + assert route == "html" + + def test_no_strategy_uses_html_pipeline(self): + """When no strategy is configured, always use HTML pipeline.""" + resp = make_response(downloaded_files=["/tmp/doc.pdf"]) + + doc_strategy = None + if doc_strategy and doc_strategy.detect(resp): + route = "document" + else: + route = "html" + + assert route == "html" + + @pytest.mark.asyncio + async def test_document_result_has_markdown(self): + """Document extraction result should be wrapped in MarkdownGenerationResult.""" + strategy = MockDocumentStrategy() + resp = make_response( + downloaded_files=["/tmp/report.pdf"], + status_code=200, + response_headers={}, + redirected_url=None, + redirected_status_code=None, + ) + + doc_result = await strategy.extract(resp, "https://example.com/report.pdf") + + # Simulate what async_webcrawler.py does + md_result = MarkdownGenerationResult( + raw_markdown=doc_result.content, + markdown_with_citations=doc_result.content, + references_markdown="", + fit_markdown="", + fit_html="", + ) + + assert "report.pdf" in md_result.raw_markdown + assert md_result.fit_markdown == "" + + @pytest.mark.asyncio + async def test_document_metadata_includes_is_document(self): + """CrawlResult metadata should include is_document=True.""" + strategy = MockDocumentStrategy() + resp = make_response(downloaded_files=["/tmp/doc.pdf"]) + doc_result = await strategy.extract(resp, "https://example.com/doc.pdf") + + metadata = { + "is_document": True, + "content_type": doc_result.content_type, + **(doc_result.metadata or {}), + } + + assert metadata["is_document"] is True + assert metadata["content_type"] == "pdf" + assert metadata["pages"] == 5 + + +# ── Import tests ───────────────────────────────────────────────────────── + +class TestImports: + + def test_importable_from_crawl4ai(self): + from crawl4ai import DocumentExtractionStrategy, DocumentExtractionResult + assert DocumentExtractionStrategy is not None + assert DocumentExtractionResult is not None + + def test_importable_from_module(self): + from crawl4ai.document_extraction_strategy import ( + DocumentExtractionStrategy, + DocumentExtractionResult, + ) + assert DocumentExtractionStrategy is not None