# server.py
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
from typing import Dict, Any, Optional
from pydantic import BaseModel, validator, Field
import jwt
import bcrypt
import mysql.connector
from datetime import datetime, timedelta
import re
import json
import os
from dotenv import load_dotenv
from game_settings_validator import get_game_validator

load_dotenv()

# Initialize FastAPI app and game validator
app = FastAPI()
game_validator = get_game_validator()

ENVIRONMENT = os.getenv('ENVIRONMENT', 'development')

CORS_ORIGINS = {
    'development': [
        "http://localhost:3000",
        "http://127.0.0.1:3000",
    ],
    'production': [
        "https://iammrfrank.net",
        "http://iammrfrank.net",
    ]
}

# CORS middleware configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_ORIGINS[ENVIRONMENT],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# JWT Configuration
SECRET_KEY = os.getenv('JWT_SECRET_KEY', "your-secret-key-here")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 1440

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# Database Configuration
DB_CONFIG = {
    "host": os.getenv('DB_HOST', 'localhost'),
    "user": os.getenv('DB_USER', 'idlegame'),
    "password": os.getenv('DB_PASSWORD', '3w3dCXqhVDl4aB0neAIfO7D'),
    "database": os.getenv('DB_NAME', 'idle_game_db'),
    "charset": "utf8mb4",
    "use_unicode": True,
    "collation": "utf8mb4_unicode_ci"
}

# Pydantic Models
class UserCreate(BaseModel):
    username: str = Field(..., min_length=3, max_length=20)
    password: str = Field(..., min_length=1)  # Only requiring the password to not be empty

    @validator('username')
    def validate_username(cls, v):
        if not re.match(r'^[a-zA-Z0-9_-]+$', v):
            raise ValueError('Username can only contain letters, numbers, underscores, and hyphens')
        return v

    @validator('password')
    def validate_password(cls, v):
        # Simple validation to ensure password isn't empty and doesn't contain dangerous characters
        if not v.strip():
            raise ValueError('Password cannot be empty')
        if re.search(r'[;<>&\']', v):  # Basic SQL injection prevention
            raise ValueError('Password contains invalid characters')
        return v

class SaveData(BaseModel):
    userId: int
    timestamp: int
    version: str
    state: Dict[str, Any]

    @validator('userId')
    def validate_user_id(cls, v):
        if v <= 0:
            raise ValueError('User ID must be positive')
        return v

    @validator('timestamp')
    def validate_timestamp(cls, v):
        current_time = int(datetime.utcnow().timestamp() * 1000)
        if v > current_time + 300000:  # 5 minutes future tolerance
            raise ValueError('Timestamp cannot be in the future')
        return v

    @validator('version')
    def validate_version(cls, v):
        if not re.match(r'^\d+\.\d+\.\d+$', v):
            raise ValueError('Invalid version format. Must be in format X.Y.Z')
        return v

    @validator('state')
    def validate_game_state(cls, v):
        try:
            game_validator.validate_save_data({'state': v})
            return v
        except Exception as e:
            raise ValueError(f"Invalid game state: {str(e)}")

# Helper Functions
def get_db():
    try:
        conn = mysql.connector.connect(**DB_CONFIG)
        return conn
    except Exception as e:
        print(f"Database connection error: {str(e)}")
        raise HTTPException(status_code=500, detail="Database connection failed")

async def get_current_user(token: str = Depends(oauth2_scheme)):
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        user_id: str = payload.get("sub")
        if user_id is None:
            raise HTTPException(status_code=401, detail="Invalid authentication token")
        return user_id
    except jwt.JWTError as e:
        raise HTTPException(status_code=401, detail="Invalid authentication token")

def create_access_token(data: dict):
    to_encode = data.copy()
    expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    to_encode.update({"exp": expire})
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

