|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +INCLUDE_EXTENSIONS = { |
| 6 | + ".py", ".ts", ".tsx", ".js", ".jsx", ".java", ".go", ".rs", ".rb", |
| 7 | + ".cs", ".cpp", ".c", ".h", ".md", ".json", ".yaml", ".yml", ".toml", |
| 8 | +} |
| 9 | +INCLUDE_FILENAMES = {".env.example", "Dockerfile", "docker-compose.yml"} |
| 10 | +SKIP_DIRS = { |
| 11 | + ".git", "node_modules", "__pycache__", ".venv", "venv", "env", |
| 12 | + "dist", "build", ".next", ".nuxt", "target", "vendor", ".grokcode", |
| 13 | +} |
| 14 | +MAX_FILE_SIZE = 50 * 1024 # 50 KB |
| 15 | +KEY_FILES = { |
| 16 | + "README.md", "pyproject.toml", "package.json", "go.mod", "Cargo.toml", |
| 17 | + "pom.xml", "Dockerfile", "docker-compose.yml", |
| 18 | +} |
| 19 | +KEY_FILE_LINES = 100 |
| 20 | +OTHER_FILE_LINES = 30 |
| 21 | +OTHER_FILES_CAP = 60 # avoid enormous summaries |
| 22 | + |
| 23 | + |
| 24 | +def collect_files(path: Path) -> list[Path]: |
| 25 | + """Recursively collect relevant source files, skipping ignored dirs and large files.""" |
| 26 | + result: list[Path] = [] |
| 27 | + for entry in path.rglob("*"): |
| 28 | + if any(part in SKIP_DIRS for part in entry.parts): |
| 29 | + continue |
| 30 | + if not entry.is_file(): |
| 31 | + continue |
| 32 | + if entry.suffix.lower() not in INCLUDE_EXTENSIONS and entry.name not in INCLUDE_FILENAMES: |
| 33 | + continue |
| 34 | + try: |
| 35 | + if entry.stat().st_size > MAX_FILE_SIZE: |
| 36 | + continue |
| 37 | + except OSError: |
| 38 | + continue |
| 39 | + result.append(entry) |
| 40 | + return sorted(result) |
| 41 | + |
| 42 | + |
| 43 | +def _build_tree(path: Path, max_depth: int = 2) -> str: |
| 44 | + """Build a directory tree string up to max_depth levels.""" |
| 45 | + lines: list[str] = [str(path) + "/"] |
| 46 | + |
| 47 | + def _walk(p: Path, depth: int) -> None: |
| 48 | + if depth > max_depth: |
| 49 | + return |
| 50 | + try: |
| 51 | + entries = sorted(p.iterdir(), key=lambda x: (x.is_file(), x.name)) |
| 52 | + except PermissionError: |
| 53 | + return |
| 54 | + for entry in entries: |
| 55 | + if entry.name in SKIP_DIRS: |
| 56 | + continue |
| 57 | + indent = " " * (depth - 1) |
| 58 | + suffix = "/" if entry.is_dir() else "" |
| 59 | + lines.append(f"{indent}├── {entry.name}{suffix}") |
| 60 | + if entry.is_dir() and depth < max_depth: |
| 61 | + _walk(entry, depth + 1) |
| 62 | + |
| 63 | + _walk(path, 1) |
| 64 | + return "\n".join(lines) |
| 65 | + |
| 66 | + |
| 67 | +def build_summary(path: Path, files: list[Path]) -> str: |
| 68 | + """Build a structured codebase summary string to send to Grok.""" |
| 69 | + sections: list[str] = [] |
| 70 | + |
| 71 | + sections.append("## Directory Structure\n" + _build_tree(path)) |
| 72 | + |
| 73 | + key_found: list[Path] = [] |
| 74 | + other_files: list[Path] = [] |
| 75 | + test_files: list[Path] = [] |
| 76 | + ci_files: list[Path] = [] |
| 77 | + |
| 78 | + for f in files: |
| 79 | + try: |
| 80 | + rel = f.relative_to(path) |
| 81 | + except ValueError: |
| 82 | + rel = Path(f.name) |
| 83 | + |
| 84 | + if f.name in KEY_FILES: |
| 85 | + key_found.append(f) |
| 86 | + elif "test" in f.name.lower() or any("test" in part.lower() for part in rel.parts): |
| 87 | + test_files.append(rel) |
| 88 | + elif ".github" in rel.parts and "workflows" in rel.parts: |
| 89 | + ci_files.append(rel) |
| 90 | + else: |
| 91 | + other_files.append(f) |
| 92 | + |
| 93 | + if key_found: |
| 94 | + sections.append("## Key Files") |
| 95 | + for f in key_found: |
| 96 | + try: |
| 97 | + lines = f.read_text(encoding="utf-8", errors="ignore").splitlines()[:KEY_FILE_LINES] |
| 98 | + sections.append(f"### {f.name}\n```\n" + "\n".join(lines) + "\n```") |
| 99 | + except OSError: |
| 100 | + pass |
| 101 | + |
| 102 | + if other_files: |
| 103 | + sections.append("## Source Files (first 30 lines each)") |
| 104 | + for f in other_files[:OTHER_FILES_CAP]: |
| 105 | + try: |
| 106 | + rel = f.relative_to(path) |
| 107 | + except ValueError: |
| 108 | + rel = Path(f.name) |
| 109 | + try: |
| 110 | + lines = f.read_text(encoding="utf-8", errors="ignore").splitlines()[:OTHER_FILE_LINES] |
| 111 | + sections.append(f"### {rel}\n```\n" + "\n".join(lines) + "\n```") |
| 112 | + except OSError: |
| 113 | + pass |
| 114 | + |
| 115 | + if test_files: |
| 116 | + sections.append("## Test Files\n" + "\n".join(f"- {p}" for p in test_files)) |
| 117 | + |
| 118 | + if ci_files: |
| 119 | + sections.append("## CI/CD Config\n" + "\n".join(f"- {p}" for p in ci_files)) |
| 120 | + |
| 121 | + return "\n\n".join(sections) |
0 commit comments