# Fichier: models.py (VERSION POSTGRESQL + UUID READY)

import uuid
from datetime import datetime, timezone
from flask_login import UserMixin
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import validates
import pyotp
import re
from itsdangerous import URLSafeTimedSerializer as Serializer, SignatureExpired, BadSignature
from flask import current_app
from extensions import db, login_manager
# ========================================================================
# INITIALISATION DB
# ========================================================================
# db = SQLAlchemy()

# ========================================================================
# FONCTION UTILITAIRE POUR GÉNÉRER LE SECRET 2FA
# ========================================================================

def generate_2fa_secret():
    """Génère un secret aléatoire pour TOTP (2FA)."""
    return pyotp.random_base32()

# ========================================================================
# MODÈLE USER
# ========================================================================

class User(db.Model, UserMixin):
    """
    Modèle représentant un utilisateur du système.
    Supporte 3 rôles : admin, manager, user (client final).
    """
    __tablename__ = 'users'  # ✅ CHANGÉ : 'user' -> 'users' (convention)
    
    # ✅ UUID comme clé primaire (String(36) compatible SQLite + PostgreSQL)
    id = db.Column(
        db.String(36), 
        primary_key=True, 
        default=lambda: str(uuid.uuid4())
    )
    
    email = db.Column(db.String(120), unique=True, nullable=False, index=True)
    password = db.Column(db.String(200), nullable=False)  # ✅ CHANGÉ : 60 -> 200 (bcrypt peut varier)
    company_name = db.Column(db.String(150), nullable=True)
    logo_filename = db.Column(db.String(255), nullable=True)  # Nom fichier logo uploadé
    show_logo_on_reports = db.Column(db.Boolean, default=False, nullable=False)
    show_company_name_on_reports = db.Column(db.Boolean, default=False, nullable=False)
    use_manager_branding = db.Column(db.Boolean, default=True, nullable=False)  # Pour users uniquement
    
    # Rôles et permissions
    role = db.Column(db.String(20), default='user', nullable=False, index=True)
    is_active_account = db.Column(db.Boolean, default=True, nullable=False)
    can_audit = db.Column(db.Boolean, default=True, nullable=False)
    # Multi-connexion (pour managers)
    max_concurrent_sessions = db.Column(db.Integer, default=1, nullable=False)
    # 1 = users normaux, 5+ = managers, -1 = illimité (admins)
    
    # Dates
    engagement_end_date = db.Column(db.DateTime, nullable=True)
    created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), nullable=False)
    updated_at = db.Column(
        db.DateTime, 
        default=lambda: datetime.now(timezone.utc), 
        onupdate=lambda: datetime.now(timezone.utc),
        nullable=False
    )
    
    # 2FA
    two_factor_secret = db.Column(db.String(32), nullable=True)
    two_factor_enabled = db.Column(db.Boolean, default=False, nullable=False)
    two_factor_recovery_codes = db.Column(db.Text, nullable=True)
    
    # Relations Manager/Client (auto-référence)
    # ✅ Foreign Key vers users.id (UUID String)
    manager_id = db.Column(
        db.String(36), 
        db.ForeignKey('users.id'), 
        nullable=True, 
        index=True
    )

    # ✅ NOUVEAU : Relation pour les membres d'équipe (Sub-users)
    parent_id = db.Column(
        db.String(36),
        db.ForeignKey('users.id'),
        nullable=True,
        index=True
    )
    
    # ✅ Relation bidirectionnelle
    clients = db.relationship(
        'User', 
        backref=db.backref('manager', remote_side=[id]), 
        lazy='select',  # ✅ CHANGÉ : lazy=True -> lazy='dynamic' (meilleur pour queries)
        cascade='save-update, merge',
        foreign_keys=[manager_id]
    )

    # ✅ NOUVEAU : Relation Équipe (Le patron accède à ses employés)
    team_members = db.relationship(
        'User',
        backref=db.backref('parent', remote_side=[id]),
        lazy='select',
        foreign_keys=[parent_id]
    )
    
    
    # Relation avec rapports
    rapports = db.relationship(
        'Rapport', 
        backref='auteur', 
        lazy='select',  # ✅ CHANGÉ
        cascade='all, delete-orphan'
    )

    def __init__(self, **kwargs):
        super(User, self).__init__(**kwargs)
        # Génération automatique du secret 2FA
        if not self.two_factor_secret:
            self.two_factor_secret = generate_2fa_secret()

    # ========================================================================
    # VALIDATIONS
    # ========================================================================
    
    @validates('email')
    def validate_email(self, key, email):
        if not email:
            raise ValueError("L'email ne peut pas être vide")
        email_regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if not re.match(email_regex, email):
            raise ValueError(f"Format d'email invalide : {email}")
        return email.lower().strip()
    
    @validates('role')
    def validate_role(self, key, role):
        allowed_roles = ['admin', 'manager', 'user']
        if role not in allowed_roles:
            raise ValueError(f"Rôle invalide : {role}. Doit être : {', '.join(allowed_roles)}")
        return role
    
    @validates('company_name')
    def validate_company_name(self, key, name):
        if name and len(name) > 150:
            raise ValueError("Le nom d'entreprise ne peut dépasser 150 caractères")
        return name.strip() if name else None

    # ========================================================================
    # MÉTHODES FLASK-LOGIN
    # ========================================================================
    
    def is_active(self):
        """Requis par Flask-Login pour vérifier si l'utilisateur peut se connecter."""
        return self.is_active_account

    # ========================================================================
    # MÉTHODES - RÉINITIALISATION MOT DE PASSE
    # ========================================================================

    def get_reset_token(self, expires_sec=1800, secret_key=None):
        """Génère un token sécurisé pour réinitialiser le mot de passe."""
        key = str(secret_key) if secret_key else str(current_app.config['SECRET_KEY'])
        s = Serializer(key)
        return s.dumps({'user_id': self.id}, salt='password-reset')

    @staticmethod
    def verify_reset_token(token, secret_key=None, max_age=1800):
        """Vérifie un token de réinitialisation et retourne l'utilisateur."""
        key = str(secret_key) if secret_key else str(current_app.config['SECRET_KEY'])
        s = Serializer(key)
        try:
            user_id = s.loads(token, salt='password-reset', max_age=max_age)['user_id']
            
            # Validation UUID format (défense en profondeur)
            if not isinstance(user_id, str) or len(user_id) != 36:
                return None
                
        except (SignatureExpired, BadSignature):
            return None
        
        return User.query.get(user_id)

    # ========================================================================
    # MÉTHODES - 2FA (TOTP)
    # ========================================================================

    def get_totp_uri(self):
        """Retourne l'URI pour générer le QR code 2FA."""
        return pyotp.totp.TOTP(self.two_factor_secret).provisioning_uri(
            name=self.email,
            issuer_name="MCyber Consulting"
        )

    def verify_totp_code(self, token):
        """Vérifie un code TOTP (2FA) avec fenêtre de tolérance."""
        totp = pyotp.TOTP(self.two_factor_secret)
        return totp.verify(token, valid_window=1)

    # ========================================================================
    # MÉTHODES UTILITAIRES
    # ========================================================================
    
    def can_create_audit(self):
        """Vérifie si l'utilisateur peut créer un audit."""
        return self.is_active_account and self.can_audit
    
    def is_manager_of(self, user):
        """Vérifie si cet utilisateur est le manager d'un autre."""
        return self.role == 'manager' and user.manager_id == self.id
    
    def get_all_clients(self):
        """Retourne tous les clients gérés par ce manager."""
        if self.role != 'manager':
            return []
        return User.query.filter_by(manager_id=self.id).all()
    
    def has_clients(self):
        """Vérifie si ce manager a des clients."""
        if self.role != 'manager':
            return False
        return User.query.filter_by(manager_id=self.id).count() > 0

    def __repr__(self):
        return f"<User {self.email} (Role: {self.role}, Active: {self.is_active_account})>"

    @property
    def owner_id(self):
        """
        Retourne l'ID du propriétaire des données.
        - Si je suis le Manager Principal : retourne mon ID.
        - Si je suis un Membre d'équipe : retourne l'ID de mon parent.
        """
        return self.parent_id if self.parent_id else self.id

    def get_team_clients(self):
        """
        Retourne les clients gérés par l'équipe entière (Patron + Membres).
        """
        if self.role != 'manager':
            return []
        
        # On cherche tous les clients dont le manager est SOIT moi, SOIT mon patron
        target_id = self.owner_id
        # Note: On utilise User ici, assure-toi que l'interpréteur le comprend
        # Sinon utilise self.__class__.query...
        return User.query.filter_by(manager_id=target_id).all()

        

