// Temp to test rad layout
import React, { useEffect, useState, useRef, useCallback } from 'react';
import { useRAD } from './RADHandlerLogic';

import * as d3 from 'd3';

const defaultWidth = 20;
const defaultHeight = 20;

const RADLayout = () => {
    const { radID, version, getRoles, getThreads, getSubthreads, getThreadSequence,
        getRoleThreads } = useRAD();
    const [radLayout, setRadLayout] = useState(null);

    const viewBoxWidth = 1500;
    const viewBoxHeight = 1500;

    const layoutRef = useRef();

    useEffect(() => {
        //console.log('create svg');
        if (layoutRef.current && !layoutRef.current.querySelector('svg')) {
            const diagramSvg = d3.select(layoutRef.current)
                .append('svg')
                .attr('id', 'rad-layout-svg')
                .attr("viewBox", `0 0 ${viewBoxWidth} ${viewBoxHeight}`)
                .node() // Get the DOM node
            // Add a group to hold the diagram elements
            d3.select(diagramSvg).append('g')
                .attr('id', 'rad-layout')
                .attr("cursor", "grab")
                .attr("pointer-events", "all")
        }
    }, []);



    // get all tree subnodes of a given class type (role, thread, activity, etc.) into an array. Optinally filter by parent_id
    const getNodesByClass = useCallback((tree, classType, parent_id = null) => {
        const nodes = [];
        function traverse(tree) {
            if (tree.node.class === classType) {
                if (!parent_id || tree.node.parent_id === parent_id) {
                    nodes.push(tree.node);
                }
            }
            tree.children.forEach(child => traverse(child));
        }
        traverse(tree);
        return nodes;
    }, []);

    const getNodesWithClass = useCallback((tree, classType, parent_id = null) => {
        const nodes = [];
        function traverse(tree) {
            if (tree.node.class.includes(classType)) {
                if (!parent_id || tree.node.parent_id === parent_id) {
                    nodes.push(tree.node);
                }
            }
            tree.children.forEach(child => traverse(child));
        }
        traverse(tree);
        return nodes;
    }, []);

/*     const getNodeById = useCallback((tree, id) => {
        let node = null;
        function traverse(tree) {
            if (tree.node.id === id) {
                node = tree.node;
            }
            tree.children.forEach(child => traverse(child));
        }
        traverse(tree);
        return node;
    }, []); */

    const getSubtreeById = useCallback((tree, id) => {
        let subtree = null;
        function traverse(tree) {
            if (tree.node.id === id) {
                subtree = tree;
            }
            tree.children.forEach(child => traverse(child));
        }
        traverse(tree);
        return subtree;
    }, []);

    const getChildrenByThreadId = useCallback((tree, thread_id) => {
        const thread = getSubtreeById(tree, thread_id);
        console.log('getChildrenByThreadId', thread);
        return thread.children;
    }, [getSubtreeById]);

    const getStatePositions = useCallback((tree, thread_id) => {
        // For each pair of children, calculate the state position as the mid point between them and return the state id and the mid point.
        // The first state is the prestate of the first activity, and its position is a default distance above the first activity.
        // The last state is the poststate of the last activity, and its position is the height of the last activity plus a default distance below it.
        const threadChildren = getChildrenByThreadId(tree, thread_id);
        if (threadChildren.length === 0 ) return [];

        const statePositions = threadChildren.reduce((acc, child, index) => {
                const state_id = child.node.prestate_id;
                const position = {
                    x: child.node.position.x + child.node.size.width / 2,
                    y: index === 0 ? child.node.position.y - 20 : index === threadChildren.length - 1 ? child.node.position.y + child.node.size.height + 20 : child.node.position.y + child.node.size.height / 2
                };
                acc.push({ state_id, position });
            return acc;
        }, []);
        return statePositions;
    }, [getChildrenByThreadId]);
                

    useEffect(() => {
        console.log('Building RAD Layout...', radID, version);

        function createNode(node, children = []) {
            // Determine if the node is an activity based on its class or other attributes
            const isActivity = node.activity_id && !children.length;
            const size = isActivity
                ? { width: defaultWidth, height: defaultHeight }
                : { width: defaultWidth + node.size?.width, height: defaultHeight + node.size?.height };

            // Attach the size information directly to the node's value if it's an activity
            return { node: { ...node, size }, children };
        }

        function addSubthreadsToRefinement(activityNode) {
            const { node } = activityNode;
            //console.log('addSubthreadsToRefinement', activityNode);
            // Check if the activity has subthreads and add them recursively
            if (node.class && ['case-refinement', 'part-refinement', 'part-repeat'].includes(node.class)) {
                const subthreadsData = getSubthreads(node.class, node.id);
                subthreadsData.forEach(subthread => {
                    const subthreadNode = createNode(subthread);
                    const activities = getThreadSequence(subthread.id, subthread.prestate_id)
                        .map(activity => {
                            // Create activity nodes, including checking for their own subthreads recursively
                            const newNode = createNode({ ...activity, id: activity.activity_id, parent_id: activity.thread_id });
                            addSubthreadsToRefinement(newNode); // Recursive call
                            return newNode;
                        });

                    subthreadNode.children = activities;

                    const size = setThreadSize(subthreadNode.children);
                    const updatedChildren = setActivityPositions(subthreadNode.children);

                    const sizedThreadNode = { node: { ...subthreadNode.node, class: subthread.class + ' thread-group', size, parent_id: node.id }, children: updatedChildren };
                    activityNode.children.push(sizedThreadNode);

                });
                // resize the refinement activity node to fit the subthreads
                const size = sizeRefinement(activityNode.children);
                const updatedChildren = setSubthreadPositions(activityNode.children);
                const sizedActivityNode = { node: { ...activityNode.node, size }, children: updatedChildren };
                return sizedActivityNode;
            }
            else { // No subthreads, return the activity node as is
                return activityNode;
            }
        }

        function sizeRefinement(nodes) {
            // Refinement node size is the total width and max height of its children
            const size = nodes.reduce((acc, child) => {
                return {
                    width: acc.width + child.node.size.width,
                    height: Math.max(acc.height, child.node.size.height)
                };
            }, { width: 20, height: 10 }); // Padding
            return size;
        }

        function setSubthreadPositions(subthreads) {
            // Set the initial x and y offsets for the first subthread
            const xOffset = 40;
            const yOffset = 0;
            // First, sort the subthreads based on node.index
            const sortedSubthreads = subthreads.sort((a, b) => a.node.index - b.node.index);

            // Reduce the sorted subthreads to calculate the cumulative x-value
            const updatedSubthreads = sortedSubthreads.reduce((acc, subthread, index) => {
                // Calculate the new x-value based on the cumulative width of preceding subthreads
                const newX = index === 0 ? xOffset : acc[index - 1].node.position.x + acc[index - 1].node.size.width;

                // Update the subthread with the new position
                const updatedSubthread = {
                    ...subthread,
                    node: {
                        ...subthread.node,
                        position: {
                            x: newX,
                            y: yOffset,
                        },
                    },
                };

                return [...acc, updatedSubthread];
            }, []);

            return updatedSubthreads;
        }

        function setRoleSize(nodes) {
            // Role node size is the total width and maximal extent of its children
            const size = nodes.reduce((acc, child) => {
                return {
                    width: acc.width + child.node.size.width,
                    height: acc.height + child.node.size.height
                };
            }, { width: 10, height: 10 }); // Padding
            return size;
        }

        function setThreadSize(nodes) {
            // Thread node size is the max width and total height of its children
            const size = nodes.reduce((acc, child) => {
                return {
                    width: Math.max(acc.width, child.node.size.width),
                    height: acc.height + child.node.size.height * 2.5 + 20
                };
            }, { width: 60, height: 20 }); // Padding
            return size;
        }

        function setActivityPositions(activities) {
            // Set the initial x and y offsets for the first activity
            const xOffset = 0;
            const yOffset = 40;
            // First, sort the activities based on node.index
            const sortedActivities = activities.sort((a, b) => a.node.index - b.node.index);


            // Reduce the sorted activities to calculate the cumulative y-value
            const updatedActivities = sortedActivities.reduce((acc, activity, index) => {
                // Calculate the new y-value based on the cumulative height of preceding activities
                const newY = index === 0 ? yOffset : acc[index - 1].node.position.y + acc[index - 1].node.size.height * 2 + 20;

                // Update the activity with the new position
                const updatedActivity = {
                    ...activity,
                    node: {
                        ...activity.node,
                        position: {
                            x: xOffset,
                            y: newY,
                        },
                    },
                };

                return [...acc, updatedActivity];
            }, []);

            return updatedActivities;
        }

        function buildRADLayout() {
            function postOrderTraversal(node) {
                // Accumulate children's values recursively, then add the current node's value
                return node.children.reduce((acc, child) => acc.concat(postOrderTraversal(child)), []).concat(node.node);
            }

            let root = createNode({ id: 'root', name: 'Root Node', class: 'root' });

            const roles = getRoles();
            roles.forEach(role => {
                const roleNode = createNode(role);
                const threads = getRoleThreads(role.id);
                threads.forEach(thread => {
                    const threadNode = createNode(thread);
                    // add activities to the thread node
                    const orderedActivities = getThreadSequence(thread.id, thread.prestate_id);
                    orderedActivities.forEach(activity => {
                        const activityNode = createNode({ ...activity, id: activity.activity_id, parent_id: activity.thread_id });
                        const sizedNode = addSubthreadsToRefinement(activityNode);
                        threadNode.children.push(sizedNode);
                    });
                    const position = { x: thread.x, y: thread.y}
                    const size = setThreadSize(threadNode.children);
                    const updatedChildren = setActivityPositions(threadNode.children);
                    const sizedThreadNode = { node: { ...threadNode.node, position, size, parent_id: thread.role_id }, children: updatedChildren };
                    roleNode.children.push(sizedThreadNode);
                });
                const size = setRoleSize(roleNode.children);
                const roleSize = {
                    width: Math.max(roleNode.node.width, size.width + 40),
                    height: Math.max(roleNode.node.height, size.height + 40)
                };
                const sizedRoleNode = { ...roleNode, node: { ...roleNode.node, size: roleSize, parent_id: 'root' } };
                root.children.push(sizedRoleNode);
            });

            console.log('postOrderTraversal', postOrderTraversal(root));
            //console.log(JSON.stringify(root, null, 4));
            return root;
        }

        setRadLayout(buildRADLayout());
    }, [radID, version, getRoles, getThreads, getSubthreads, getThreadSequence, getRoleThreads]);

    useEffect(() => {
        // Guards
        if (!radLayout) return;
        if (!layoutRef.current) return;

        console.log('RAD Layout updated...', radLayout);
        if (!radLayout) return;

        const svg = d3.select(layoutRef.current).select('svg');
        const g = svg.select("#rad-layout");

        const roles = getNodesByClass(radLayout, 'role');
        drawRoleGroups(g, roles, d3.drag(), () => console.log('clicked role'));

        roles.forEach(role => {
            const roleThreads = getNodesByClass(radLayout, 'thread', role.id);
            const roleGroup = g.select(`g.role-group[data-id="${role.id}"]`);
            drawThreadGroups(roleGroup, roleThreads, d3.drag(), () => console.log('clicked thread'));
            drawThreads(roleThreads);
        });


        function drawThreads(threads) {
            //console.log('drawThreads', threads);
            threads.forEach(thread => {
                const threadGroup = g.select(`g.thread-group[data-id="${thread.id}"]`)

                const interactions = getNodesByClass(radLayout, 'activity interaction', thread.id);
                //console.log('draw interactions', interactions)
                drawInteractionGroups(threadGroup, interactions, d3.drag(), () => console.log('clicked interaction'));

                const actions = getNodesByClass(radLayout, 'activity action', thread.id);
                drawActionGroups(threadGroup, actions, d3.drag(), () => console.log('clicked action'));

                const triggers = getNodesByClass(radLayout, 'activity trigger', thread.id);
                drawTriggerGroups(threadGroup, triggers, d3.drag(), () => console.log('clicked trigger'));

                const start_roles = getNodesByClass(radLayout, 'activity start-role', thread.id);
                drawStartRoleGroups(threadGroup, start_roles, d3.drag(), () => console.log('clicked start role'));

                const ellipses = getNodesByClass(radLayout, 'activity ellipsis', thread.id);
                drawEllipsisGroups(threadGroup, ellipses, d3.drag(), () => console.log('clicked ellipsis'));

                const caseRefs = getNodesByClass(radLayout, 'case-refinement', thread.id);
                drawCaseRefinementGroups(threadGroup, caseRefs, d3.drag(), () => console.log('clicked caserefinement'));
                caseRefs.forEach(caseRef => {
                    const caseGroup = threadGroup.select(`g.case-refinement[data-id="${caseRef.id}"]`);
                    const caseConditions = getNodesWithClass(radLayout, 'case-condition', caseRef.id);
                    drawCaseConditionGroups(caseGroup, caseConditions, d3.drag(), () => console.log('clicked case condition'));
                    drawThreads(caseConditions)
                });

                /* const partRefs = getNodesByClass(radLayout, 'part-refinement', thread.id);
                drawPartRefinementGroups(threadGroup, partRefs, d3.drag(), () => console.log('clicked part refinement'));
    
                const partRepeats = getNodesByClass(radLayout, 'part-repeat', thread.id);
                drawPartRepeatGroups(threadGroup, partRepeats, d3.drag(), () => console.log('clicked part thread'));

                const radstates = getThreadStates(thread.id);
            drawStateLines(threadGroup, thread.prestate_id, activities, radstates, expanses, clicked, dragStateBehaviour());
     */
                const statePositions = getStatePositions(radLayout,thread.id);
                console.log('statePositions', statePositions); 
                drawStateLines(threadGroup, statePositions, d3.drag(), () => console.log('clicked state line'));

            });
        }


    }, [radLayout, getNodesByClass, getNodesWithClass, getStatePositions]);

    return (
        <div id="rad-display" >
            <p>Rad Layout</p>
            <div className='rad-diagram-container' >
                <div ref={layoutRef}></div>
            </div>
            {radLayout && <pre>{JSON.stringify(radLayout, null, 4)}</pre>
            }
        </div>
    );
};

