# Fichier: services/dns_service.py
"""
Service d'analyse DNS sécurisé.
Vérifie SPF, DKIM, DMARC, DNSSEC, MTA-STS, DANE, BIMI.
"""

import subprocess
import re
import shutil
import os
from typing import Dict, Tuple, List, Optional, Union

# ========================================================================
# CONFIGURATION
# ========================================================================

MAX_SCORE_TOTAL = 100

PLAN_SCORES = {
    "MX (Mail Base)": 5,
    "SPF (Security)": 25,
    "DKIM": 10,
    "DMARC": 25,
    "DNSSEC": 10,
    "MTA-STS": 10,
    "DANE TLS": 10,
    "BIMI": 5,
}

# ========================================================================
# VALIDATION SÉCURISÉE DU DOMAINE
# ========================================================================

# Regex stricte : lettres, chiffres, tirets, points uniquement
DOMAIN_REGEX = re.compile(
    r'^(?=.{1,253}$)'           # Longueur totale max 253
    r'(?!-)'                     # Ne commence pas par -
    r'(?:[A-Za-z0-9-]{1,63}\.)+' # Labels de 1-63 chars
    r'[A-Za-z]{2,63}$'           # TLD de 2-63 lettres
)

def is_valid_domain(domain: str) -> bool:
    """
    Valide strictement un nom de domaine.
    CRITIQUE : Empêche l'injection de commandes via subprocess.
    """
    if not domain or not isinstance(domain, str):
        return False
    
    domain = domain.strip().lower()
    
    # Longueur raisonnable
    if len(domain) > 253 or len(domain) < 4:
        return False
    
    # Caractères interdits (injection)
    forbidden = [';', '&', '|', '$', '`', '(', ')', '{', '}', '[', ']', '<', '>', '!', '\\', '"', "'", '\n', '\r', ' ']
    if any(char in domain for char in forbidden):
        return False
    
    # Regex finale
    return bool(DOMAIN_REGEX.match(domain))

def sanitize_domain(domain: str) -> Optional[str]:
    """Nettoie et valide un domaine. Retourne None si invalide."""
    if not domain:
        return None
    
    domain = domain.strip().lower()
    
    # Supprimer protocole si présent
    domain = re.sub(r'^https?://', '', domain)
    
    # Supprimer trailing slash et path
    domain = domain.split('/')[0]
    
    # Supprimer www. si présent
    if domain.startswith('www.'):
        domain = domain[4:]
    
    if is_valid_domain(domain):
        return domain
    
    return None

# ========================================================================
# UTILITAIRES DIG
# ========================================================================

def dig_available() -> bool:
    """Vérifie si 'dig' est disponible dans le PATH."""
    return os.path.exists("/usr/bin/dig")

def run_dig(target: str, record_type: str, server: str = "@8.8.8.8", timeout: int = 8) -> Tuple[str, str]:
    """
    Exécute une commande dig de manière sécurisée.
    """
    # Validation basique anti-injection (sans bloquer les sous-domaines DNS spéciaux)
    if not target or not isinstance(target, str):
        return "", "ERROR: Cible invalide."
    
    target = target.strip().lower()
    
    # Longueur raisonnable
    if len(target) > 255 or len(target) < 3:
        return "", "ERROR: Cible invalide (longueur)."
    
    # Caractères interdits (injection de commandes)
    forbidden = [';', '&', '|', '$', '`', '(', ')', '{', '}', '[', ']', '<', '>', '!', '\\', '"', "'", '\n', '\r', ' ']
    if any(char in target for char in forbidden):
        return "", "ERROR: Caractères non autorisés dans la cible."
    
    # Valider record_type (whitelist)
    allowed_types = ['MX', 'TXT', 'DNSKEY', 'TLSA', 'A', 'AAAA', 'NS', 'SOA', 'CNAME']
    if record_type.upper() not in allowed_types:
        return "", f"ERROR: Type d'enregistrement non autorisé: {record_type}"
    
    command = ["/usr/bin/dig", record_type, target, server, "+short"]
    
    try:
        result = subprocess.run(
            command, 
            capture_output=True, 
            text=True, 
            timeout=timeout,
            shell=False  # CRITIQUE : Jamais shell=True avec des inputs utilisateur
        )
        stdout = result.stdout.strip()
        stderr = result.stderr.strip()
        
        if result.returncode != 0 and not stdout:
            return "", f"Dig failed (code {result.returncode}): {stderr}"
        
        return stdout, ""
        
    except FileNotFoundError:
        return "", "ERROR: L'outil 'dig' est introuvable sur le serveur."
    except subprocess.TimeoutExpired:
        return "", "ERROR: Timeout lors de l'exécution de dig."
    except Exception as e:
        return "", f"ERROR: Exception dig: {str(e)}"

