import { ReactNode, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { Group } from "@visx/group";
import { useTooltip, useTooltipInPortal } from "@visx/tooltip";
import { Point } from "d3-dag/dist/dag";

import { convertToTreeStructure, isNodeVisible, updateTreeRendering } from "utils/tree";
import Box from "ds/components/Box";
import Typography from "ds/components/Typography";
import useAnalytics, { AnalyticsPage } from "hooks/useAnalytics";

import { ConfigNode, ConfigTypeComponents, Connection, Position, PositionedNode } from "./types";
import styles from "./styles.module.css";
import {
  COLUMN_WIDTH,
  CONNECTION_BORDER_RADIUS,
  CONNECTION_SPACE,
  TOOLTIP_TOP_OFFSET,
} from "./constants";
import TreeChartConnection from "./Connection";
import { createEndId, createStartId } from "./utils";

const COLUMN_SPACE = COLUMN_WIDTH + CONNECTION_BORDER_RADIUS * 4 + CONNECTION_SPACE * 2;

type TreeChartProps<T extends string> = {
  nodes: ConfigNode<T>[];
  nodeTypes: ConfigTypeComponents<T>;
  margin?: {
    top?: number;
    right?: number;
    bottom?: number;
    left?: number;
  };
  activeId?: string;
  analyticsPage?: AnalyticsPage;
};

const TreeChart = <T extends string>({
  nodes,
  nodeTypes,
  margin,
  activeId,
  analyticsPage,
}: TreeChartProps<T>) => {
  const trackSegmentAnalyticsEvent = useAnalytics({
    page: analyticsPage,
    callbackTrackProviders: { segment: true },
  });
  const { tooltipData, tooltipLeft, tooltipTop, tooltipOpen, showTooltip, hideTooltip } =
    useTooltip<{ text: ReactNode }>();
  const scrollableContainerRef = useRef<HTMLDivElement>(null);
  const { containerRef, TooltipInPortal } = useTooltipInPortal({
    detectBounds: false,
    scroll: true,
  });
  const { flatList } = useMemo(
    () => convertToTreeStructure<ConfigNode<T>>(nodes, "id", "parent", "name"),
    [nodes]
  );

  const allKeysSet = useMemo(() => {
    const allKeys = flatList.map((item) => item.id);
    return new Set(allKeys);
  }, [flatList]);

  const [connectionPoints, setConnectionPoints] = useState<Record<string, Position>>({});
  const [expandedKeys, setExpandedKeys] = useState(new Set<string>());
  // const [, setLoadingKeys] = useState(new Set<string>());

  const visibleNodes = useMemo(
    () => flatList.filter((node) => isNodeVisible(node, allKeysSet, expandedKeys)),
    [expandedKeys, allKeysSet, flatList]
  );

  useEffect(() => {
    if (activeId && !expandedKeys.has(activeId)) {
      const item = flatList.find(
        ({ id, item }) => id === activeId || !!item.group?.find(({ id }) => id === activeId)
      );

      if (item?.path && item.parentId && !expandedKeys.has(item.parentId)) {
        const { expandedNodes } = updateTreeRendering(flatList, [item?.path]);
        setExpandedKeys(expandedNodes);
      }
    }
  }, [activeId]);

  const [positionedNodes, connections, containerHeight, containerWidth] = useMemo(() => {
    const positionedNodes: PositionedNode<T>[] = [];
    const connections: Connection[] = [];

    let containerWidth = 0;
    let containerHeight = 0;

    let y = 0;

    for (let i = 0; i < visibleNodes.length; i++) {
      const node = visibleNodes[i];
      const prevNode = visibleNodes[i - 1];

      // Change y position for first child
      if (prevNode && prevNode.path.length !== node.path.length && prevNode.id === node.parentId) {
        const prevNodeHeight = nodeTypes[prevNode.item.type].height(prevNode.item);
        y = y - prevNodeHeight;
      }

      positionedNodes.push({
        ...node,
        item: {
          ...node.item,
          position: {
            x: (node.path.length - 1) * COLUMN_SPACE,
            y,
          },
        },
      } as PositionedNode<T>);

      // Container size and next iteration y position
      const nodeHeight = nodeTypes[node.item.type].height(node.item);
      const nodeEndY = y + nodeHeight;
      const nodeEndX = node.path.length * COLUMN_SPACE;

      if (nodeEndY > y) {
        y = nodeEndY;
      }

      if (nodeEndY > containerHeight) {
        containerHeight = nodeEndY;
      }

      if (nodeEndX > containerWidth) {
        containerWidth = nodeEndX;
      }

      // Create connection
      if (node.parentId) {
        const startPointId = createEndId(node.parentId);
        const startPoint = connectionPoints[startPointId];
        const endPointId = createStartId(node.id);
        const endPoint = connectionPoints[endPointId];

        if (startPoint && endPoint) {
          connections.push({
            start: startPoint,
            end: endPoint,
            id: `${endPointId}-${startPointId}`,
          });
        }
      }
    }

    return [
      positionedNodes,
      connections,
      containerHeight + (margin?.top || 0) + (margin?.bottom || 0),
      containerWidth + (margin?.left || 0) + (margin?.right || 0),
    ];
  }, [visibleNodes, connectionPoints, nodeTypes, margin]);

  const toggleKey = useCallback(
    (key: string, hasChildrenToLoad?: boolean, position?: Position) => {
      if (!hasChildrenToLoad) {
        setExpandedKeys((prev) => {
          const newKeys = new Set(prev);
          if (newKeys.has(key)) {
            newKeys.delete(key);
            trackSegmentAnalyticsEvent?.("Diagram node collapsed");
          } else {
            trackSegmentAnalyticsEvent?.("Diagram node expanded");
            newKeys.add(key);
            if (position) {
              // We have to wait unitl element will be rendered on the screen and then we call scroll into
              requestAnimationFrame(() => {
                scrollableContainerRef.current?.scrollTo({
                  behavior: "smooth",
                  left: position.x + COLUMN_SPACE * 2,
                  top: position.y,
                });
              });
            }
          }
          return newKeys;
        });
      }
      // TODO: add query to load children and the switch keys
      // } else {
      // setLoadingKeys((prev) => {
      //   const newKeys = new Set(prev);
      //   if (newKeys.has(key)) {
      //     newKeys.delete(key);
      //   } else {
      //     newKeys.add(key);
      //   }
      //   return newKeys;
      // });
      // TODO: then unset when loaded or error
      // }
    },
    [setExpandedKeys, trackSegmentAnalyticsEvent]
  );

  const handleMouseEnterForTooltip = useCallback(
    (text: ReactNode, coordinates: Point | null) => {
      showTooltip({
        tooltipData: { text },
        tooltipTop: coordinates?.y,
        tooltipLeft: coordinates?.x,
      });
    },
    [showTooltip]
  );

  return (
    <>
      {tooltipOpen && tooltipData && (
        <TooltipInPortal
          top={tooltipTop}
          left={tooltipLeft}
          offsetLeft={margin?.left}
          offsetTop={(margin?.top || 0) - TOOLTIP_TOP_OFFSET}
          className={styles.tooltipContainer}
          unstyled
        >
          <Box className={styles.tooltip} direction="column" gap="medium">
            <Typography tag="span" variant="p-body3" color="on-inversed">
              {tooltipData.text}
            </Typography>
          </Box>
        </TooltipInPortal>
      )}
      <Box grow="1" className={styles.wrapper} ref={scrollableContainerRef}>
        <Box style={{ minWidth: containerWidth }}>
          <svg
            ref={containerRef}
            width={containerWidth}
            height={containerHeight}
            className={styles.treeChart}
          >
            <Group x={0} y={0} textAnchor="top" top={margin?.top} left={margin?.left}>
              {positionedNodes.map((node) => {
                const Component = nodeTypes[node.item.type].Component;

                return (
                  <Component
                    onMouseEnter={handleMouseEnterForTooltip}
                    onMouseLeave={hideTooltip}
                    id={node.id}
                    connectionPoints={connectionPoints}
                    setConnectionPoints={setConnectionPoints}
                    onToggle={() =>
                      toggleKey(node.id, node.item.hasChildrenToLoad, node.item.position)
                    }
                    key={node.id}
                    item={node.item}
                    isParent={!!node.children.length}
                    isOpened={expandedKeys.has(node.id)}
                    columnWidth={COLUMN_WIDTH}
                    activeId={activeId}
                  />
                );
              })}

              {connections.map((connection) => (
                <TreeChartConnection key={connection.id} {...connection} />
              ))}
            </Group>
          </svg>
        </Box>
      </Box>
    </>
  );
};

export default TreeChart;
