# Copyright (c) 2025 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from datetime import datetime
import ssl
import re

from flask import Flask, request, jsonify
from requests.adapters import HTTPAdapter
from sqlalchemy.exc import IntegrityError, DisconnectionError, OperationalError

from oslo_log import log as logging
from openstack import connection
from keystoneauth1 import loading as ks_loading

from vmms.db.sqlalchemy.database import init_database_config
from vmms.db.sqlalchemy.database import get_session as get_db_session
from vmms.db.sqlalchemy.models import VMMigration, MigrationState
from vmms.mistral_client import MistralClient
from vmms.policy import enforcer as policy_enforcer

# Initialize logging
LOG = logging.getLogger(__name__)

# Global config storage to avoid re-parsing
_CONFIG = None

class LegacyTLSAdapter(HTTPAdapter):
    """Custom adapter that allows legacy TLS protocols"""
    def init_poolmanager(self, *args, **kwargs):
        # Create SSL context that allows older TLS versions
        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
        # Remove the SECLEVEL restriction to allow older protocols
        context.set_ciphers('DEFAULT')
        # Explicitly set minimum version to allow TLS 1.0
        context.minimum_version = ssl.TLSVersion.TLSv1
        context.maximum_version = ssl.TLSVersion.TLSv1_2
        
        kwargs['ssl_context'] = context
        return super().init_poolmanager(*args, **kwargs)

def is_valid_uuid(uuid_string):
    """Check if string is a valid UUID format"""
    try:
        uuid.UUID(uuid_string)
        return True
    except ValueError:
        return False

def is_valid_iso_datetime(datetime_string):
    """Check if string is a valid ISO datetime format"""
    try:
        datetime.fromisoformat(datetime_string.replace('Z', '+00:00'))
        return True
    except ValueError:
        return False

def validate_required_fields(data, required_fields):
    """Validate that all required fields are present"""
    return all(field in data and data[field] is not None for field in required_fields)

def get_config():
    """Get configuration, parsing it only once."""
    global _CONFIG
    if _CONFIG is None:
        from vmms import config
        _CONFIG = config.init_config()
    return _CONFIG

def get_vm_details_from_source_cloud(vm_identifier):
    # Get configuration (reuse existing config)
    CONF = get_config()

    try:
        auth         = ks_loading.load_auth_from_conf_options(CONF, "source_cloud")
        base_session = ks_loading.load_session_from_conf_options(CONF, "source_cloud", auth=auth, verify=False)
        if CONF.accept_legacy_tls:
            legacy_adapter = LegacyTLSAdapter()
            base_session.session.mount('https://', legacy_adapter)
        adapter      = ks_loading.load_adapter_from_conf_options(CONF, "source_cloud", session=base_session, auth=auth)
    except:
        LOG.error(f"Cloud not load configuration for source_cloud.")
        return None, None

    try:
        conn = connection.Connection(
            session=base_session,
            region_name=adapter.region_name,
            interface=adapter.interface,
        )
    except:
        LOG.error(f"Cloud not connect to source_cloud.")
        return None, None
            
    # Try to find VM by UUID first
    try:
        if is_valid_uuid(vm_identifier):
            servers = list(conn.compute.servers(all_projects=True, uuid=vm_identifier))
        else:
            servers = list(conn.compute.servers(all_projects=True, name=vm_identifier))
        if len(servers) == 1:
            server = servers[0]
            return server.id, server.name
        elif len(servers) > 1:
            LOG.error(f"Multiple VMs found with name: {vm_identifier} across all projects")
            return None, None
        return server.id, server.name

    except Exception as e:
        LOG.error(f"Failed to fetch VM details from source cloud: {e}", exc_info=True)
        return None, None

# Create single Flask app with all routes
app = Flask(__name__)

# wrapper decorator, so we can pass the CONF
# object ot the vmms/policy/enforcer.py
def require_vmms_policy(action):
    def decorator(f):
        @wraps(f)
        def wrapper(*args, **kwargs):
            CONF = get_config()
            return policy_enforcer.require_policy(CONF, action)(f)(*args, **kwargs)
        return wrapper
    return decorator