def check_record_exists(domain: str, record_type: str) -> Tuple[bool, str, str]:
    """Vérifie si un enregistrement DNS existe."""
    output, error = run_dig(domain, record_type)
    
    if error and "introuvable" in error:
        return False, error, "N/A"
    
    exists = bool(output and "NXDOMAIN" not in output)
    detail = "Enregistrement trouvé." if exists else "Enregistrement non trouvé."
    
    return exists, detail, output if exists else "N/A"

# ========================================================================
# MASQUAGE DES DONNÉES SENSIBLES
# ========================================================================

def mask_record_content(content: str, record_type: str) -> str:
    """Masque les clés cryptographiques pour l'affichage."""
    if not content or content == "N/A":
        return content
    
    if record_type in ["DNSKEY", "TLSA"]:
        lines = content.splitlines()
        first_line = lines[0].strip() if lines else content.strip()
        parts = first_line.split()
        
        if len(parts) >= 2:
            key_part = parts[-1]
            if len(key_part) > 30:
                return f"{key_part[:10]}...{key_part[-10:]} [MASQUÉ]"
        
        return "[Clé masquée]"
    
    if record_type == "DKIM":
        m = re.search(r'p=([^;"]+)', content)
        if m:
            key = m.group(1).strip().replace('"', '')
            if len(key) > 20:
                return f"{key[:10]}...{key[-10:]} [MASQUÉE]"
            return key
        return content
    
    return content

# ========================================================================
# CHECKS INDIVIDUELS
# ========================================================================

def check_mx_record(domain: str) -> Dict:
    """Vérifie les enregistrements MX. MAX: 5 points."""
    output, error = run_dig(domain, "MX")
    score = 0
    mx_hosts = []
    detail = "Aucun enregistrement MX trouvé."
    
    if output and "NXDOMAIN" not in output:
        lines = [l.strip() for l in output.splitlines() if l.strip()]
        for line in lines:
            parts = line.split()
            # Format MX : "10 mail.example.com." ou juste "mail.example.com."
            if parts:
                # Le hostname est le dernier élément
                hostname = parts[-1].rstrip('.')
                if hostname and '.' in hostname:
                    mx_hosts.append(hostname)
        
        if len(mx_hosts) >= 2:
            score = 5
            detail = f"Plusieurs serveurs MX trouvés ({len(mx_hosts)}). Redondance OK."
        elif len(mx_hosts) == 1:
            score = 3
            detail = "Un seul serveur MX trouvé. Redondance faible."
        else:
            detail = "Format MX non reconnu."
    
    content = "\n".join(mx_hosts) if mx_hosts else "N/A"

    return {
        "status": score > 0, 
        "detail": detail, 
        "score": score, 
        "content": content, 
        "mx_hosts": mx_hosts  # Important pour DANE !
    }

def check_spf_record(domain: str) -> Dict:
    """Vérifie l'enregistrement SPF. MAX: 25 points."""
    exists, detail, content = check_record_exists(domain, "TXT")
    spf_record = "N/A"
    score = 0
    
    if exists:
        lines = [l.strip().replace('"', '') for l in content.splitlines() if "v=spf1" in l.lower()]
        if lines:
            spf_record = lines[0]
            m = re.search(r'([-\~\?])all', spf_record)
            
            if m:
                kind = m.group(1)
                if kind == "-":
                    score = 25
                    detail = "SPF strict (-all) trouvé. Protection maximale."
                elif kind == "~":
                    score = 20
                    detail = "SPF soft-fail (~all) trouvé. Bonne protection."
                else:
                    score = 15
                    detail = "SPF avec ?all trouvé. Protection faible."
            else:
                score = 5
                detail = "SPF trouvé mais sans mécanisme de fin."
        else:
            detail = "Aucun enregistrement SPF (v=spf1) trouvé."
    else:
        detail = "Aucun enregistrement TXT trouvé."
        
    return {"status": score > 0, "detail": detail, "score": score, "content": spf_record}

