import dagre from "dagre";
import { Edge, Node, Position } from "reactflow";

import { getMaxPoints, getMinPoints } from "./getMinMaxPoints";

const NODE_WIDTH = 200;
const NODE_HEIGHT = 40;
const PADDING = 30;

const getLayoutedElements = (nodes: Node<any>[], edges: Edge<any>[]) => {
  const dagreGraph = new dagre.graphlib.Graph();

  dagreGraph.setDefaultEdgeLabel(() => ({}));
  dagreGraph.setGraph({ rankdir: "LR" });

  nodes.forEach(({ id }) => {
    dagreGraph.setNode(id, { width: NODE_WIDTH, height: NODE_HEIGHT });
  });

  edges.forEach(({ source, target }) => {
    dagreGraph.setEdge(source, target);
  });

  dagre.layout(dagreGraph);

  nodes.forEach((node) => {
    const nodeWithPosition = dagreGraph.node(node.id);
    node.targetPosition = Position.Left;
    node.sourcePosition = Position.Right;

    node.position = {
      x: nodeWithPosition.x - NODE_WIDTH / 2 + PADDING,
      y: node.parentId ? nodeWithPosition.y - NODE_HEIGHT / 2 : 0,
    };
    return node;
  });
  const minPoints = getMinPoints(nodes);

  //move all nodes inside the group
  nodes.forEach((node) => {
    if (node.type !== "group") {
      const parentId = node.parentId as string;

      node.position = {
        x: node.position.x - (minPoints[parentId].minX - PADDING),
        y: node.position.y - (minPoints[parentId].minY - PADDING),
      };

      node.style = {
        ...node.style,
        width: NODE_WIDTH,
        height: NODE_HEIGHT,
      };
    }

    return node;
  });

  const maxPoints = getMaxPoints(nodes);

  nodes.forEach((node, index) => {
    if (node.type === "group") {
      if (index !== 0) {
        const { style, position } = nodes[index - 1];
        const prevNodeWidth = style!.width as number;
        const prevNodeX = position.x;

        node.position.x = prevNodeX + prevNodeWidth + PADDING;
      }

      const max = maxPoints[node.id];

      if (max) {
        node.style = {
          ...node.style,
          width: maxPoints[node.id]?.maxX + NODE_WIDTH + PADDING,
          height: maxPoints[node.id]?.maxY + NODE_HEIGHT + PADDING,
        };
      } else {
        node.style = {
          ...node.style,
          width: NODE_WIDTH + PADDING,
          height: NODE_HEIGHT + PADDING,
        };
        console.warn("Couldn't find maxPoints", node);
      }
    }
  });

  return { nodes, edges };
};

export default getLayoutedElements;
