import subprocess
import re
import shutil
from typing import Dict, Tuple, List, Optional, Union

# --- Configuration ---
# Le score max total est maintenant de 100 points exactement, sans normalisation supplémentaire.
MAX_SCORE_TOTAL = 100 

# --- Utilitaires de base (inchangés) ---

def dig_available() -> bool:
    """Vérifie si 'dig' est disponible dans le PATH."""
    return shutil.which("dig") is not None

def run_dig(domain: str, record_type: str, server: str = "@8.8.8.8", timeout: int = 8) -> Tuple[str, str]:
    command = ["dig", record_type, domain, server, "+short"]
    try:
        result = subprocess.run(command, capture_output=True, text=True, timeout=timeout)
        stdout = result.stdout.strip()
        stderr = result.stderr.strip()
        if result.returncode != 0 and not stdout:
            return "", f"Dig command failed with return code {result.returncode}. Stderr: {stderr}"
        return stdout, stderr
    except FileNotFoundError:
        return "", "ERROR: L'outil 'dig' est introuvable."
    except subprocess.TimeoutExpired:
        return "", "ERROR: Timeout lors de l'exécution de dig."
    except Exception as e:
        return "", f"ERROR: Exception lors de l'exécution de dig: {e}"

def check_record_exists(domain: str, record_type: str) -> Tuple[bool, str, str]:
    output, error = run_dig(domain, record_type)
    if error and "introuvable" in error:
        return False, error, "N/A"
    exists = bool(output and "NXDOMAIN" not in output)
    detail = "Enregistrement trouvé." if exists else "Enregistrement non trouvé ou résultat ambigu."
    return exists, detail, output if exists else "N/A"

# --- Fonctions de masquage (mise à jour) ---

def mask_record_content(content: str, record_type: str) -> str:
    """Tronque les clés DKIM, DNSSEC, et DANE pour des raisons de confidentialité et les affiche de manière concise."""
    if not content or content == "N/A":
        return content
    
    # Masquage DNSSEC (KEY) et DANE (TLSA)
    if record_type in ["DNSKEY", "TLSA"]:
        lines = content.splitlines()
        first_line = lines[0].strip() if lines else content.strip()
        
        parts = first_line.split()
        if len(parts) >= 2:
            key_part = parts[-1]
            if len(key_part) > 30: # C'est probablement une clé
                masked_key = key_part[:10] + "..." + key_part[-10:]
                # Affichage concis (retrait du texte descriptif long)
                return f"{masked_key} [MASQUÉ]"
        
        # Fallback (ceci remplace la suppression de '[Contenu cryptographique masqué]')
        # Si la clé est trop courte ou le format inattendu, on garde juste un marqueur de masquage.
        return "[Clé masquée]" 

    # Masquage DKIM (clé publique)
    if record_type == "DKIM":
        m = re.search(r'p=([^;"]+)', content)
        if m:
            key = m.group(1).strip().replace('"', '')
            # Tronquer la clé publique
            masked_key = key[:10] + "..." + key[-10:] if len(key) > 20 else key
            # Affichage concis (retrait du texte descriptif long)
            return f"{masked_key} [MASQUÉE]"
        return content
    
    return content


# --- Checks (SCORING FINALISÉ SUR BASE 100) ---

def check_mx_record(domain: str) -> Dict:
    # MAX SCORE: 5
    exists, detail, content = check_record_exists(domain, "MX")
    score = 0
    mx_hosts = []
    if exists and content:
        lines = [l.strip() for l in content.splitlines() if l.strip()]
        for line in lines:
            parts = line.split()
            if len(parts) >= 2 and parts[-1].endswith('.'):
                mx_hosts.append(parts[-1].rstrip('.'))
        
        if len(mx_hosts) >= 2:
            score = 5 # Max
            detail = f"Plusieurs serveurs MX trouvés ({len(mx_hosts)}). Redondance OK."
        elif len(mx_hosts) == 1:
            score = 3 # Partiel
            detail = "Un seul serveur MX trouvé. Redondance faible."
        
        content_to_display = "\n".join(mx_hosts) if mx_hosts else "N/A"
    else:
        content_to_display = "N/A"

    return {"status": score > 0, "detail": detail, "score": score, "content": content_to_display, "mx_hosts": mx_hosts}

