#!/usr/bin/env python3
"""
Maat Journal Decryption Tool
============================

This tool allows you to decrypt your Maat Journal export files using your recovery phrase.
It gives you full control over your data by allowing you to access your journal entries
outside of the app.

SECURITY: Your recovery phrase is the key to your data. Keep it safe and never share it.

How it works:
1. Your 12-word recovery phrase is converted to a cryptographic seed
2. The seed is used to derive an encryption key using HKDF-SHA256
3. Each journal entry is decrypted using AES-256-GCM encryption
4. The decrypted data is saved as a readable JSON file

Usage:
    python maat_decrypt.py --input your_export.json --output decrypted_entries.json

Requirements:
    - Python 3.9 or higher
    - Your 12-word recovery phrase in a .env file
    - Install dependencies: pip install python-dotenv mnemonic pycryptodome

For help: python maat_decrypt.py --help

Learn more: https://github.com/your-repo/maat-journal/docs/encryption.md
"""

import argparse
import base64
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

try:
    from Crypto.Cipher import AES
    from Crypto.Hash import SHA256
    from Crypto.Protocol.KDF import HKDF
    from dotenv import load_dotenv
    from mnemonic import Mnemonic
except ImportError as e:
    print("ERROR: Missing required dependencies!")
    print(f"Error: {e}")
    print("\nInstall dependencies with:")
    print("   pip install python-dotenv mnemonic pycryptodome")
    print("\nOr if you're using a virtual environment:")
    print("   source venv/bin/activate && pip install python-dotenv mnemonic pycryptodome")
    sys.exit(1)


# Encryption constants
HKDF_SALT = b"maat-journal-v1"
HKDF_INFO = b"master-encryption-key"
IV_LENGTH = 16  # bytes
GCM_TAG_LENGTH = 16  # bytes

# Maat export format versioning
MAAT_EXPORT_FORMAT_VERSION = "1.0"
SUPPORTED_MAAT_VERSIONS = ["1.0"]  # Add new versions here as they're released
DEFAULT_MAAT_VERSION = "1.0"  # Version to assume if not specified


def load_recovery_phrase() -> str:
    """Load recovery phrase from .env file with helpful error messages."""
    # Try multiple locations for .env file
    env_locations = [
        Path.cwd() / ".env",
        Path(__file__).parent / ".env",
        Path.home() / ".maat-journal" / ".env",
    ]
    
    loaded = False
    for env_path in env_locations:
        if env_path.exists():
            load_dotenv(dotenv_path=env_path, override=False)
            loaded = True
            print(f"Loaded .env from: {env_path}")
    
    if not loaded:
        print("ERROR: No .env file found!")
        print("\nCreate a .env file with your recovery phrase:")
        print("   RECOVERY_PHRASE=\"word1 word2 word3 word4 word5 word6 word7 word8 word9 word10 word11 word12\"")
        print("\nPlace the .env file in one of these locations:")
        for loc in env_locations:
            print(f"   - {loc}")
        sys.exit(1)

    phrase = os.getenv("RECOVERY_PHRASE")
    if not phrase:
        print("ERROR: RECOVERY_PHRASE not found in .env file!")
        print("\nAdd your recovery phrase to the .env file:")
        print("   RECOVERY_PHRASE=\"word1 word2 word3 word4 word5 word6 word7 word8 word9 word10 word11 word12\"")
        sys.exit(1)
    
    # Normalize whitespace and validate
    phrase = " ".join(phrase.strip().split())
    words = phrase.split()
    
    if len(words) != 12:
        print(f"ERROR: Recovery phrase must be exactly 12 words, found {len(words)}")
        print(f"   Your phrase: {phrase}")
        sys.exit(1)
    
    print(f"Loaded recovery phrase ({len(words)} words)")
    return phrase


def mnemonic_to_seed(phrase: str) -> bytes:
    """Convert recovery phrase to cryptographic seed."""
    m = Mnemonic("english")
    if not m.check(phrase):
        print("WARNING: Recovery phrase failed BIP39 checksum validation")
        print("   This might be okay if your phrase was generated by the app")
        print("   Proceeding with decryption attempt...")
    # Empty passphrase to match app behavior
    seed = m.to_seed(phrase, passphrase="")
    print("Generated cryptographic seed from recovery phrase")
    return seed


