"""
note_stats.py - Obsidian 筆記字數統計工具
統計母筆記及其 [[雙向鏈結]] 子筆記、孫筆記的字數與行數。

用法:
    python -X utf8 note_stats.py <markdown_file> [--depth N] [--json] [--no-tree]

範例:
    # 統計母筆記 + 子筆記（預設 depth=1）
    python -X utf8 note_stats.py "research/2026/02/2026-02-06-琉球素地不動產估價專案規劃與研究.md"

    # 統計到孫筆記（depth=2）
    python -X utf8 note_stats.py "research/2026/02/xxx.md" --depth 2

    # 只統計母筆記本身（depth=0）
    python -X utf8 note_stats.py "research/2026/02/xxx.md" --depth 0

    # 輸出 JSON 格式
    python -X utf8 note_stats.py "research/2026/02/xxx.md" --json
"""

import sys
import os
import re
import json
import argparse
from pathlib import Path

sys.stdout.reconfigure(encoding='utf-8')

# Obsidian vault 根目錄（預設為 Research_zoo）
VAULT_ROOT = Path(r"C:\Users\User\Documents\GitHub\Research_zoo")


def extract_wikilinks(content: str) -> set[str]:
    """從 markdown 內容中提取所有 [[wikilink]] 的筆記名稱（去除 # 錨點和 | 別名）"""
    pattern = r'\[\[([^\]#|]+?)(?:#[^\]|]*)?(?:\|[^\]]*?)?\]\]'
    links = set(re.findall(pattern, content))
    # 只保留看起來像筆記名稱的（排除純 URL 等）
    return {link.strip() for link in links if not link.startswith('http')}


def find_note_file(note_name: str, parent_dir: Path) -> Path | None:
    """尋找筆記檔案，依序搜尋：1. 同資料夾 2. vault 全域"""
    # 確保有 .md 副檔名
    if not note_name.endswith('.md'):
        note_name += '.md'

    # 1. 同資料夾
    candidate = parent_dir / note_name
    if candidate.exists():
        return candidate

    # 2. vault 全域搜尋（使用 glob，可能較慢但可靠）
    results = list(VAULT_ROOT.rglob(note_name))
    if results:
        return results[0]

    return None


def count_file(filepath: Path) -> dict:
    """計算單一檔案的字數和行數"""
    try:
        content = filepath.read_text(encoding='utf-8')
        return {
            'chars': len(content),
            'lines': len(content.splitlines()),
            'content': content,
            'exists': True,
        }
    except Exception as e:
        return {
            'chars': 0,
            'lines': 0,
            'content': '',
            'exists': False,
            'error': str(e),
        }


def analyze_note(filepath: Path, depth: int, visited: set | None = None) -> dict:
    """遞迴分析筆記及其鏈結子筆記

    Args:
        filepath: 筆記檔案路徑
        depth: 遞迴深度（0=只看自己，1=子筆記，2=孫筆記...）
        visited: 已造訪的檔案（避免循環鏈結）

    Returns:
        分析結果 dict
    """
    if visited is None:
        visited = set()

    filepath = filepath.resolve()
    visited.add(filepath)

    file_info = count_file(filepath)
    if not file_info['exists']:
        return {
            'name': filepath.stem,
            'path': str(filepath),
            'chars': 0,
            'lines': 0,
            'exists': False,
            'children': [],
        }

    result = {
        'name': filepath.stem,
        'path': str(filepath),
        'chars': file_info['chars'],
        'lines': file_info['lines'],
        'exists': True,
        'children': [],
    }

    if depth > 0:
        links = extract_wikilinks(file_info['content'])
        parent_dir = filepath.parent

        for link in sorted(links):
            child_path = find_note_file(link, parent_dir)
            if child_path is None:
                result['children'].append({
                    'name': link,
                    'path': None,
                    'chars': 0,
                    'lines': 0,
                    'exists': False,
                    'children': [],
                })
            elif child_path.resolve() not in visited:
                child_result = analyze_note(child_path, depth - 1, visited)
                result['children'].append(child_result)

    return result


def collect_all_stats(result: dict) -> dict:
    """從遞迴結構中收集所有統計數據"""
    all_notes = []
    not_found = []

    def _collect(node, level):
        entry = {
            'name': node['name'],
            'chars': node['chars'],
            'lines': node['lines'],
            'level': level,
            'exists': node['exists'],
        }
        if node['exists']:
            all_notes.append(entry)
        else:
            not_found.append(entry)

        for child in node.get('children', []):
            _collect(child, level + 1)

    _collect(result, 0)
    return {'notes': all_notes, 'not_found': not_found}