# API Endpoints
@app.post("/api/auth/register")
async def register(user: UserCreate):
    conn = get_db()
    cursor = conn.cursor(dictionary=True)
    
    try:
        # Additional username validation
        if any(char.isspace() for char in user.username):
            raise HTTPException(status_code=400, detail="Username cannot contain whitespace")

        cursor.execute("SELECT id FROM users WHERE username = %s", (user.username,))
        if cursor.fetchone():
            raise HTTPException(status_code=400, detail="Username already registered")

        salt = bcrypt.gensalt()
        hashed_password = bcrypt.hashpw(user.password.encode('utf-8'), salt)

        cursor.execute(
            "INSERT INTO users (username, password_hash) VALUES (%s, %s)",
            (user.username, hashed_password)
        )
        conn.commit()
        
        user_id = cursor.lastrowid
        access_token = create_access_token({"sub": str(user_id)})
        
        return {
            "userId": user_id,
            "username": user.username,
            "token": access_token
        }

    except Exception as e:
        conn.rollback()
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        cursor.close()
        conn.close()

@app.post("/api/auth/login")
async def login(user: UserCreate):
    conn = get_db()
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute(
            "SELECT id, username, password_hash FROM users WHERE username = %s",
            (user.username,)
        )
        db_user = cursor.fetchone()
        
        if not db_user:
            raise HTTPException(status_code=401, detail="Invalid username or password")

        if not bcrypt.checkpw(user.password.encode('utf-8'), db_user['password_hash'].encode('utf-8')):
            raise HTTPException(status_code=401, detail="Invalid username or password")

        access_token = create_access_token({"sub": str(db_user['id'])})
        
        return {
            "userId": db_user['id'],
            "username": db_user['username'],
            "token": access_token
        }

    finally:
        cursor.close()
        conn.close()

@app.get("/api/auth/verify")
async def verify_token(current_user: str = Depends(get_current_user)):
    conn = get_db()
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute("SELECT username FROM users WHERE id = %s", (current_user,))
        user = cursor.fetchone()
        
        if not user:
            raise HTTPException(status_code=401, detail="User not found")
            
        return {
            "userId": current_user,
            "username": user['username']
        }

    finally:
        cursor.close()
        conn.close()

@app.post("/api/save")
async def save_game(save_data: SaveData, current_user: str = Depends(get_current_user)):
    if str(save_data.userId) != current_user:
        raise HTTPException(status_code=403, detail="Not authorized to save for this user")

    conn = get_db()
    cursor = conn.cursor()
    
    try:
        state_json = json.dumps(save_data.state)
        
        cursor.execute("""
            SELECT id FROM save_data WHERE user_id = %s
        """, (current_user,))
        
        existing_save = cursor.fetchone()
        
        if existing_save:
            cursor.execute("""
                UPDATE save_data 
                SET save_timestamp = %s, version = %s, state_json = %s
                WHERE user_id = %s
            """, (save_data.timestamp, save_data.version, state_json, current_user))
        else:
            cursor.execute("""
                INSERT INTO save_data (user_id, save_timestamp, version, state_json)
                VALUES (%s, %s, %s, %s)
            """, (current_user, save_data.timestamp, save_data.version, state_json))
        
        conn.commit()
        return {"status": "success", "message": "Save successful"}

    except Exception as e:
        conn.rollback()
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        cursor.close()
        conn.close()

@app.get("/api/save/{user_id}")
async def get_save(user_id: str, current_user: str = Depends(get_current_user)):
    if user_id != current_user:
        raise HTTPException(status_code=403, detail="Not authorized to access this save")
        
    conn = get_db()
    cursor = conn.cursor(dictionary=True)
    
    try:
        cursor.execute("""
            SELECT save_timestamp as timestamp, version, state_json 
            FROM save_data 
            WHERE user_id = %s
        """, (user_id,))
        
        save = cursor.fetchone()
        if not save:
            return {
                "timestamp": int(datetime.utcnow().timestamp() * 1000),
                "version": "1.0.0",
                "state": None
            }
            
        return {
            "timestamp": save['timestamp'],
            "version": save['version'],
            "state": json.loads(save['state_json'])
        }

    except Exception as e:
        print(f"Load error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        cursor.close()
        conn.close()

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)