# ========================================================================
# MODÈLE RAPPORT
# ========================================================================

class Rapport(db.Model):
    """
    Modèle représentant un rapport d'audit de cybersécurité.
    """
    __tablename__ = 'rapports'  # ✅ CHANGÉ : 'rapport' -> 'rapports' (convention pluriel)
    
    # ✅ UUID comme clé primaire
    id = db.Column(
        db.String(36), 
        primary_key=True, 
        default=lambda: str(uuid.uuid4())
    )
    
    # Métadonnées
    date_creation = db.Column(
        db.DateTime, 
        default=lambda: datetime.now(timezone.utc), 
        nullable=False, 
        index=True
    )
    nom_client = db.Column(db.String(150), default='Client Inconnu', nullable=False)
    
    # Score et données
    score_total = db.Column(db.Integer, nullable=False)
    raw_data = db.Column(db.Text, nullable=True)
    preconisations_json = db.Column(db.Text, nullable=False)
    
    # ✅ Foreign Key vers users.id (UUID String)
    user_id = db.Column(
        db.String(36), 
        db.ForeignKey('users.id'), 
        nullable=False, 
        index=True
    )

    # ========================================================================
    # VALIDATIONS
    # ========================================================================
    
    @validates('score_total')
    def validate_score(self, key, score):
        if not isinstance(score, int):
            raise ValueError("Le score doit être un entier")
        if not 0 <= score <= 100:
            raise ValueError(f"Le score doit être entre 0 et 100 (reçu: {score})")
        return score
    
    @validates('nom_client')
    def validate_nom_client(self, key, nom):
        if not nom or not nom.strip():
            return 'Client Inconnu'
        if len(nom) > 150:
            raise ValueError("Le nom du client ne peut dépasser 150 caractères")
        return nom.strip()