export default RADLayout;

const drawStateLines = (container, statePositions) => {
    const stateLines = container.selectAll('line.state-line').data(statePositions);
    stateLines.enter().append('line').classed('state-line', true).merge(stateLines)
        .attr('x1', d => d.position.x)
        .attr('y1', d => d.position.y)
        .attr('x2', d => d.position.x + 20)
        .attr('y2', d => d.position.y);
    stateLines.exit().remove();
}

function drawRoleGroups(container, rolesData, onDrag, onClick) {

    //console.log('drawRoleGroups', rolesData);
    // Binding data to role groups, using a class to differentiate
    const roleGroups = container.selectAll("g.role-group")
        .data(rolesData, d => d.id);

    // Enter new role groups
    const roleGroupsEnter = roleGroups.enter().append("g")
        .attr("class", "role-group")
        .attr('data-id', d => d.id)
        .attr("cursor", "grab")
        .call(onDrag) // Attach drag behavior

    // Append rectangles for new roles
    roleGroupsEnter.append('rect')
        .attr("class", "role")
        .attr("data-id", d => d.id)
        .attr("rx", d => d.size.width / 10 || 0)
        .attr("ry", d => d.size.height / 10 || 0)
        .on("click", onClick); // Attach click behavior


    // Update group transform including for entered groups
    roleGroupsEnter.merge(roleGroups)
        .attr('transform', d => `translate(${(d.x || 0)}, ${(d.y || 0)})`)

    // Add/Update labels for both new and existing groups
    roleGroupsEnter.merge(roleGroups).selectAll('text.role-name')
        .data(d => [d]) // Important: Wrap data in an array to ensure one text per group
        .join('text')
        .attr("class", "role-name")
        .attr('x', d => d.size.width / 2) // Center label within the role rectangle
        .attr('y', -18) // Position above the rectangle
        .attr('text-anchor', 'middle')
        .attr('font-size', '18px')
        .attr('font-weight', 'bold')
        .text(d => d.label || "new role");

    roleGroupsEnter.merge(roleGroups).selectAll('text.role-tag-tick')
        .data(d => d.preexists ? [d] : []) // Create a tick element only if preexists is true
        .join('text')
        .attr("class", "role-tag-tick")
        .attr('x', d => d.size.width - 20)
        .attr('y', -2)
        .text('✓')
        .attr('font-weight', 'bold');

    roleGroupsEnter.merge(roleGroups).selectAll('text.role-tag-multiple')
        .data(d => d.multiple ? [d] : []) // Create an 'n' element only if multiple is true
        .join('text')
        .attr("class", "role-tag-multiple")
        .attr('x', d => d.size.width - 10) // Further to the right; adjust as needed
        .attr('y', 0)
        .text('n');

    // Append/update rectangle sizes last to ensure data is not re-bound incorrectly
    roleGroupsEnter.merge(roleGroups).select('rect.role')
        .attr("width", d => d.size.width)
        .attr("height", d => d.size.height);

    // Exit and remove old roles
    roleGroups.exit().remove();
};