def check_spf_record(domain: str) -> Dict:
    # MAX SCORE: 25
    exists, detail, content = check_record_exists(domain, "TXT")
    spf_record = "N/A"
    score = 0
    if exists:
        lines = [l.strip().replace('"', '') for l in content.splitlines() if "v=spf1" in l.lower()]
        if lines:
            spf_record = lines[0]
            m = re.search(r'([-\~\?])all', spf_record)
            if m:
                kind = m.group(1)
                if kind == "-":
                    score = 25 # Max (Strict)
                    detail = "Enregistrement SPF trouvé avec mécanisme strict (-all)."
                elif kind == "~":
                    score = 20 # Bon (Soft-fail)
                    detail = "Enregistrement SPF trouvé avec soft-fail (~all)."
                else: # ?all
                    score = 15 # Faible (Neutral)
                    detail = "Enregistrement SPF trouvé avec ?all."
            else:
                score = 5
                detail = "Enregistrement SPF trouvé mais sans mécanisme de fin."
        else:
            detail = "Aucun enregistrement SPF ('v=spf1') trouvé dans les TXT."
    else:
        detail = "Aucun enregistrement TXT trouvé."
        
    return {"status": score > 0, "detail": detail, "score": score, "content": spf_record}

def check_dkim_record(domain: str, selector: str = "default") -> Dict:
    # MAX SCORE: 10
    selector = selector or "default"
    dkim_domain = f"{selector}._domainkey.{domain}"
    exists, detail, content = check_record_exists(dkim_domain, "TXT")
    dkim_record = "N/A"
    score = 0
    content_to_display = "N/A"

    if exists:
        lines = [l.strip().replace('"', '') for l in content.splitlines() if "v=DKIM1" in l]
        if lines:
            dkim_record = lines[0].strip()
            detail = f"Enregistrement DKIM trouvé pour le sélecteur '{selector}'."
            score = 10 # Max
            content_to_display = mask_record_content(dkim_record, "DKIM")
        else:
            detail = f"TXT trouvé pour {dkim_domain} mais pas de tag v=DKIM1."
            score = 2
    else:
        detail = f"Aucun enregistrement DKIM TXT trouvé pour {dkim_domain}."
    
    return {"status": score >= 2, "detail": detail, "score": score, "content": content_to_display}

def check_dmarc_record(domain: str) -> Dict:
    # MAX SCORE: 25
    dmarc_domain = f"_dmarc.{domain}"
    exists, detail, content = check_record_exists(dmarc_domain, "TXT")
    dmarc_record = "N/A"
    score = 0
    
    if exists:
        raw_content_cleaned = content.replace('\n', ' ').replace('"', '').strip()
        m_dmarc = re.match(r'(.*v=DMARC1[^;]*;.*)', raw_content_cleaned, flags=re.IGNORECASE)
        
        if m_dmarc:
            dmarc_record = m_dmarc.group(1).strip()
            m_policy = re.search(r'p=([^;]+)', dmarc_record, flags=re.IGNORECASE)
            
            if m_policy:
                policy = m_policy.group(1).lower().strip()
                # SCORING DMARC
                if policy == "reject":
                    score = 25 # Max (Protection stricte)
                    detail = "Politique DMARC reject trouvée. (Protection maximale)."
                elif policy == "quarantine":
                    score = 15 # Bon (Protection active)
                    detail = "Politique DMARC quarantine trouvée. (Protection active)."
                elif policy == "none":
                    score = 5 # Faible (Monitoring)
                    detail = "Politique DMARC de monitoring (p=none) trouvée. (Non protectrice, mais présent)."
                else:
                    score = 2
                    detail = f"DMARC trouvé avec politique inconnue/invalide: p={policy}."
            else:
                score = 2
                detail = "DMARC trouvé, mais la politique (p=) est manquante ou invalide."
        else:
            detail = "Aucun enregistrement DMARC ('v=DMARC1') trouvé dans _dmarc."
    else:
        detail = "Aucun enregistrement TXT pour _dmarc."
        
    return {"status": score > 0, "detail": detail, "score": score, "content": dmarc_record}

def check_dnssec(domain: str) -> Dict:
    # MAX SCORE: 10
    output, error = run_dig(domain, "DNSKEY")
    exists = bool(output and "NXDOMAIN" not in output)
    score = 10 if exists else 0
    detail = "DNSSEC activé (DNSKEY trouvé)." if exists else "DNSSEC non activé ou DS Record manquant."
    
    content_to_display = mask_record_content(output, "DNSKEY")

    return {"status": exists, "detail": detail, "score": score, "content": content_to_display}