def check_dkim_record(domain: str, selector: str = "default") -> Dict:
    """Vérifie l'enregistrement DKIM. MAX: 10 points."""
    selector = selector.strip() if selector else "default"
    
    # Valider le sélecteur
    if not re.match(r'^[a-zA-Z0-9_-]+$', selector):
        return {
            "status": False, 
            "detail": "Sélecteur DKIM invalide.", 
            "score": 0, 
            "content": "N/A"
        }
    
    dkim_domain = f"{selector}._domainkey.{domain}"
    output, error = run_dig(dkim_domain, "TXT")
    score = 0
    detail = f"Aucun DKIM trouvé pour {dkim_domain}."
    content_display = "N/A"

    if output and "NXDOMAIN" not in output:
        # Nettoyer : joindre les lignes, supprimer les guillemets
        raw = output.replace('\n', ' ').replace('"', '').strip()
        
        if 'v=dkim1' in raw.lower():
            detail = f"DKIM trouvé pour le sélecteur '{selector}'."
            score = 10
            content_display = mask_record_content(raw, "DKIM")
        elif 'p=' in raw.lower():
            # Parfois v=DKIM1 est absent mais la clé est là
            detail = f"Clé DKIM trouvée pour '{selector}' (tag v= manquant)."
            score = 8
            content_display = mask_record_content(raw, "DKIM")
        else:
            detail = f"TXT trouvé pour {dkim_domain} mais format DKIM non reconnu."
            score = 2
    
    return {"status": score >= 2, "detail": detail, "score": score, "content": content_display}

def check_dmarc_record(domain: str) -> Dict:
    """Vérifie l'enregistrement DMARC. MAX: 25 points."""
    dmarc_domain = f"_dmarc.{domain}"
    output, error = run_dig(dmarc_domain, "TXT")
    dmarc_record = "N/A"
    score = 0
    detail = "Aucun enregistrement DMARC trouvé."
    
    if output and "NXDOMAIN" not in output:
        # Nettoyer : joindre les lignes, supprimer les guillemets
        raw = output.replace('\n', ' ').replace('"', '').strip()
        
        # Chercher v=DMARC1
        if 'v=dmarc1' in raw.lower():
            dmarc_record = raw
            
            # Chercher la politique p=
            m_policy = re.search(r'p\s*=\s*([a-zA-Z]+)', raw, flags=re.IGNORECASE)
            
            if m_policy:
                policy = m_policy.group(1).lower().strip()
                
                if policy == "reject":
                    score = 25
                    detail = "DMARC reject trouvé. Protection maximale."
                elif policy == "quarantine":
                    score = 15
                    detail = "DMARC quarantine trouvé. Protection active."
                elif policy == "none":
                    score = 5
                    detail = "DMARC none (monitoring). Non protecteur mais présent."
                else:
                    score = 2
                    detail = f"DMARC avec politique inconnue: p={policy}."
            else:
                score = 2
                detail = "DMARC trouvé mais politique (p=) manquante."
        else:
            detail = "TXT trouvé mais pas de tag v=DMARC1."
    
    return {"status": score > 0, "detail": detail, "score": score, "content": dmarc_record}

def check_dnssec(domain: str) -> Dict:
    """Vérifie DNSSEC. MAX: 10 points."""
    output, error = run_dig(domain, "DNSKEY")
    exists = bool(output and "NXDOMAIN" not in output)
    score = 10 if exists else 0
    detail = "DNSSEC activé (DNSKEY trouvé)." if exists else "DNSSEC non activé."
    content = mask_record_content(output, "DNSKEY")
    
    return {"status": exists, "detail": detail, "score": score, "content": content}

def check_mta_sts(domain: str) -> Dict:
    """Vérifie MTA-STS. MAX: 10 points."""
    exists, detail, content = check_record_exists(f"_mta-sts.{domain}", "TXT")
    score = 10 if exists and "v=STSv1" in content else 0
    detail = "MTA-STS (v=STSv1) détecté." if score > 0 else "MTA-STS non détecté."
    
    return {"status": score > 0, "detail": detail, "score": score, "content": content}