# ========================================================================
# MODÈLE ANNOUNCEMENT (BANDEAU ADMIN)
# ========================================================================

class Announcement(db.Model):
    """
    Modèle représentant un bandeau d'information affiché globalement.
    """
    __tablename__ = 'announcements'
    
    id = db.Column(
        db.String(36), 
        primary_key=True, 
        default=lambda: str(uuid.uuid4())
    )
    
    message = db.Column(db.Text, nullable=False)
    target = db.Column(db.String(20), default='all', nullable=False)  # all, manager, user
    is_active = db.Column(db.Boolean, default=True, nullable=False)
    created_at = db.Column(
        db.DateTime, 
        default=lambda: datetime.now(timezone.utc), 
        nullable=False
    )
    updated_at = db.Column(
        db.DateTime, 
        default=lambda: datetime.now(timezone.utc), 
        onupdate=lambda: datetime.now(timezone.utc),
        nullable=False
    )
    
    @validates('target')
    def validate_target(self, key, target):
        allowed_targets = ['all', 'manager', 'user']
        if target not in allowed_targets:
            raise ValueError(f"Cible invalide : {target}")
        return target
    
    @validates('message')
    def validate_message(self, key, message):
        if not message or not message.strip():
            raise ValueError("Le message ne peut pas être vide")
        if len(message) > 500:
            raise ValueError("Le message ne peut dépasser 500 caractères")
        return message.strip()
    
    def __repr__(self):
        return f"<Announcement {self.id[:8]}... (Active: {self.is_active}, Target: {self.target})>"

        # ========================================================================
# MODÈLE DOCUMENT (DOCUMENTATION)
# ========================================================================

class Document(db.Model):
    """
    Modèle représentant un document PDF téléchargeable.
    """
    __tablename__ = 'documents'
    
    id = db.Column(
        db.String(36), 
        primary_key=True, 
        default=lambda: str(uuid.uuid4())
    )
    
    title = db.Column(db.String(150), nullable=False)
    category = db.Column(db.String(50), default='guide', nullable=False)  # guide, cgv, other
    filename = db.Column(db.String(255), nullable=False)  # Nom fichier stocké
    original_filename = db.Column(db.String(255), nullable=False)  # Nom original
    file_size = db.Column(db.Integer, nullable=False)  # En octets
    
    created_at = db.Column(
        db.DateTime, 
        default=lambda: datetime.now(timezone.utc), 
        nullable=False
    )
    uploaded_by = db.Column(
        db.String(36), 
        db.ForeignKey('users.id'), 
        nullable=False
    )
    
    # Relation avec l'admin qui a uploadé
    uploader = db.relationship('User', backref='uploaded_documents', lazy='select')
    
    @validates('category')
    def validate_category(self, key, category):
        allowed_categories = ['guide', 'cgv', 'other']
        if category not in allowed_categories:
            raise ValueError(f"Catégorie invalide : {category}")
        return category
    
    @validates('title')
    def validate_title(self, key, title):
        if not title or not title.strip():
            raise ValueError("Le titre ne peut pas être vide")
        if len(title) > 150:
            raise ValueError("Le titre ne peut dépasser 150 caractères")
        return title.strip()
    
    def __repr__(self):
        return f"<Document {self.title} ({self.category})>"

        # ========================================================================
# AUDIT LOGS (TRAÇABILITÉ & DEBUG)
# ========================================================================