const drawThreadGroups = (container, data, onDrag, onClick) => {
    const threadGroups = container.selectAll("g.thread-group")
        .data(data, d => d.id);

    const threadGroupsEnter = threadGroups.enter().append("g")
        .data(data, d => d.id)
        .attr("class", "thread-group")
        .attr('data-id', d => d.id)
        .attr("cursor", "grab")
        .call(onDrag);

    threadGroupsEnter.append('rect')
        .attr("class", "thread")
        .attr("data-id", d => d.id)
        .attr("x", -10)
        .attr("rx", 5)
        .attr("ry", 5)
        .on("click", onClick);

    threadGroupsEnter.merge(threadGroups)
        .attr('transform', d => `translate(${(d.x)}, ${(d.y)})`);

    threadGroupsEnter.merge(threadGroups).select('rect')
        .attr("width", d => d.size.width)
        .attr("height", d => d.size.height)

    threadGroups.exit().remove();
}

const drawInteractionGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.interaction")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        drawInteraction(container, d);
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}

const drawInteraction = (container, data) => {
    const bindingData = [data]
    //console.log('drawInteraction', bindingData);
    drawRect(container, data);
    if (data && data.initiator) {
        // Draw hatching
        const hatch1 = container.selectAll('line.hatch1').data(bindingData);
        hatch1.enter().append('line').classed('hatch1', true).merge(hatch1)
            .attr('x1', 0)
            .attr('y1', d => d.size.height)
            .attr('x2', d => d.size.width)
            .attr('y2', 0);

        hatch1.exit().remove();

        const hatch2 = container.selectAll('line.hatch2').data(bindingData);
        hatch2.enter().append('line').classed('hatch2', true).merge(hatch2)
            .attr('x1', 0)
            .attr('y1', d => d.size.height / 2)
            .attr('x2', d => d.size.width / 2)
            .attr('y2', 0);

        hatch2.exit().remove();

        const hatch3 = container.selectAll('line.hatch3').data(bindingData);
        hatch3.enter().append('line').classed('hatch3', true).merge(hatch3)
            .attr('x1', d => d.size.width / 2)
            .attr('y1', d => d.size.height)
            .attr('x2', d => d.size.width)
            .attr('y2', d => d.size.height / 2);

        hatch3.exit().remove();
    }
    else if (data && !data.initiator) {
        // Remove hatching if it exists
        container.selectAll('line.hatch1').remove();
        container.selectAll('line.hatch2').remove();
        container.selectAll('line.hatch3').remove();
    }
}

const drawActionGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.action")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        drawRect(container, d);
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}

const drawTriggerGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.trigger")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        drawTrigger(container, d);
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}

const drawTrigger = (container, data) => {
    const bindingData = data ? [data] : container.datum() ? [container.datum()] : [];
    const tline = container.selectAll('line.trig').data(bindingData);
    tline.enter().append('line').classed('trig', true).merge(tline)
        .attr('x1', 10)
        .attr('y1', -5)
        .attr('x2', 10)
        .attr('y2', d => d.size.height + 5);
    tline.exit().remove()

    const tpath = container.selectAll('path.trig').data(bindingData);
    tpath.enter().append('path').classed('trig', true).merge(tpath)
        .attr('d', `M0,0  l5,0 l3,3 l7,0 v-3 l5,5  l-5,5 v-3 l-7,0 l-3,3 l-5,0 l3,-5 Z`);
    tpath.exit().remove();

    drawLabel(container, data);
};

const drawStartRoleGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.start-role")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        drawStartRole(container, d);
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}
const drawStartRole = (container, data) => {
    const bindingData = data ? [data] : container.datum() ? [container.datum()] : [];
    drawRect(container, data);

    const diag1 = container.selectAll('line.diag1').data(bindingData);
    diag1.enter().append('line').classed('diag1', true).merge(diag1)
        .attr('x1', 0)
        .attr('y1', 0)
        .attr('x2', d => d.size.width)
        .attr('y2', d => d.size.height);

    diag1.exit().remove();

    const diag2 = container.selectAll('line.diag2').data(bindingData);
    diag2.enter().append('line').classed('diag2', true).merge(diag2)
        .attr('x1', d => d.size.width)
        .attr('y1', 0)
        .attr('x2', 0)
        .attr('y2', d => d.size.height);

    diag2.exit().remove();
    drawLabel(container, data);
}

const drawEllipsisGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.ellipsis")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        drawEllipsis(container, d);
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}

const drawEllipsis = (container, data) => {
    const bindingData = data ? [data] : container.datum() ? [container.datum()] : [];
    const epath = container.selectAll('path.ellipsis').data(bindingData);
    epath.enter().append('path').classed('ellipsis', true).merge(epath)
        .attr('d', `M10,0  v5 l5,3 l-10,3 l10,3 l-10,3 l5,3 v5`);
    epath.exit().remove();

    drawLabel(container, data);
};

const drawCaseRefinementGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.case-refinement")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        drawCaseRefinement(container, d);
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}

const drawCaseRefinement = (container, data, onClick) => {
    // draw a curved path from the top to the bottom of the case refinement
    const bindingData = [data]

    const path = container.selectAll('path.rpath').data(bindingData);
    path.enter().append('path').classed('rpath', true).merge(path)
        .attr('d', d => {
            // Starting point
            const startX = defaultWidth / 2;  // This needs to be abstracted
            const startY = -defaultHeight / 2;

            // Vertical line down to start curve
            const verticalEndY = 0;

            // Control points for cubic bezier curve to create the curve to the right
            // Adjust these points to control the shape of the curve
            const cp1X = startX; // First control point - same x as starting point to go down vertically
            const cp1Y = verticalEndY + 2; // Move the first control point down a bit to start the curve
            const cp2X = startX + 5; // Second control point - move to the right to create the horizontal end
            const cp2Y = verticalEndY + 5; // Adjust the y to control the curve's height
            const endX = startX + 34; // End point x - to the right of the start
            const endY = verticalEndY + 5; // End point y - same y as vertical end to finish horizontally

            // Construct the path string
            return `M${startX},${startY} V${verticalEndY} C${cp1X},${cp1Y} ${cp2X},${cp2Y} ${endX},${endY}`;
        })
    path.exit().remove();


    const caseLabel = container.selectAll('text.label').data(d => [d]);
    caseLabel.enter().append('text').classed('label', true).merge(caseLabel)
        .attr('x', 0)
        .attr('y', 0)
        .attr('dy', '0.35em')
        .attr('text-anchor', 'end')
        .text(d => d.label || "case");
}