def check_dane_tls(domain: str, mx_hosts: List[str]) -> Dict:
    """Vérifie DANE TLS. MAX: 10 points."""
    tlsa_records = []
    
    # 1. Check HTTPS sur le domaine principal
    https_target = f"_443._tcp.{domain}"
    output_https, _ = run_dig(https_target, "TLSA")
    if output_https and "NXDOMAIN" not in output_https and output_https.strip():
        tlsa_records.append(f"HTTPS ({domain})")

    # 2. Check SMTP sur chaque serveur MX (port 25)
    for host in mx_hosts[:5]:  # Limiter à 5 MX max
        if not host or not isinstance(host, str):
            continue
        
        # Nettoyer le hostname
        host = host.strip().lower().rstrip('.')
        if not host or len(host) < 3:
            continue
            
        smtp_target = f"_25._tcp.{host}"
        output_smtp, _ = run_dig(smtp_target, "TLSA")
        
        if output_smtp and "NXDOMAIN" not in output_smtp and output_smtp.strip():
            tlsa_records.append(f"SMTP ({host})")

    exists = bool(tlsa_records)
    score = 10 if exists else 0
    detail = f"DANE TLSA détecté ({', '.join(tlsa_records)})." if exists else "DANE TLSA non détecté."
    content = "\n".join(tlsa_records) if tlsa_records else "N/A"

    return {"status": exists, "detail": detail, "score": score, "content": content}

def check_bimi_record(domain: str) -> Dict:
    """Vérifie BIMI. MAX: 5 points."""
    bimi_domain = f"default._bimi.{domain}"
    exists, detail, content = check_record_exists(bimi_domain, "TXT")
    
    score = 0
    if exists and "v=BIMI1" in content:
        score = 5
        detail = "BIMI (v=BIMI1) trouvé."
    elif exists:
        detail = "TXT trouvé mais pas de tag v=BIMI1."
    else:
        detail = "BIMI non trouvé."
        
    return {"status": score > 0, "detail": detail, "score": score, "content": content}

# ========================================================================
# ORCHESTRATION
# ========================================================================

def check_all_dns_security(domain: str, dkim_selector: str = "default") -> Dict[str, Dict]:
    """
    Analyse complète de la sécurité DNS d'un domaine.
    Retourne un dictionnaire avec tous les résultats.
    """
    # Validation du domaine
    clean_domain = sanitize_domain(domain)
    if not clean_domain:
        return {
            "FATAL_ERROR": {
                "status": False, 
                "detail": "Nom de domaine invalide.", 
                "score": 0, 
                "content": "N/A"
            }
        }
    
    # Vérifier que dig est disponible
    if not dig_available():
        return {
            "FATAL_ERROR": {
                "status": False, 
                "detail": "L'outil 'dig' n'est pas installé sur le serveur.", 
                "score": 0, 
                "content": "N/A"
            }
        }
    
    # Lancer les checks
    mx_result = check_mx_record(clean_domain)
    mx_hosts = mx_result.get("mx_hosts", [])

    results = {
        "MX (Mail Base)": mx_result, 
        "SPF (Security)": check_spf_record(clean_domain),
        "DKIM": check_dkim_record(clean_domain, dkim_selector),
        "DMARC": check_dmarc_record(clean_domain),
        "DNSSEC": check_dnssec(clean_domain),
        "MTA-STS": check_mta_sts(clean_domain),
        "DANE TLS": check_dane_tls(clean_domain, mx_hosts), 
        "BIMI": check_bimi_record(clean_domain),
    }
    
    # Nettoyer mx_hosts du résultat final
    if "mx_hosts" in results["MX (Mail Base)"]:
        del results["MX (Mail Base)"]["mx_hosts"]
        
    return results

def get_max_score_breakdown() -> Dict[str, int]:
    """Retourne la répartition des scores maximums."""
    return PLAN_SCORES.copy()

def calculate_total_score(results: Dict[str, Dict]) -> Dict[str, Union[int, str]]:
    """Calcule le score total sur 100."""
    current_score = 0
    
    for protocol, data in results.items():
        if protocol in PLAN_SCORES:
            current_score += data.get("score", 0)
    
    final_score = max(0, min(current_score, MAX_SCORE_TOTAL))
    
    # Couleur
    if final_score >= 70:
        color = "green"
    elif final_score > 35:
        color = "orange"
    else:
        color = "red"
        
    return {
        "score": final_score, 
        "max_score": MAX_SCORE_TOTAL,
        "color": color
    }