def derive_master_key(seed: bytes) -> bytes:
    """Derive master encryption key from seed using HKDF."""
    key = HKDF(master=seed, key_len=32, salt=HKDF_SALT, hashmod=SHA256, context=HKDF_INFO)
    print("Derived master encryption key using HKDF")
    return key


def decrypt_gcm_base64(b64_data: str, key: bytes) -> bytes:
    """Decrypt AES-256-GCM encrypted data from base64 format."""
    try:
        raw = base64.b64decode(b64_data)
    except Exception as e:
        raise ValueError(f"Invalid base64 data: {e}")

    if len(raw) < IV_LENGTH + GCM_TAG_LENGTH + 1:
        raise ValueError("Encrypted data too short (corrupted or invalid format)")

    iv = raw[:IV_LENGTH]
    ct_with_tag = raw[IV_LENGTH:]
    if len(ct_with_tag) <= GCM_TAG_LENGTH:
        raise ValueError("Invalid encrypted data format")

    ciphertext = ct_with_tag[:-GCM_TAG_LENGTH]
    tag = ct_with_tag[-GCM_TAG_LENGTH:]

    cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
    plaintext = cipher.decrypt(ciphertext)
    try:
        cipher.verify(tag)
    except Exception as e:
        raise ValueError(f"Decryption failed - wrong key or corrupted data: {e}")
    return plaintext


def detect_maat_export_version(data: Dict[str, Any]) -> str:
    """Detect the Maat export format version from the data."""
    metadata = data.get("backup_metadata", {})
    
    # Check for explicit Maat version
    version = metadata.get("maat_export_format_version")
    if version:
        return version
    
    # Default to current version if no version info found
    return DEFAULT_MAAT_VERSION


def validate_maat_export_version(version: str) -> bool:
    """Validate that the Maat export version is supported."""
    return version in SUPPORTED_MAAT_VERSIONS


def migrate_maat_export_data(data: Dict[str, Any], from_version: str, to_version: str) -> Dict[str, Any]:
    """Migrate Maat export data from one version to another."""
    if from_version == to_version:
        return data
    
    migrated_data = data.copy()
    
    # Version 1.0 to 1.1 migration (example for future use)
    if from_version == "1.0" and to_version == "1.1":
        # Add any new fields or transformations here
        # migrated_data["new_field"] = "default_value"
        pass
    
    # Version 1.1 to 1.2 migration (example for future use)
    if from_version == "1.1" and to_version == "1.2":
        # Add any new fields or transformations here
        # migrated_data["another_field"] = "default_value"
        pass
    
    return migrated_data


def convert_to_plain_text(data: Dict[str, Any]) -> str:
    """Convert decrypted journal data to readable plain text format."""
    entries = data.get("journal_entries", [])
    if not entries:
        return "No journal entries found."
    
    # Sort entries by date (newest first)
    sorted_entries = sorted(entries, key=lambda x: x.get("created_at", ""), reverse=True)
    
    lines = []
    lines.append("Maat Journal - Decrypted Entries")
    lines.append("=" * 50)
    lines.append("")
    
    for i, entry in enumerate(sorted_entries, 1):
        # Get entry data
        created_at = entry.get("created_at", "")
        mood = entry.get("mood", "")
        content = entry.get("content", "")
        
        # Format date
        try:
            if created_at:
                date_obj = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
                formatted_date = date_obj.strftime("%B %d, %Y at %I:%M %p")
            else:
                formatted_date = "Unknown date"
        except:
            formatted_date = created_at or "Unknown date"
        
        # Format mood
        mood_text = f"Mood: {mood}" if mood else "Mood: Not specified"
        
        # Format content
        if isinstance(content, dict):
            # Handle structured content
            content_text = content.get("content", "")
            if not content_text:
                content_text = str(content)
        else:
            content_text = str(content) if content else "No content"
        
        # Build entry text
        lines.append(f"Entry #{i}")
        lines.append(f"Date: {formatted_date}")
        lines.append(f"{mood_text}")
        lines.append("-" * 30)
        lines.append(content_text)
        lines.append("")
        lines.append("")
    
    # Add summary
    lines.append("=" * 50)
    lines.append(f"Total Entries: {len(entries)}")
    lines.append(f"Export Date: {datetime.now().strftime('%B %d, %Y at %I:%M %p')}")
    
    return "\n".join(lines)


