/**
  \file websockets.h
  \author FC
  \date 2026
  \brief Minimal websocket implementation.
*/

void websocket_send_text(socket_t fd, const char *msg) {
    uint64_t len = (uint64_t)strlen(msg);
    uint8_t header[10];
    uint8_t header_len = 0;

    header[0] = 0x81;

    if (len <= 125) {
        header[1] = (uint8_t)len;
        header_len = 2;
    } else if (len <= 0xFFFF) {
        header[1] = 126;
        header[2] = (uint8_t)((len >> 8) & 0xFF);
        header[3] = (uint8_t)(len & 0xFF);
        header_len = 4;
    } else {
        header[1] = 127;
        uint64_t tmp = len;
        int i;
        for (i = 0; i < 8; i++) {
            header[2 + i] = (uint8_t)((tmp >> (56 - 8*i)) & 0xFF);
        }
        header_len = 10;
    }
    send(fd, (const void *)header, header_len, 0);
    send(fd, msg, (size_t)len, 0);
}

int recv_all(socket_t fd, void *buf, int len) {
    int total = 0;

    while (total < len) {
        int n = recv(fd, (char*)buf + total, len - total, 0);
        if (SOCKET_IO_ERROR_CHECK(n)) return 0;
        total += n;
    }

    return total;
}


int websocket_recv_frame(socket_t fd, char *out, size_t bufferSize) {
    size_t total_len = 0;

    while (1) {
        uint8_t hdr[2];

        if (recv_all(fd, hdr, 2) <= 0) return 0;

        uint8_t opcode = hdr[0] & 0x0F;
        uint8_t fin    = hdr[0] & 0x80;

        size_t len = hdr[1] & 0x7F;
        int masked = hdr[1] & 0x80;

        if (len == 126) {
            uint8_t ext[2];
            if (recv_all(fd, ext, 2) <= 0) return 0;
            len = (ext[0] << 8) | ext[1];
        } else if (len == 127) {
            uint8_t ext[8];
            if (recv_all(fd, ext, 8) <= 0) return 0;

            len = 0;
            for (int i = 0; i < 8; i++)
                len = (len << 8) | ext[i];

            if (len > 65536) return 0;
        }

        if (opcode == 0x8) return 0;

        uint8_t mask[4];
        if (masked) {
            if (recv_all(fd, mask, 4) <= 0) return 0;
        }

        if (total_len + len > bufferSize) return 0;

        if (recv_all(fd, out + total_len, len) <= 0) return 0;

        for (size_t i = 0; i < len; i++) {
            if (masked)
                out[total_len + i] ^= mask[i % 4];
        }

        total_len += len;

        if (fin) break;
    }

    return total_len;
}

#define ROTLEFT(a,b) (((a) << (b)) | ((a) >> (32-(b))))

void sha1_transform(uint32_t state[5], const uint8_t data[64]);

void sha1(const char *input, size_t len, uint8_t hash[20]) {
    uint32_t state[5] = { 0x67452301,0xEFCDAB89,0x98BADCFE,0x10325476,0xC3D2E1F0 };
    uint8_t block[64];
    size_t i;

    // Padding
    size_t full_blocks = len / 64;
    size_t rem = len % 64;

    for (i = 0; i < full_blocks; i++) {
        sha1_transform(state, (uint8_t*)(input + i*64));
    }

    memset(block, 0, 64);
    if(rem) memcpy(block, input + full_blocks*64, rem);
    block[rem] = 0x80;

    if(rem >= 56) {
        sha1_transform(state, block);
        memset(block, 0, 64);
    }

    uint64_t bitlen = len * 8;
    for(i=0;i<8;i++) block[63-i] = (uint8_t)((bitlen >> (i*8)) & 0xFF);

    sha1_transform(state, block);

    for(i=0;i<5;i++){
        hash[i*4 + 0] = (state[i] >> 24) & 0xFF;
        hash[i*4 + 1] = (state[i] >> 16) & 0xFF;
        hash[i*4 + 2] = (state[i] >> 8) & 0xFF;
        hash[i*4 + 3] = state[i] & 0xFF;
    }
}