def check_mta_sts(domain: str) -> Dict:
    # MAX SCORE: 10
    exists, detail, content = check_record_exists("_mta-sts."+domain, "TXT")
    score = 10 if exists and "v=STSv1" in content else 0
    detail = "MTA-STS TXT record (v=STSv1) détecté." if score > 0 else "MTA-STS non détecté."
    return {"status": score > 0, "detail": detail, "score": score, "content": content}

def check_dane_tls(domain: str, mx_hosts: List[str]) -> Dict:
    # MAX SCORE: 10
    tlsa_records = []
    
    output_https, _ = run_dig(f"_443._tcp.{domain}", "TLSA")
    if output_https and "NXDOMAIN" not in output_https:
        tlsa_records.append(f"HTTPS ({domain}): {output_https}")

    for host in mx_hosts:
        tlsa_domain = f"_25._tcp.{host}"
        output_smtp, _ = run_dig(tlsa_domain, "TLSA")
        if output_smtp and "NXDOMAIN" not in output_smtp:
            tlsa_records.append(f"SMTP ({host}): {output_smtp}")

    exists = bool(tlsa_records)
    score = 10 if exists else 0
    detail = "DANE TLSA détecté (via HTTPS et/ou SMTP/MX)." if exists else "DANE TLSA non détecté sur le domaine ou les hôtes MX."
    
    content_to_display = mask_record_content("\n".join(tlsa_records), "TLSA")

    return {"status": exists, "detail": detail, "score": score, "content": content_to_display}

def check_bimi_record(domain: str) -> Dict:
    """Vérifie l'enregistrement BIMI (TXT) dans le sous-domaine 'default._bimi'."""
    # MAX SCORE: 5
    bimi_domain = f"default._bimi.{domain}"
    exists, detail, content = check_record_exists(bimi_domain, "TXT")
    
    score = 0
    if exists and "v=BIMI1" in content:
        score = 5 # Max
        detail = "Enregistrement BIMI (v=BIMI1) trouvé."
    elif exists:
        detail = "TXT trouvé pour BIMI mais pas de tag v=BIMI1."
    else:
        detail = "Aucun enregistrement BIMI (default._bimi) trouvé."
        
    return {"status": score > 0, "detail": detail, "score": score, "content": content}

# --- Orchestration ---

def check_all_dns_security(domain: str, dkim_selector: str = "default") -> Dict[str, Dict]:
    if not dig_available():
        return {"FATAL_ERROR": {"status": False, "detail": "L'outil 'dig' est introuvable.", "score": 0, "content": "N/A"}}
    
    mx_result = check_mx_record(domain)
    mx_hosts = mx_result.get("mx_hosts", [])

    results = {
        "MX (Mail Base)": mx_result, 
        "SPF (Security)": check_spf_record(domain),
        "DKIM": check_dkim_record(domain, dkim_selector),
        "DMARC": check_dmarc_record(domain),
        "DNSSEC": check_dnssec(domain),
        "MTA-STS": check_mta_sts(domain),
        "DANE TLS": check_dane_tls(domain, mx_hosts), 
        "BIMI": check_bimi_record(domain),
    }
    
    if "mx_hosts" in results["MX (Mail Base)"]:
        del results["MX (Mail Base)"]["mx_hosts"]
        
    return results

def get_max_score_breakdown() -> Dict[str, int]:
    """Retourne la répartition des scores maximums (total 100 points)."""
    return {
        "MX (Mail Base)": 5,
        "SPF (Security)": 25,
        "DKIM": 10,
        "DMARC": 25,
        "DNSSEC": 10,
        "MTA-STS": 10,
        "DANE TLS": 10,
        "BIMI": 5,
    }

def calculate_total_score(results: Dict[str, Dict]) -> Dict[str, Union[int, str]]:
    """Calcule le score total sur 100 et la couleur associée. (Pas de normalisation nécessaire)."""
    max_scores = get_max_score_breakdown()
    
    current_score = 0
    max_possible_score = MAX_SCORE_TOTAL 
    
    for protocol, data in results.items():
        if protocol in max_scores:
            current_score += data.get("score", 0)
            
    final_score = max(0, min(current_score, max_possible_score))
        
    # Détermination de la couleur (Rouge < 35, Orange < 70, Vert >= 70)
    color = "red" 
    if final_score > 35:
        color = "orange"
    if final_score >= 70:
        color = "green"
        
    return {
        "score": final_score, 
        "max_score": max_possible_score,
        "color": color
    }