def process_export(data: Dict[str, Any], key: bytes) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    """Process and decrypt all entries in the export with version handling."""
    # Detect and validate Maat export version
    maat_export_version = detect_maat_export_version(data)
    
    if not validate_maat_export_version(maat_export_version):
        print(f"WARNING: Maat export format version {maat_export_version} is not officially supported")
        print(f"Supported versions: {', '.join(SUPPORTED_MAAT_VERSIONS)}")
        print("Attempting to process anyway...")
    
    print(f"Detected Maat export format version: {maat_export_version}")
    
    # Migrate data to current version if needed
    if maat_export_version != MAAT_EXPORT_FORMAT_VERSION:
        print(f"Migrating data from version {maat_export_version} to {MAAT_EXPORT_FORMAT_VERSION}")
        data = migrate_maat_export_data(data, maat_export_version, MAAT_EXPORT_FORMAT_VERSION)
    
    metadata = data.get("backup_metadata", {})
    entries = data.get("journal_entries", [])
    
    if not entries:
        print("WARNING: No journal entries found in export file")
        return data, []

    print(f"Found {len(entries)} journal entries to decrypt")
    
    failed: List[Dict[str, Any]] = []
    successful = 0

    # Work on a copy to avoid mutating input
    result = json.loads(json.dumps(data))

    for idx, entry in enumerate(result.get("journal_entries", [])):
        encrypted_str: Optional[str] = entry.get("content") or entry.get("encrypted_content")
        if not encrypted_str:
            # Nothing to decrypt
            continue

        try:
            decrypted_bytes = decrypt_gcm_base64(encrypted_str, key)
            try:
                decrypted_obj = json.loads(decrypted_bytes.decode("utf-8"))
            except Exception as e:
                raise ValueError(f"Decrypted data is not valid JSON: {e}")

            # Replace fields to match unencrypted export format used by the app
            if "encrypted_content" in entry:
                entry.pop("encrypted_content", None)
            # Persist other original fields
            entry["content"] = decrypted_obj.get("content", "")
            if "metadata" in decrypted_obj:
                entry["metadata"] = decrypted_obj["metadata"]
            else:
                entry.pop("metadata", None)
            
            successful += 1
            if successful % 10 == 0:  # Progress indicator every 10 entries
                print(f"   Decrypted {successful}/{len(entries)} entries...")
                
        except Exception as e:
            failed.append({
                "entry_id": entry.get("id"),
                "index": idx,
                "error": str(e),
            })

    # Update metadata with version information
    result.setdefault("backup_metadata", {})
    result["backup_metadata"]["is_encrypted"] = False
    result["backup_metadata"]["decrypted_at"] = datetime.now(timezone.utc).isoformat()
    result["backup_metadata"]["decryption_tool"] = "maat-decrypt-tool"
    result["backup_metadata"]["decryption_version"] = "1.0.0"
    result["backup_metadata"]["original_maat_export_format_version"] = maat_export_version
    result["backup_metadata"]["processed_maat_export_format_version"] = MAAT_EXPORT_FORMAT_VERSION

    if failed:
        result["failed_entries"] = failed

    print(f"Successfully decrypted {successful} entries")
    if failed:
        print(f"Failed to decrypt {len(failed)} entries")
    
    return result, failed