policy_checker = policy_enforcer.require_policy_factory(get_config)

@app.route('/healthcheck', methods=['GET'])
def health_check():
    """Health check endpoint - no authentication required"""
    return jsonify({'status': 'ok', 'service': 'vmms'}), 200

@app.route('/v2/vms', methods=['POST'])
@policy_checker('vmms:add')
def add_vm():
    """Add a VM to migration scheduler - authentication required"""
    try:
        data = request.get_json()
        
        if not data or 'vm_identifier' not in data:
            return jsonify({'error': 'Missing required field: vm_identifier'}), 400

        # Validate vm_identifier is not empty
        if not data['vm_identifier'] or not isinstance(data['vm_identifier'], str):
            return jsonify({'error': 'vm_identifier must be a non-empty string'}), 400
        
        vm_identifier = data['vm_identifier']
        LOG.info(f"Adding VM with identifier: {vm_identifier}")
        
        # Fetch VM details from source cloud
        vm_id, vm_name = get_vm_details_from_source_cloud(vm_identifier)
        
        if not vm_id or not vm_name:
            return jsonify({'error': f'Could not find VM with identifier: {vm_identifier}'}), 404
        
        LOG.info(f"Found VM: {vm_name} ({vm_id})")
        
        
        session = get_db_session()
        migration = VMMigration()
        migration.id = str(uuid.uuid4())
        migration.vm_id = vm_id
        migration.vm_name = vm_name
        
        if 'scheduled_time' in data and data['scheduled_time']:
            try:
                migration.scheduled_time = datetime.fromisoformat(data['scheduled_time'].replace('Z', '+00:00'))
            except ValueError:
                return jsonify({'error': 'Invalid scheduled_time format'}), 400
        
        migration.state = MigrationState.SCHEDULED
        
        session.add(migration)
        session.commit()
        
        result = migration.to_dict()
        return jsonify(result), 201
    except IntegrityError as e:
        session.rollback()
        if 'uniq_vm_id' in str(e) or 'vm_id' in str(e).lower():
           return jsonify({'error': f'VM with ID {vm_id} is already scheduled for migration'}), 409
        return jsonify({'error': 'Database integrity error'}), 500
    except Exception as e:
        LOG.error(f"Error adding VM: {e}", exc_info=True)
        if 'session' in locals():
            session.rollback()
        return jsonify({'error': str(e)}), 500

@app.route('/v2/vms', methods=['GET'])
@policy_checker('vmms:list')
def list_vms():
    """List all VM migrations - authentication required"""
    try:
        # Get query parameters
        state_filter = request.args.get('state', None)

        # Validate state parameter if provided
        if state_filter:
            # Get all valid state values from the enum
            valid_states = [state.value for state in MigrationState]
            if state_filter not in valid_states:
                return jsonify({
                    'error': f'Invalid state. Allowed: {", ".join(valid_states)}'
                }), 400

        session = get_db_session()

        # Build query with optional state filtering
        query = session.query(VMMigration)
        if state_filter:
            query = query.filter(VMMigration.state == state_filter)
        migrations = query.all()
        result = [migration.to_dict() for migration in migrations]
        return jsonify(result)
    except Exception as e:
        LOG.error(f"Error listing VMs: {e}", exc_info=True)
        return jsonify({'error': str(e)}), 500

@app.route('/v2/vms/<string:migration_id>', methods=['DELETE'])
@policy_checker('vmms:remove')
def remove_vm(migration_id):
    """Remove a VM from migration scheduler - authentication required"""
    # Validate migration_id format
    if not is_valid_uuid(migration_id):
        return jsonify({'error': 'Invalid migration ID format. Must be a valid UUID.'}), 400

    try:
        session = get_db_session()
        migration = session.query(VMMigration).filter_by(id=migration_id).first()
        
        if not migration:
            return jsonify({'error': f'Migration with ID {migration_id} not found'}), 404
            
        session.delete(migration)
        session.commit()
        
        return jsonify({'message': 'Migration removed successfully'}), 200
        
    except Exception as e:
        LOG.error(f"Error removing VM: {e}", exc_info=True)
        if 'session' in locals():
            session.rollback()
        return jsonify({'error': str(e)}), 500