const drawRect = (container, data) => {
    // Select all 'rect' elements within the group and bind the data for the rectangles.
    // If 'group' already has data bound to it, this data will be used for the rectangles.
    const bindingData = data ? [data] : container.datum() ? [container.datum()] : [];

    // Bind the data to the rect elements
    const rects = container.selectAll('rect')
        .data(bindingData);

    // Enter + update pattern for 'rect'
    rects.enter().append('rect').merge(rects)
        .attr("width", d => d.size.width)
        .attr("height", d => d.size.height)
        .attr('x', 0)
        .attr('y', 0)

    // Remove any excess rects.
    rects.exit().remove();

    // Similar pattern for 'line' and 'text' elements.
    const lines1 = container.selectAll('line.line1').data(bindingData);
    lines1.enter().append('line').classed('line1', true).merge(lines1)
        .attr('x1', d => d.size.width / 2)
        .attr('y1', -5)
        .attr('x2', d => d.size.width / 2)
        .attr('y2', 0);
    lines1.exit().remove();

    const lines2 = container.selectAll('line.line2').data(bindingData);
    lines2.enter().append('line').classed('line2', true).merge(lines2)
        .attr('x1', d => d.size.width / 2)
        .attr('y1', d => d.size.height)
        .attr('x2', d => d.size.width / 2)
        .attr('y2', d => d.size.height + 5);
    lines2.exit().remove();

    drawLabel(container, data);
}