void sha1_transform(uint32_t state[5], const uint8_t data[64]) {
    uint32_t a,b,c,d,e,t,m[80];
    int i;

    for(i=0;i<16;i++){
        m[i] =
        ((uint32_t)data[i*4] << 24) |
        ((uint32_t)data[i*4+1] << 16) |
        ((uint32_t)data[i*4+2] << 8) |
        ((uint32_t)data[i*4+3]);
    }
    for(i=16;i<80;i++){
        m[i] = ROTLEFT(m[i-3]^m[i-8]^m[i-14]^m[i-16],1);
    }

    a=state[0]; b=state[1]; c=state[2]; d=state[3]; e=state[4];

    for(i=0;i<80;i++){
        uint32_t f,k;
        if(i<20){ f=(b & c)|((~b)&d); k=0x5A827999; }
        else if(i<40){ f=b^c^d; k=0x6ED9EBA1; }
        else if(i<60){ f=(b & c)|(b & d)|(c & d); k=0x8F1BBCDC; }
        else { f=b^c^d; k=0xCA62C1D6; }
        t = ROTLEFT(a,5)+f+e+k+m[i];
        e=d; d=c; c=ROTLEFT(b,30); b=a; a=t;
    }

    state[0]+=a; state[1]+=b; state[2]+=c; state[3]+=d; state[4]+=e;
}

static const char b64_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

void base64_encode(const uint8_t *input, size_t len, char *out) {
    size_t i;
    char *p = out;
    for(i = 0; i < len; i += 3){
        uint32_t v = input[i] << 16;
        if(i + 1 < len) v |= input[i + 1] << 8;
        if(i + 2 < len) v |= input[i + 2];

        *p++ = b64_table[(v >> 18) & 0x3F];
        *p++ = b64_table[(v >> 12) & 0x3F];
        *p++ = (i + 1 < len) ? b64_table[(v >> 6) & 0x3F] : '=';
        *p++ = (i + 2 < len) ? b64_table[v & 0x3F] : '=';
    }
    *p = '\0';
}

#define WS_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

uint8_t handle_websocket_handshake(socket_t fd, const char *req, Conn *c) {
    const char *key_hdr = strstr(req, "Sec-WebSocket-Key:");
    if (!key_hdr) return 0;

    char key[128] = {0};
    const char *p = key_hdr + strlen("Sec-WebSocket-Key:");
    while (*p == ' ') p++; // saltar espacios
    int i = 0;
    while (*p != '\r' && *p != '\n' && *p != '\0' && i < 127) {
        key[i++] = *p++;
    }
    if(i == 0) return 0;

    key[i] = '\0';


    char combined[256];
    snprintf(combined, sizeof(combined), "%s%s", key, WS_GUID);


    uint8_t sha1_result[20];
    sha1(combined, strlen(combined), sha1_result);

    char b64[64];
    base64_encode(sha1_result, 20, b64);

    //in case the server is behind a reverse proxy server like nginx we try to get the real ip
    if (!strcmp(c->ip, "127.0.0.1")) {
        const char *hdr = strstr(req, "X-Real-IP:");
        if (!hdr) hdr = strstr(req, "X-Forwarded-For:");
        if (hdr) {
            if (hdr == strstr(req, "X-Real-IP:")) hdr += strlen("X-Real-IP:");
            else hdr += strlen("X-Forwarded-For:");
    
            while (*hdr == ' ') hdr++;
    
            int i = 0;
            while (*hdr != '\r' && *hdr != '\n' && *hdr != '\0' && i < 15) {
                c->ip[i++] = *hdr++;
            }
            c->ip[i] = '\0';
        }
    }

    char response[512];
    snprintf(response, sizeof(response),
        "HTTP/1.1 101 Switching Protocols\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        "Sec-WebSocket-Accept: %s\r\n\r\n",
        b64);
    
    send(fd, response, strlen(response), 0);
    return 1;
}