@app.route('/v2/vms/<string:migration_id>', methods=['GET'])
@policy_checker('vmms:show')
def show_vm(migration_id):
    """Show details of a specific VM migration - authentication required"""
    # Validate migration_id format
    if not is_valid_uuid(migration_id):
        return jsonify({'error': 'Invalid migration ID format. Must be a valid UUID.'}), 400
        
    try:
        session = get_db_session()
        migration = session.query(VMMigration).filter_by(id=migration_id).first()
        
        if not migration:
            return jsonify({'error': f'Migration with ID {migration_id} not found'}), 404
            
        result = migration.to_dict()
        return jsonify(result), 200
        
    except Exception as e:
        LOG.error(f"Error showing VM: {e}", exc_info=True)
        return jsonify({'error': str(e)}), 500


@app.route('/v2/vms/<string:migration_id>', methods=['PUT'])
@policy_checker('vmms:update')
def update_vm(migration_id):
    """Update a VM migration - authentication required"""
    LOG.debug(f"🐛 API: update_vm called by user with migration_id: {migration_id}")
    LOG.debug(f"🐛 DEBUG: This should NEVER be reached if policy enforcement works!")

    # Validate migration_id format
    if not is_valid_uuid(migration_id):
        return jsonify({'error': 'Invalid migration ID format. Must be a valid UUID.'}), 400
        
    try:
        data = request.get_json()
        
        if not data:
            return jsonify({'error': 'No data provided'}), 400
            
        # Define allowed fields for update
        allowed_fields = ['scheduled_time', 'state']
        if not any(field in data for field in allowed_fields):
            return jsonify({'error': 'Missing field: scheduled_time or state'}), 400

        # Validate that only allowed fields are provided
        invalid_fields = [field for field in data.keys() if field not in allowed_fields]
        if invalid_fields:
            return jsonify({'error': f'Invalid fields: {", ".join(invalid_fields)}. Allowed: {", ".join(allowed_fields)}'}), 400

        # Validate scheduled_time format if provided
        if 'scheduled_time' in data and data['scheduled_time'] is not None:
            if not isinstance(data['scheduled_time'], str):
                return jsonify({'error': 'scheduled_time must be a string in ISO format'}), 400
            if not is_valid_iso_datetime(data['scheduled_time']):
                return jsonify({'error': 'Invalid scheduled_time format. Must be ISO 8601 format.'}), 400
        
        
        session = get_db_session()
        migration = session.query(VMMigration).filter_by(id=migration_id).first()
        
        if not migration:
            return jsonify({'error': f'Migration with ID {migration_id} not found'}), 404
            
        # Handle scheduled_time updates
        if 'scheduled_time' in data:
            # Allow unsetting scheduled_time for any state, but setting it only for scheduled migrations
            if data['scheduled_time'] is not None and migration.state != MigrationState.SCHEDULED:
                return jsonify({'error': 'Can only update scheduled_time for scheduled migrations'}), 400
                
            try:
                if data['scheduled_time']:
                    migration.scheduled_time = datetime.fromisoformat(data['scheduled_time'].replace('Z', '+00:00'))
                else:
                    migration.scheduled_time = None
            except ValueError:
                return jsonify({'error': 'Invalid scheduled_time format'}), 400
        
        # Handle state updates
        if 'state' in data:
            # Get all possible state values from the enum
            allowed_states = [state.value for state in MigrationState]
            if data['state'] not in allowed_states:
                return jsonify({'error': f'Invalid state. Allowed: {", ".join(allowed_states)}'}), 400
            migration.state = data['state']

        # Handle workflow_exec reset when state is set to SCHEDULED
        if 'state' in data and data['state'] == MigrationState.SCHEDULED.value:
            migration.workflow_exec = None

        # Validate that workflow_exec is not being set via API
        if 'workflow_exec' in data:
            return jsonify({'error': 'workflow_exec is read-only and cannot be set via API'}), 400
        
        session.commit()
        result = migration.to_dict()
        return jsonify(result), 200
        
    except Exception as e:
        LOG.error(f"Error updating VM: {e}", exc_info=True)
        if 'session' in locals():
            session.rollback()
        return jsonify({'error': str(e)}), 500


