"""Socket.IO server for real-time chat."""
import logging
from urllib.parse import parse_qs
import jwt
import socketio
from bson import ObjectId
from bson.errors import InvalidId

from auth import JWT_ALGORITHM, get_jwt_secret
from deps import get_db

logger = logging.getLogger(__name__)

sio = socketio.AsyncServer(
    async_mode="asgi",
    cors_allowed_origins="*",
    ping_timeout=60,
    ping_interval=25,
)


def _extract_token(environ: dict, auth: dict | None) -> str | None:
    if auth and isinstance(auth, dict):
        tok = auth.get("token")
        if tok:
            return tok
    # cookies
    cookie_str = environ.get("HTTP_COOKIE", "") or ""
    for part in cookie_str.split(";"):
        kv = part.strip().split("=", 1)
        if len(kv) == 2 and kv[0] == "access_token":
            return kv[1]
    # query
    qs = environ.get("QUERY_STRING", "")
    qd = parse_qs(qs)
    tk = qd.get("token", [None])[0]
    return tk


@sio.event
async def connect(sid, environ, auth=None):
    token = _extract_token(environ, auth)
    if not token:
        logger.info(f"socket {sid} rejected: no token")
        return False
    try:
        payload = jwt.decode(token, get_jwt_secret(), algorithms=[JWT_ALGORITHM])
        if payload.get("type") != "access":
            return False
        try:
            uid = str(ObjectId(payload["sub"]))
        except InvalidId:
            return False
        db = get_db()
        user = await db.users.find_one({"_id": ObjectId(uid)})
        if not user:
            return False
        await sio.save_session(sid, {
            "user_id": uid,
            "role": user["role"],
            "name": user.get("name", ""),
        })
        await sio.enter_room(sid, f"user:{uid}")
        logger.info(f"socket {sid} connected as {user['email']}")
        return True
    except jwt.InvalidTokenError as e:
        logger.info(f"socket {sid} invalid token: {e}")
        return False


@sio.event
async def disconnect(sid):
    logger.info(f"socket {sid} disconnected")


@sio.event
async def join_channel(sid, data):
    channel_id = (data or {}).get("channel_id")
    if not channel_id:
        return {"ok": False, "error": "missing channel_id"}
    session = await sio.get_session(sid)
    if not session:
        return {"ok": False, "error": "no session"}
    db = get_db()
    try:
        ch = await db.channels.find_one({"_id": ObjectId(channel_id)})
    except InvalidId:
        return {"ok": False, "error": "invalid channel"}
    if not ch:
        return {"ok": False, "error": "channel not found"}
    if session["role"] not in ("admin", "pm") and session["user_id"] not in (ch.get("members") or []):
        return {"ok": False, "error": "forbidden"}
    await sio.enter_room(sid, channel_id)
    return {"ok": True}


@sio.event
async def leave_channel(sid, data):
    channel_id = (data or {}).get("channel_id")
    if channel_id:
        await sio.leave_room(sid, channel_id)
    return {"ok": True}


@sio.event
async def typing(sid, data):
    channel_id = (data or {}).get("channel_id")
    session = await sio.get_session(sid)
    if channel_id and session:
        await sio.emit("typing", {
            "channel_id": channel_id,
            "user_id": session["user_id"],
            "name": session["name"],
        }, room=channel_id, skip_sid=sid)