def print_tree(result: dict, indent: int = 0):
    """以樹狀結構印出統計"""
    prefix = '  ' * indent
    marker = '├─ ' if indent > 0 else ''

    if result['exists']:
        print(f"{prefix}{marker}{result['name']}: {result['chars']:,} 字 / {result['lines']:,} 行")
    else:
        print(f"{prefix}{marker}{result['name']}: (未找到)")

    for child in result.get('children', []):
        print_tree(child, indent + 1)


def print_summary(stats: dict, root_name: str, depth: int):
    """印出彙整報告"""
    notes = stats['notes']
    not_found = stats['not_found']

    if not notes:
        print("找不到任何筆記。")
        return

    # 分層統計
    levels = {}
    for note in notes:
        lvl = note['level']
        if lvl not in levels:
            levels[lvl] = {'count': 0, 'chars': 0, 'lines': 0}
        levels[lvl]['count'] += 1
        levels[lvl]['chars'] += note['chars']
        levels[lvl]['lines'] += note['lines']

    level_names = {0: '母筆記', 1: '子筆記', 2: '孫筆記', 3: '曾孫筆記'}

    total_chars = sum(n['chars'] for n in notes)
    total_lines = sum(n['lines'] for n in notes)
    total_count = len(notes)

    print()
    print("=" * 60)
    print(f"  筆記統計：{root_name}")
    print(f"  遞迴深度：{depth}")
    print("=" * 60)
    print()

    for lvl in sorted(levels.keys()):
        lbl = level_names.get(lvl, f'第{lvl}層')
        info = levels[lvl]
        print(f"  {lbl}：{info['count']} 篇 / {info['chars']:,} 字 / {info['lines']:,} 行")

    print(f"  {'─' * 40}")
    print(f"  總計：{total_count} 篇 / {total_chars:,} 字 / {total_lines:,} 行")
    print(f"  約 {total_chars / 1000:.0f}K 字（{total_chars / 10000:.1f} 萬字）")
    print()

    if not_found:
        print(f"  未找到的鏈結（{len(not_found)} 篇）：")
        for nf in not_found:
            print(f"    - {nf['name']}")
        print()

    # 排行榜（前 10 大）
    ranked = sorted([n for n in notes if n['level'] > 0], key=lambda x: x['chars'], reverse=True)
    if ranked:
        print(f"  字數排行（前 {min(10, len(ranked))} 名）：")
        for i, note in enumerate(ranked[:10], 1):
            lvl_label = level_names.get(note['level'], f'L{note["level"]}')
            print(f"    {i:2d}. {note['name']}: {note['chars']:,} 字 [{lvl_label}]")
        print()


def main():
    parser = argparse.ArgumentParser(description='Obsidian 筆記字數統計工具')
    parser.add_argument('file', help='母筆記的 markdown 檔案路徑')
    parser.add_argument('--depth', type=int, default=1,
                        help='遞迴深度：0=只看母筆記，1=含子筆記（預設），2=含孫筆記')
    parser.add_argument('--json', action='store_true', help='輸出 JSON 格式')
    parser.add_argument('--no-tree', action='store_true', help='不顯示樹狀結構')
    parser.add_argument('--vault', type=str, default=None,
                        help='Obsidian vault 根目錄（預設自動偵測）')

    args = parser.parse_args()

    # 設定 vault 根目錄
    global VAULT_ROOT
    if args.vault:
        VAULT_ROOT = Path(args.vault)

    # 解析檔案路徑
    filepath = Path(args.file)
    if not filepath.is_absolute():
        filepath = VAULT_ROOT / filepath

    if not filepath.exists():
        print(f"錯誤：找不到檔案 {filepath}")
        sys.exit(1)

    # 分析
    result = analyze_note(filepath, args.depth)
    stats = collect_all_stats(result)

    if args.json:
        output = {
            'root': result['name'],
            'depth': args.depth,
            'total_notes': len(stats['notes']),
            'total_chars': sum(n['chars'] for n in stats['notes']),
            'total_lines': sum(n['lines'] for n in stats['notes']),
            'not_found': len(stats['not_found']),
            'notes': stats['notes'],
            'not_found_list': [n['name'] for n in stats['not_found']],
        }
        print(json.dumps(output, ensure_ascii=False, indent=2))
    else:
        if not args.no_tree:
            print_tree(result)

        print_summary(stats, result['name'], args.depth)


if __name__ == '__main__':
    main()