class AuditLog(db.Model):
    """
    Logs d'audit pour traçabilité et debug.
    Séparation : Logs activité (INFO) vs Logs erreurs (ERROR/CRITICAL)
    """
    __tablename__ = 'audit_logs'
    
    id = db.Column(db.String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    timestamp = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
    
    # Classification
    level = db.Column(db.String(20), nullable=False, index=True)  # INFO, WARNING, ERROR, CRITICAL
    category = db.Column(db.String(50), nullable=False, index=True)  # AUTH, USER, REPORT, ADMIN, SYSTEM, EMAIL
    action = db.Column(db.String(100), nullable=False)  # LOGIN_SUCCESS, USER_CREATED, PDF_ERROR, etc.
    
    # Contexte
    user_id = db.Column(db.String(36), db.ForeignKey('users.id'), nullable=True, index=True)
    target_id = db.Column(db.String(36), nullable=True)  # ID de l'objet concerné (rapport, user, etc.)
    ip_address = db.Column(db.String(45), nullable=True)
    user_agent = db.Column(db.String(255), nullable=True)
    
    # Contenu
    message = db.Column(db.Text, nullable=False)  # Message principal
    error_details = db.Column(db.Text, nullable=True)  # Stack trace si erreur
    suggested_fix = db.Column(db.Text, nullable=True)  # Solutions proposées (pour erreurs)
    
    # Métadonnées
    resolved = db.Column(db.Boolean, default=False, nullable=False)  # Erreur résolue ?
    resolved_at = db.Column(db.DateTime, nullable=True)
    resolved_by = db.Column(db.String(36), db.ForeignKey('users.id'), nullable=True)
    
    # Relations
    user = db.relationship('User', foreign_keys=[user_id], backref='audit_logs', lazy='select')
    resolver = db.relationship('User', foreign_keys=[resolved_by], lazy='select')
    
    def __repr__(self):
        return f'<AuditLog {self.level} {self.action} @ {self.timestamp}>'
    
    @staticmethod
    def log_info(category, action, message, user_id=None, target_id=None, ip=None, user_agent=None):
        """Helper pour logs INFO (activité normale)"""
        log = AuditLog(
            level='INFO',
            category=category,
            action=action,
            message=message,
            user_id=user_id,
            target_id=target_id,
            ip_address=ip,
            user_agent=user_agent
        )
        db.session.add(log)
        try:
            db.session.commit()
        except Exception as e:
            db.session.rollback()
            current_app.logger.error(f"Erreur sauvegarde log INFO: {e}")
    
    @staticmethod
    def log_error(category, action, message, error_details=None, suggested_fix=None, user_id=None, target_id=None, ip=None, level='ERROR'):
        """Helper pour logs ERROR/CRITICAL (bugs, erreurs)"""
        log = AuditLog(
            level=level,
            category=category,
            action=action,
            message=message,
            error_details=error_details,
            suggested_fix=suggested_fix,
            user_id=user_id,
            target_id=target_id,
            ip_address=ip
        )
        db.session.add(log)
        try:
            db.session.commit()
        except Exception as e:
            db.session.rollback()
            current_app.logger.error(f"Erreur sauvegarde log ERROR: {e}")
    
    def mark_resolved(self, resolver_id):
        """Marquer une erreur comme résolue"""
        self.resolved = True
        self.resolved_at = datetime.now(timezone.utc)
        self.resolved_by = resolver_id
        db.session.commit()

        # ========================================================================
# USER SESSIONS (MULTI-CONNEXION)
# ========================================================================

class UserSession(db.Model):
    """
    Tracker des sessions actives pour multi-connexion.
    Permet de gérer les limites de connexions simultanées.
    """
    __tablename__ = 'user_sessions'
    
    id = db.Column(db.String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    user_id = db.Column(db.String(36), db.ForeignKey('users.id'), nullable=False, index=True)
    session_token = db.Column(db.String(255), unique=True, nullable=False, index=True)
    
    # Métadonnées connexion
    ip_address = db.Column(db.String(45), nullable=True)
    user_agent = db.Column(db.String(255), nullable=True)
    device_info = db.Column(db.String(100), nullable=True)  # "Chrome (Mac)", "Firefox (Windows)"
    location = db.Column(db.String(100), nullable=True)  # Ville si disponible
    
    # Timestamps
    created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), nullable=False, index=True)
    last_activity = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc), nullable=False)
    expires_at = db.Column(db.DateTime, nullable=False)
    
    # Relation
    user = db.relationship('User', backref='active_sessions', lazy='select')
    
    def __repr__(self):
        return f'<UserSession {self.user_id} from {self.ip_address}>'
    
    @staticmethod
    def cleanup_expired():
        """Supprime les sessions expirées"""
        from extensions import db
        expired = UserSession.query.filter(UserSession.expires_at < datetime.now(timezone.utc)).all()
        for session in expired:
            db.session.delete(session)
        db.session.commit()
        return len(expired)
    
    @staticmethod
    def get_active_count(user_id):
        """Compte le nombre de sessions actives pour un utilisateur"""
        UserSession.cleanup_expired()
        return UserSession.query.filter_by(user_id=user_id).count()
    
    def update_activity(self):
        """Met à jour le timestamp d'activité"""
        self.last_activity = datetime.now(timezone.utc)
        db.session.commit()

