import dagre from "dagre";
import { Edge, Node, Position } from "reactflow";

import {
  NODE_DIMENSIONS,
  NODE_PADDING,
} from "../../constants";
import {
  GraphDirection,
  GroupedNodes,
  MaxGroupPoints,
  MinGroupPoints,
} from "../../types";

type Props = {
  nodes: Node<any>[];
  edges: Edge<any>[];
  direction: GraphDirection;
};

const getLayoutedElements = ({ nodes, edges, direction }: Props) => {
  const dagreGraph = new dagre.graphlib.Graph();
  const { width, height } = NODE_DIMENSIONS;

  dagreGraph.setDefaultEdgeLabel(() => ({}));
  dagreGraph.setGraph({ rankdir: direction });

  nodes.forEach(({ id }) => {
    dagreGraph.setNode(id, { width, height });
  });

  edges.forEach(({ source, target }) => {
    dagreGraph.setEdge(source, target);
  });

  dagre.layout(dagreGraph);

  nodes.forEach((node) => {
    const nodeWithPosition = dagreGraph.node(node.id);
    node.targetPosition = Position.Left;
    node.sourcePosition = Position.Right;

    node.position = {
      x: nodeWithPosition.x - width / 2 + NODE_PADDING,
      y: node.parentId ? nodeWithPosition.y - height / 2 : 0,
    };
    return node;
  });
  const minPoints = _getMinPoints(nodes);

  //move all nodes inside the group
  nodes.forEach((node) => {
    if (node.type !== "group") {
      const parentId = node.parentId as string;

      node.position = {
        x: node.position.x - (minPoints[parentId].minX - NODE_PADDING),
        y: node.position.y - (minPoints[parentId].minY - NODE_PADDING),
      };

      node.style = {
        ...node.style,
        width,
        height,
      };
    }

    return node;
  });

  const maxPoints = _getMaxPoints(nodes);

  nodes.forEach((node, index) => {
    if (node.type === "group") {
      if (index !== 0) {
        const { style, position } = nodes[index - 1];
        const prevNodeWidth = style!.width as number;
        const prevNodeX = position.x;

        node.position.x = prevNodeX + prevNodeWidth + NODE_PADDING;
      }

      const max = maxPoints[node.id];

      if (max) {
        node.style = {
          ...node.style,
          width: maxPoints[node.id]?.maxX + width + NODE_PADDING,
          height: maxPoints[node.id]?.maxY + height + NODE_PADDING,
        };
      } else {
        node.style = {
          ...node.style,
          width: width + NODE_PADDING,
          height: height + NODE_PADDING,
        };
        console.warn("Couldn't find maxPoints", node);
      }
    }
  });

  return { nodes, edges };
};

const _getGroups = (nodes: Node<any>[]) => {
  const grouped: GroupedNodes = nodes.reduce(
    (acc: GroupedNodes, item: Node<any>) => {
      const { parentId } = item;

      if (parentId) {
        if (!acc[parentId]) {
          acc[parentId] = [];
        }
        acc[parentId].push(item);
      }
      return acc;
    },
    {},
  );

  return Object.values(grouped);
};

const _getMinPoints = (nodes: Node<any>[]) => {
  const groups = _getGroups(nodes);
  const minPoints: MinGroupPoints = {};

  groups.forEach((group) => {
    const xValues = group.map((item) => item.position.x);
    const yValues = group.map((item) => item.position.y);
    const parentId = group[0].parentId as string;

    minPoints[parentId] = {
      minX: Math.min(...xValues),
      minY: Math.min(...yValues),
    };
  });
  return minPoints;
};

const _getMaxPoints = (nodes: Node<any>[]) => {
  const groups = _getGroups(nodes);
  const maxPoints: MaxGroupPoints = {};

  groups.forEach((group) => {
    const xValues = group.map((item) => item.position.x);
    const yValues = group.map((item) => item.position.y);
    const parentId = group[0].parentId as string;

    maxPoints[parentId] = {
      maxX: Math.max(...xValues),
      maxY: Math.max(...yValues),
    };
  });
  return maxPoints;
};

export default getLayoutedElements;