@app.route('/v2/vms/<string:migration_id>/output', methods=['GET'])
@policy_checker('vmms:output')
def get_vm_output(migration_id):
    """Get output from VM migration - authentication required"""
    LOG.debug(f"🐛 GET /v2/vms/{migration_id}/output called")

    # Validate migration_id format
    if not is_valid_uuid(migration_id):
        LOG.debug(f"🐛 Invalid UUID format: {migration_id}")
        return jsonify({'error': 'Invalid migration ID format. Must be a valid UUID.'}), 400

    try:
        # Get database session
        session = get_db_session()
        migration = session.query(VMMigration).filter_by(id=migration_id).first()

        if not migration:
            LOG.debug(f"🐛 Migration not found: {migration_id}")
            return jsonify({'error': f'Migration with ID {migration_id} not found'}), 404

        # Check if migration has a workflow execution
        if not migration.workflow_exec:
            LOG.debug(f"🐛 No workflow execution for migration: {migration_id}")
            return jsonify({'error': f'Migration {migration_id} has no workflow execution'}), 404

        # Debug state values
        LOG.debug(f"🐛 Migration state: '{migration.state}' (type: {type(migration.state)})")
        LOG.debug(f"🐛 MIGRATED: '{MigrationState.MIGRATED}' (type: {type(MigrationState.MIGRATED)})")
        LOG.debug(f"🐛 ERROR: '{MigrationState.ERROR}' (type: {type(MigrationState.ERROR)})")

        # Check if migration state allows output retrieval (must be MIGRATED or ERROR)
        allowed_states = [MigrationState.MIGRATED, MigrationState.ERROR]
        LOG.debug(f"🐛 Allowed states: {allowed_states}")
        LOG.debug(f"🐛 State check result: {migration.state in allowed_states}")

        if migration.state not in allowed_states:
            LOG.debug(f"🐛 Invalid state for output retrieval: {migration.state}")
            return jsonify({
                'error': f'Cannot retrieve output for migration in state {migration.state}. '
                         f'Output is only available for MIGRATED or ERROR states.'
            }), 400

        # Get Mistral client and retrieve output
        CONF = get_config()
        from vmms.mistral_client import MistralClient
        mistral_client = MistralClient(CONF)

        # Get output from Mistral execution
        output = mistral_client.get_migration_output(migration.workflow_exec)
        LOG.debug(f"✅ Successfully retrieved output for migration: {migration_id}")

        return jsonify(output), 200
        
    except Exception as e:
        LOG.error(f"⧱ Error getting VM output for migration {migration_id}: {e}", exc_info=True)
        return jsonify({'error': f'Failed to retrieve migration output: {str(e)}'}), 500


def get_application():
    """Create WSGI application with Keystone authentication"""
    try:
        CONF = get_config()
        logging.setup(CONF, 'vmms-api')
        
        # Initialize database configuration
        init_database_config(CONF)

        # Apply Keystone middleware with delay_auth_decision=true
        from keystonemiddleware import auth_token
        middleware_conf = dict(CONF.keystone_authtoken)
        # Ensure delay_auth_decision is true for selective auth
        middleware_conf['delay_auth_decision'] = True
        application = auth_token.filter_factory(middleware_conf)(app)
        return application
    except Exception as e:
        LOG.error(f"Error creating application: {e}", exc_info=True)
        return app
