//WebSocketProvider.js
import React, { createContext, useEffect, useRef, useState } from 'react';
import { MlKem768 } from 'mlkem';
import { toByteArray, fromByteArray } from 'base64-js';

export const WebSocketContext = createContext();

const WebSocketProvider = ({ children, socketRef, operationQueue, sessionId, setSessionId, clientPhone, threadId, onThreadAdded, joinThread, clientThreads }) => {
    // const socketRef = useRef(null);
    // const [threads, setThreads] = useState([]);
    const [isReady, setIsReady] = useState(false);
    // const operationQueue = useRef([]);
    const [messages, setMessages] = useState([]);
    const imageChunksRef = useRef({});
    const aesKeysRef = useRef({ key1: null, key2: null });
    const savedSharedSecret = useRef(null);

    function base64ToUint8Array(base64String) {
        // const binaryString = atob(base64String);
        // const len = binaryString.length;
        // const bytes = new Uint8Array(len);
        // for (let i = 0; i < len; i++) {
        //     bytes[i] = binaryString.charCodeAt(i);
        // }
        // return bytes;
        return toByteArray(base64String);
    }

    function uint8ArrayToBase64(uint8Array) {
        // const binaryString = String.fromCharCode(...uint8Array);
        // return btoa(binaryString);
        return fromByteArray(uint8Array);
    }

    const deriveKey = async (inputSharedSecret, salt, length = 32) => {
        const sharedSecretBuffer = new TextEncoder().encode(inputSharedSecret);
        const saltBuffer = new TextEncoder().encode(salt);

        const keyMaterial = await window.crypto.subtle.importKey(
            'raw',
            sharedSecretBuffer,
            { name: 'HKDF' },
            false,
            ['deriveKey']
        );

        const derivedKey = await window.crypto.subtle.deriveKey(
            {
                name: 'HKDF',
                hash: 'SHA-256',
                salt: saltBuffer,
                info: new Uint8Array(),
            },
            keyMaterial,
            { name: 'AES-GCM', length: length * 8 },
            true,
            ['encrypt', 'decrypt']
        );

        const keyBuffer = await window.crypto.subtle.exportKey('raw', derivedKey);
        return new Uint8Array(keyBuffer);
    };

    const singleEncryptMessage = async (key, plainTextMessage) => {
        const iv = window.crypto.getRandomValues(new Uint8Array(12)); // Generate a 12-byte random IV
        const encoder = new TextEncoder();
        const plainTextBuffer = encoder.encode(plainTextMessage);

        const cryptoKey = await window.crypto.subtle.importKey(
            'raw',
            key,
            'AES-GCM',
            false,
            ['encrypt']
        );

        try {
            console.log("Attempting encryption...");
            const encryptedBuffer = await window.crypto.subtle.encrypt(
                {
                    name: 'AES-GCM',
                    iv: iv,
                    tagLength: 128, // 128-bit tag length
                },
                cryptoKey,
                plainTextBuffer
            );

            // Extract ciphertext and tag from the encrypted buffer
            const ciphertextBuffer = encryptedBuffer.slice(0, encryptedBuffer.byteLength - 16); // Last 16 bytes are the tag
            const tagBuffer = encryptedBuffer.slice(encryptedBuffer.byteLength - 16);

            return {
                ciphertext: btoa(String.fromCharCode(...new Uint8Array(ciphertextBuffer))),
                iv: btoa(String.fromCharCode(...iv)),
                tag: btoa(String.fromCharCode(...new Uint8Array(tagBuffer))),
            };
        } catch (error) {
            console.error('Error encrypting message:', error);
        }
    };

    const singleDecryptMessage = async (key, encryptedMessage) => {
        const { ciphertext, iv, tag } = encryptedMessage;

        const ciphertextBuffer = Uint8Array.from(atob(ciphertext), c => c.charCodeAt(0));
        const ivBuffer = Uint8Array.from(atob(iv), c => c.charCodeAt(0));
        const tagBuffer = Uint8Array.from(atob(tag), c => c.charCodeAt(0));

        const cryptoKey = await window.crypto.subtle.importKey(
            'raw',
            key,
            'AES-GCM',
            false,
            ['decrypt']
        );

        try {
            console.log("Attempting decrypt...")
            const decryptedBuffer = await window.crypto.subtle.decrypt(
                {
                    name: 'AES-GCM',
                    iv: ivBuffer,
                    tagLength: 128, // 128-bit tag length
                },
                cryptoKey,
                new Uint8Array([...ciphertextBuffer, ...tagBuffer])
            );

            return new TextDecoder().decode(decryptedBuffer);

        } catch (error) {
            console.error('Error decrypting message:', error);
        }
    };

    const doubleEncrypt = async (key1, key2, plainTextMessage) => {
        try {
            // First encryption with key1
            const firstEncryption = await singleEncryptMessage(key1, plainTextMessage);

            // Serialize the first encryption result
            const serializedFirst = JSON.stringify(firstEncryption);

            // Second encryption with key2
            const secondEncryption = await singleEncryptMessage(key2, serializedFirst);

            return secondEncryption;
        } catch (error) {
            console.error('Error in doubleEncrypt:', error);
        }
    };

    const doubleDecrypt = async (key1, key2, encryptedMessage) => {
        try {
            const firstDecryption = await singleDecryptMessage(key2, encryptedMessage);
            const parsedFirst = JSON.parse(firstDecryption);
            const finalMessage = await singleDecryptMessage(key1, parsedFirst);
            return finalMessage;
        } catch (error) {
            console.error('Error in doubleDecrypt:', error);
        }
    };

    useEffect(() => {
        const wsUrl = window.location.hostname === 'localhost'
            ? `ws://localhost:8080`
            : `wss://sinbi-websocket-platform-all.azurewebsites.net`;

        // Load sessionId from localStorage or create a new one
        // let storedSessionId = localStorage.getItem('sessionId');
        // setSessionId(storedSessionId);

        socketRef.current = new WebSocket(wsUrl);

        socketRef.current.onopen = async () => {
            console.log('WebSocket connected');
            setIsReady(true);

            console.log('ON CONNECTION SESSION ID:', sessionId);
            // Notify the server of the connection
            socketRef.current.send(
                JSON.stringify({
                    type: 'connect',
                    sessionId: sessionId
                })
            );

            //Send queued messages
            // while (operationQueue.current.length > 0) {
            //     const operation = operationQueue.current.shift();
            //     socketRef.current.send(operation);
            // }
        };

        socketRef.current.onmessage = async (event) => {
            try {
                const message = JSON.parse(event.data);

                if (message.type === 'thread_created') {
                    console.log("===== THREAD CREATED =====");
                    onThreadAdded(message.threadId, message.threadName, message.sessionId);

                } else if (message.type === 'thread_joined') {
                    if (!clientThreads.find((thread) => thread.id === message.threadId)) {
                        console.log("===== THREAD JOINED =====");
                        //only join general if there is no other currentThreadId set
                        onThreadAdded(message.threadId, message.threadName, message.sessionId);
                    }
                }
                else if (message.type === 'init' && !savedSharedSecret.current) {
                    console.log("===== INIT =====");
                    // localStorage.setItem('sessionId', message.sessionId);
                    console.log('Received sessionId:', message.sessionId);
                    setSessionId(message.sessionId);
                    const keyEncap = new MlKem768();
                    const [cipherText, sharedSecret] = await keyEncap.encap(
                        base64ToUint8Array(message.publicKey)
                    );

                    savedSharedSecret.current = sharedSecret;
                    console.log("Sending encapsulated key...");
                    socketRef.current.send(
                        JSON.stringify({
                            type: 'encapsulation',
                            cipherText: uint8ArrayToBase64(cipherText),
                            sessionId: message.sessionId,
                        })
                    );
                } else if (message.type === 'chatKeys') {
                    console.log("===== CHAT KEYS =====");
                    const encryptedMessage = message.encryptedMessage;
                    const quantumKey1 = await deriveKey(savedSharedSecret.current, 'sinbi-sa1t');
                    const quantumKey2 = await deriveKey(savedSharedSecret.current, 'sinbi-sa2t');

                    const decryptedKeys = await doubleDecrypt(
                        quantumKey1,
                        quantumKey2,
                        encryptedMessage
                    );

                    const parsedKeys = JSON.parse(decryptedKeys);
                    aesKeysRef.current = {
                        key1: base64ToUint8Array(parsedKeys.key1),
                        key2: base64ToUint8Array(parsedKeys.key2),
                    };
                    setSessionId(parsedKeys.sessionId);
                    console.log('Decrypted AES keys:', aesKeysRef.current);
                    console.log('Session ID:', parsedKeys.sessionId);

                    socketRef.current.send(
                        JSON.stringify({
                            type: 'chatKeysAck',
                            sessionId: parsedKeys.sessionId,
                        })
                    );

                } else if (message.type === 'chat' && message.threadId === threadId) {
                    console.log("===== CHAT =====");
                    const decryptedMessage = await doubleDecrypt(
                        aesKeysRef.current.key1,
                        aesKeysRef.current.key2,
                        message.text
                    );
                    const completedMessage = {
                        ...message,
                        text: decryptedMessage,
                    };
                    setMessages((prev) => [completedMessage, ...prev]);
                } else if (message.type === 'image_chunk' && message.threadId === threadId) {
                    console.log("===== IMAGE =====");
                    const { threadId, sender, sessionId, imageChunk, chunkIndex, totalChunks } = message;

                    // Initialize the image chunks array if it doesn't exist
                    if (!imageChunksRef.current[threadId]) {
                        imageChunksRef.current[threadId] = { chunks: [], totalChunks, sender, sessionId };
                    }

                    // Decrypt the image chunk
                    const decryptedChunk = await doubleDecrypt(
                        aesKeysRef.current.key1,
                        aesKeysRef.current.key2,
                        imageChunk
                    );
                    // Store the decrypted chunk
                    const threadData = imageChunksRef.current[threadId];
                    threadData.chunks[chunkIndex] = decryptedChunk;

                    // If all chunks have been received, reassemble the image
                    if (threadData.chunks.filter(Boolean).length === totalChunks) {
                        const fullImage = threadData.chunks.join('');
                        const completedMessage = {
                            ...message,
                            image: fullImage,
                        };
                        setMessages((prev) => [completedMessage, ...prev]);
                        delete imageChunksRef.current[threadId]; // Cleanup
                    }

                }
            } catch (error) {
                console.error('Error handling WebSocket message:', error);
            }
        };

        socketRef.current.onclose = () => {
            console.log('WebSocket disconnected');
            setIsReady(false);
        };

        return () => socketRef.current.close();
    }, [threadId]); // Reconnect or filter based on threadId

    const sendMessage = async (message, type = 'chat') => {
        console.log('Attempting to send message');
        console.log('aesKeysRef.current:', aesKeysRef.current);
        console.log('sharedSecret:', savedSharedSecret.current);
        if (type === 'chat') {
            if (socketRef.current && socketRef.current.readyState === WebSocket.OPEN) {
                const encryptedMessage = await doubleEncrypt(
                    aesKeysRef.current.key1,
                    aesKeysRef.current.key2,
                    message.text
                );

                console.log('ThreadId RIGHT before sending:', threadId);
                const messageObject = {
                    text: encryptedMessage,
                    type: type || 'chat',
                    threadId, // Include threadId
                    sessionId,
                    sender: clientPhone,
                };

                socketRef.current.send(JSON.stringify(messageObject));
            }
        } else if (type === 'image_chunk') {
            if (socketRef.current && socketRef.current.readyState === WebSocket.OPEN) {
                const encryptedChunk = await doubleEncrypt(
                    aesKeysRef.current.key1,
                    aesKeysRef.current.key2,
                    message.imageChunk
                );

                const messageObject = {
                    imageChunk: encryptedChunk,
                    type: type || 'image_chunk',
                    threadId, // Include threadId
                    sessionId,
                    sender: clientPhone,
                    chunkIndex: message.chunkIndex,
                    totalChunks: message.totalChunks,
                };

                // setMessages((prev) => [...prev, { ...message, threadId }]);
                socketRef.current.send(JSON.stringify(messageObject));
            }
        }
    };

    return (
        <WebSocketContext.Provider value={{ messages, sendMessage, aesKeys: aesKeysRef.current }}>
            {children}
        </WebSocketContext.Provider>
    );
};

export default WebSocketProvider;