def main() -> None:
    """Main function with improved argument parsing and user experience."""
    parser = argparse.ArgumentParser(
        description="Decrypt Maat Journal export files using your recovery phrase",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Decrypt with automatic output filename (JSON format)
  python maat_decrypt.py --input my_export.json

  # Decrypt to plain text for easy reading
  python maat_decrypt.py --input my_export.json --format text

  # Specify custom output file
  python maat_decrypt.py --input my_export.json --output readable_entries.json

  # Export to plain text with custom filename
  python maat_decrypt.py --input my_export.json --format text --output my_journal.txt

  # Quiet mode (minimal output)
  python maat_decrypt.py --input my_export.json --quiet

Security Notes:
  • Your recovery phrase is the key to your data - keep it safe!
  • Never share your recovery phrase with anyone
  • Store your .env file securely
  • The decrypted file contains your personal data - handle with care

For more information: https://github.com/your-repo/maat-journal/docs/encryption.md
        """
    )
    
    parser.add_argument(
        "--input", "-i", 
        required=True, 
        help="Path to your encrypted Maat Journal export file (.json)"
    )
    parser.add_argument(
        "--output", "-o", 
        help="Path for the decrypted output file (default: <input>_decrypted.json)"
    )
    parser.add_argument(
        "--format", "-f",
        choices=["json", "text"],
        default="json",
        help="Output format: json (default) or text (plain text for reading)"
    )
    parser.add_argument(
        "--quiet", "-q",
        action="store_true", 
        help="Suppress progress messages (quiet mode)"
    )
    parser.add_argument(
        "--version", "-v",
        action="version",
        version="Maat Decrypt Tool v1.0.0"
    )
    
    args = parser.parse_args()

    # Validate input file
    input_path = Path(args.input)
    if not input_path.exists():
        print(f"ERROR: Input file not found: {input_path}")
        sys.exit(1)
    
    if not input_path.suffix.lower() == '.json':
        print(f"WARNING: Input file doesn't have .json extension: {input_path}")
    
    print("Maat Journal Decryption Tool")
    print("=" * 50)
    
    if not args.quiet:
        print(f"Input file: {input_path}")
    
    # Load and validate recovery phrase
    phrase = load_recovery_phrase()
    
    # Generate encryption key
    seed = mnemonic_to_seed(phrase)
    key = derive_master_key(seed)
    
    # Load and validate export file
    try:
        with open(input_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        print(f"ERROR: Invalid JSON file: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"ERROR: Error reading file: {e}")
        sys.exit(1)

    if not isinstance(data, dict):
        print("ERROR: Export file is not a valid JSON object")
        sys.exit(1)
        
    if "journal_entries" not in data:
        print("ERROR: Export file doesn't contain journal entries")
        print("   This might not be a valid Maat Journal export file")
        sys.exit(1)

    # Process and decrypt
    if not args.quiet:
        print("\nStarting decryption...")
    
    decrypted, failed = process_export(data, key)

    # Determine output path
    if args.output:
        out_path = Path(args.output)
    else:
        # Use appropriate extension based on format
        extension = ".txt" if args.format == "text" else ".json"
        out_path = input_path.parent / f"{input_path.stem}_decrypted{extension}"

    # Write output based on format
    try:
        if args.format == "text":
            # Convert to plain text
            text_content = convert_to_plain_text(decrypted)
            with open(out_path, "w", encoding="utf-8") as f:
                f.write(text_content)
        else:
            # Write as JSON
            with open(out_path, "w", encoding="utf-8") as f:
                json.dump(decrypted, f, ensure_ascii=False, indent=2)
                f.write("\n")
    except Exception as e:
        print(f"ERROR: Error writing output file: {e}")
        sys.exit(1)

    # Final summary
    print("\n" + "=" * 50)
    print("Decryption completed!")
    print(f"Output file: {out_path}")
    
    if failed:
        print(f"WARNING: {len(failed)} entries could not be decrypted")
        print("   Check the 'failed_entries' section in the output file for details")
        print("   This might indicate corrupted data or wrong recovery phrase")
    
    print("\nNext steps:")
    print("   • Open the decrypted file to view your journal entries")
    print("   • Keep the decrypted file secure - it contains your personal data")
    print("   • Consider deleting the decrypted file after use")
    
    if not args.quiet:
        print("\nSecurity reminder:")
        print("   • Keep your recovery phrase safe and never share it")
        print("   • Store your .env file securely")
        print("   • The decrypted file contains sensitive personal data")


if __name__ == "__main__":
    main()