const drawLabel = (container, data) => {
    // Render labels
    const bindingData = data ? [data] : container.datum() ? [container.datum()] : [];
    const labels = container.selectAll('text.label').data(bindingData);

    // Enter + update pattern for 'text'
    labels.enter().append('text').classed('label', true)
        .merge(labels)
        .attr('x', 25) // Position to the right of the shape
        .attr('y', d => d.size.height / 2) // Vertically position 
        .attr('dy', '0.35em') // Adjustment for vertical centering
        .attr('font-size', '12px')
        .text(d => d.label || "")

    // Remove any excess rects.
    labels.exit().remove();
};


const drawCaseConditionGroups = (container, data, onDrag, onClick) => {
    const groups = container.selectAll("g.case-condition")
        .data(data, d => d.id)
        .join(
            enter => enter.append("g")
                .attr("class", d => d.class)
                .attr('data-id', d => d.id)
                .attr("cursor", "grab")
                .on("click", onClick)
                .call(onDrag),
            update => update,
            exit => exit.remove()
        );

    groups.each(function (d) {
        const container = d3.select(this);
        // get the offset of the preceding data array element to calculate the x position

        const offset = d.index > 0 ? d.position.x - data[d.index - 1].position.x : 0
        drawCaseCondition(container, d, offset);
        const conditionLabel = container.selectAll('text.condition-label').data(d => [d]);
        conditionLabel.enter().append('text').classed('condition-label', true).merge(conditionLabel)
            .attr('x', 0)
            .attr('y', d => d.position.y - 10)
            .attr('dy', '0.35em')
            .text(d => d.label || "c" + d.index)
            .attr('font-size', '12px')
        conditionLabel.exit().remove();
    });

    groups.attr('transform', d => `translate(${(d.position.x)}, ${(d.position.y)})`);

    groups.exit().remove();
}

export const drawCaseCondition = (container, data, offset) => {
    const bindingData = [data]
    const cline1 = container.selectAll('line.cline1').data(bindingData);
    cline1.enter().append('line').classed('cline1', true).merge(cline1)
        .attr('x1', d => (!d.index || d.index === 0) ? 0 : -offset + 15)
        .attr('x2', d => (!d.index || d.index === 0) ? 0 : 5)
        .attr('y1', 5)
        .attr('y2', 5);
    cline1.exit().remove()

    const cline3 = container.selectAll('line.cline3').data(bindingData);
    cline3.enter().append('line').classed('cline3', true).merge(cline3)
        .attr('x1', defaultWidth / 2)
        .attr('y1', defaultHeight - 4)
        .attr('x2', defaultWidth / 2)
        .attr('y2', defaultHeight + 20);
    cline3.exit().remove()

    const cpoly = container.selectAll('polygon.cpoly').data(bindingData);
    cpoly.enter().append('polygon').classed('cpoly', true).merge(cpoly)
        .attr('points', d => `${defaultWidth / 2},${defaultHeight - 4} ${defaultWidth - 4},-0 4,-0`)
    cpoly.exit().remove()
}


