import { Dictionary, groupBy, isEqual } from 'lodash';
import { hierarchy, pack } from 'd3-hierarchy';
import { scaleLinear } from 'd3-scale';
import { hsl } from 'd3-color';
import { select, Selection } from 'd3-selection';
import {
  forceSimulation,
  forceManyBody,
  forceX,
  forceY,
  forceLink,
  forceCollide,
  Simulation,
  ForceLink,
} from 'd3-force';
import { line, curveBundle } from 'd3-shape';
import { GraphNode, GraphLink } from 'model';
import { setAllNodesCollapsed, calcNodeDescendantCounts } from 'utils';
import { D3Link, D3Node, GraphView, SvgDatum } from './graphTypes';
import { graphNodeInteraction, graphSvgInteraction } from './graphInteraction';
import {
  addToHighlightedNodes,
  arrowSize,
  assignColorSelections,
  attachLinksToNodes,
  calcGroupNodeHullPath,
  calcSelectedGroupNodeHullPath,
  combineLinks,
  computeForceLinkDistance,
  constrainChildNode,
  generateLinkPoints,
  getGroupNodeFillColor,
  getLeafNodeFillColor,
  getLeafNodeLabelColor,
  getNodeBorderColor,
  highlightedColorDark,
  pruneGraph,
  shouldIncludeLink,
  truncateText,
} from './graphHelpers';

const linkWidthScale = scaleLinear().domain([0, 1]).range([0.5, 3.5]);

const linkColorScale = scaleLinear().domain([0, 1]).range([0.6, 0.25]);

const d3line = line()
  .curve(curveBundle.beta(0.9))
  .x((d) => (d as any).x)
  .y((d) => (d as any).y);

export function drawGraph(
  container: HTMLDivElement,
  graphRootNode: GraphNode,
  graphLinks: GraphLink[],
  viewConfig?: Record<string, any>,
  onViewConfigChange?: (viewConfig: Record<string, any>) => void,
  onNodeHighlight?: (node: GraphNode | undefined, event: any) => void,
) {
  const viewSize = [container.offsetWidth, container.offsetHeight];

  select(container).select('svg').remove();

  const svg = select(container)
    .append('svg')
    .style('width', '100%')
    .style('height', '100%')
    .attr('viewBox', `0 0 ${viewSize[0]} ${viewSize[1]}`);

  const svgDatum: SvgDatum = {
    viewSize,
    graphRootNode,
    graphLinks,
    prunedGraphRootNode: graphRootNode,
    prunedGraphLinks: graphLinks,
    bodySimulations: [],
    nodes: [],
    links: [],
    maxLinkWeight: 0,
    topNodes: [],
    topLinks: [],
    nodeMap: {},
    depthGroupedNodes: {},
    viewConfig,
    onViewConfigChange,
    onNodeHighlight,
    updateGraphLayout,
    updateGraphVisibility,
    runGraphForceSimulation,
  };

  svg.datum<SvgDatum>(svgDatum);
  svg.call(graphSvgInteraction);

  svg
    .append('defs')
    .selectAll('marker')
    .data([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    .enter()
    .append('marker')
    .attr('id', (d) => `link-arrow-${d}`)
    .attr('viewBox', '0 0 11 11')
    .attr('refX', 6)
    .attr('refY', 6)
    .attr('markerUnits', 'userSpaceOnUse')
    .attr('markerWidth', arrowSize)
    .attr('markerHeight', arrowSize)
    .attr('orient', 'auto')
    .append('path')
    .attr('d', 'M2,2 L10,6 L2,10 L6,6 L2,2')
    .attr('fill', (d) => hsl(0, 0, linkColorScale(d / 10)).toString());

  const scene = svg.append('g');

  scene.append('g').attr('class', 'nodes group-nodes').call(graphNodeInteraction, svg);

  scene.append('g').attr('class', 'group-nodes-selected');

  scene.append('g').attr('class', 'nodes leaf-nodes').call(graphNodeInteraction, svg);

  scene.append('g').attr('class', 'leaf-nodes-selected');

  scene.append('g').attr('class', 'node-counts').attr('pointer-events', 'none');

  scene.append('g').attr('class', 'links').attr('pointer-events', 'none');

  scene.append('g').attr('class', 'target-points').attr('pointer-events', 'none');

  scene
    .append('g')
    .attr('class', 'node-labels')
    .attr('pointer-events', 'none')
    .attr('text-anchor', 'middle');

  graphLinks.sort((a, b) => a.weight - b.weight); // Sort for pruning.
  updateGraphLayout(svg);

  const graphView: GraphView = {
    setViewConfig: (viewConfig: Record<string, any> | undefined) => {
      if (!isEqual(viewConfig, svgDatum.viewConfig)) {
        const pruneValueChange = viewConfig?.pruneValue !== svgDatum.viewConfig?.pruneValue;
        const removedNodesChange = viewConfig?.removedNodes !== svgDatum.viewConfig?.removedNodes;
        const expandedNodesChange =
          viewConfig?.expandedNodes !== svgDatum.viewConfig?.expandedNodes;
        svgDatum.viewConfig = viewConfig;
        assignColorSelections(svgDatum.nodes, svgDatum.viewConfig?.colorSelections);
        if (pruneValueChange || removedNodesChange || expandedNodesChange) {
          updateGraphLayout(svg);
        } else if (!svgDatum.bodySimulationsRunning) {
          updateGraphVisibility(svg);
        }
      }
    },
  };

  return graphView;
}

function updateGraphLayout(svg: Selection<SVGSVGElement, unknown, null, undefined>) {
  const svgDatum = svg.datum() as SvgDatum;
  const { viewSize } = svgDatum;
  const packer = pack<GraphNode>()
    .size([viewSize[0], viewSize[1]])
    .radius((node) => 25 / node.depth)
    .padding(arrowSize);

  stopBodySimulations(svgDatum);
  pruneGraph(svgDatum);
  calcNodeDescendantCounts(svgDatum.prunedGraphRootNode);
  setAllNodesCollapsed(svgDatum.prunedGraphRootNode, svgDatum.viewConfig?.expandedNodes);
  const tree = hierarchy(svgDatum.prunedGraphRootNode).count();
  const packed = packer(tree) as D3Node;

  const nodes = packed.descendants().slice(1);
  svgDatum.nodes = nodes;
  svgDatum.nodeMap = {};
  nodes.forEach((node) => (svgDatum.nodeMap[node.data.id] = node));

  svgDatum.links = combineLinks(svgDatum.prunedGraphLinks);
  svgDatum.maxLinkWeight = Math.max(...svgDatum.links.map((link) => link.weight));
  svgDatum.topNodes = packed.children || [];
  svgDatum.topLinks = combineLinks(svgDatum.prunedGraphLinks, { depth: 1 });
  svgDatum.depthGroupedNodes = groupBy(nodes, 'depth');
  delete svgDatum.depthGroupedNodes[1];

  // If this is an update of the layout, re-establish the force-directed position of pre-existing nodes.
  nodes.forEach((node) => {
    if (node.data.state.coords !== undefined) {
      Object.assign(node, node.data.state.coords);
    }
  });

  assignColorSelections(nodes, svgDatum.viewConfig?.colorSelections);

  attachLinksToNodes(svgDatum.topLinks, svgDatum.nodeMap);
  attachLinksToNodes(svgDatum.links, svgDatum.nodeMap);

  runGraphForceSimulation(svg);
}

function runGraphForceSimulation(
  svg: Selection<SVGSVGElement, unknown, null, undefined>,
  manualAdjust?: boolean,
) {
  const svgDatum = svg.datum() as SvgDatum;
  const {
    prunedGraphLinks,
    topNodes,
    topLinks,
    nodeMap,
    depthGroupedNodes,
    bodySimulationsRunning,
  } = svgDatum;

  if (bodySimulationsRunning) {
    return;
  }

  // Create or restart top simulation.
  let topSimulation = svgDatum.bodySimulations[0];
  if (topSimulation) {
    topSimulation.alpha(Math.max(topSimulation.alpha(), manualAdjust ? 0.01 : 0.5));
  } else {
    topSimulation = createTopBodySimulation(svg);
    svgDatum.bodySimulations.push(topSimulation);
  }

  topSimulation.nodes(topNodes);
  (topSimulation.force('link') as ForceLink<D3Node, D3Link>).links(topLinks);
  svgDatum.bodySimulations[0].restart();
  svgDatum.bodySimulationsRunning = true;

  // Create or restart nested simulations.
  Object.entries(depthGroupedNodes).forEach(([key, nodes]) => {
    const depth = +key;
    const nodeIdSet = new Set(nodes.map((node) => node.data.id));
    const nestedSimLinks = combineLinks(prunedGraphLinks, { depth }).filter(
      (link) => nodeIdSet.has(link.source as string) || nodeIdSet.has(link.target as string),
    );

    attachLinksToNodes(nestedSimLinks, nodeMap);

    let simulation = svgDatum.bodySimulations[depth - 1];
    if (!simulation) {
      simulation = createNestedBodySimulation(svg);
      svgDatum.bodySimulations.push(simulation);
    } else {
      simulation.alpha(Math.max(simulation.alpha(), manualAdjust ? 0.01 : 0.5));
    }

    simulation.nodes(nodes);
    (simulation.force('link') as ForceLink<D3Node, D3Link>).links(nestedSimLinks);
    (simulation as any).nodeMap = nodeMap;
    (simulation as any).nodeIdSet = nodeIdSet;
    simulation.restart();
  });
}

function createTopBodySimulation(svg: Selection<SVGSVGElement, unknown, null, undefined>) {
  const svgDatum = svg.datum() as SvgDatum;
  const { viewSize, bodySimulations } = svgDatum;
  const center = [viewSize[0] / 2, viewSize[1] / 2];
  const hasLinks = svgDatum.links && svgDatum.links.length > 0;

  function handleTopLevelBodyPositioning() {
    svgDatum.nodes.forEach((node) => {
      if (node.depth === 1) {
        node.x = Math.min(Math.max(node.x, node.r), viewSize[0] - node.r);
        node.y = Math.min(Math.max(node.y, node.r), viewSize[1] - node.r);
      } else {
        constrainChildNode(node);
      }
      node.data.state.coords = { x: node.x, y: node.y };
    });
  }

  let tickCount = 0;
  function onTick() {
    handleTopLevelBodyPositioning();
    const maxAlpha = Math.max(...bodySimulations.map((s) => s.alpha()));
    if (maxAlpha < 0.7 && tickCount % 10 === 0) {
      updateGraphElements(svg);
    }
    tickCount++;
  }

  const topSimulation = forceSimulation<D3Node, D3Link>()
    .force(
      'charge',
      forceManyBody()
        .strength(hasLinks ? -250 : 300)
        .theta(0.4),
    )
    .force('link', forceLink<D3Node, D3Link>().strength(0.4).distance(computeForceLinkDistance(75)))
    .force(
      'collision',
      forceCollide<D3Node>().radius((d) => d.r + 10),
    )
    .force('x', forceX<D3Node>(center[0]).strength(0.1))
    .force('y', forceY<D3Node>(center[1]).strength(0.1))
    .on('tick', onTick)
    .on('end', () => stopBodySimulations(svgDatum));

  return topSimulation;
}

function stopBodySimulations(svgDatum: SvgDatum) {
  svgDatum.bodySimulations?.forEach((sim) => sim.stop());
  svgDatum.bodySimulationsRunning = false;
}

function createNestedBodySimulation(svg: Selection<SVGSVGElement, unknown, null, undefined>) {
  const svgDatum = svg.datum() as SvgDatum;
  const { nodeMap } = svgDatum;

  function updateNonParticipatingLinkNodes(simulation: Simulation<D3Node, D3Link>) {
    const links = (simulation.force('link') as ForceLink<D3Node, D3Link>).links();
    const nodeIdSet = (simulation as any).nodeIdSet;
    links.forEach((link) => {
      const sourceId = (link.source as D3Node).data?.id;
      if (!nodeIdSet.has(sourceId)) {
        link.source = { ...nodeMap[sourceId] };
      }
      const targetId = (link.target as D3Node).data?.id;
      if (!nodeIdSet.has(targetId)) {
        link.target = { ...nodeMap[targetId] };
      }
    });
  }

  function handleNestedBodyPositioning(simulation: Simulation<D3Node, D3Link>) {
    simulation.nodes().forEach((node) => {
      constrainChildNode(node);
      node.data.state.coords = { x: node.x, y: node.y };
    });
  }

  const simulation = forceSimulation<D3Node, D3Link>()
    //.alpha(0.8)
    //.alphaDecay(0.1)
    .force('charge', forceManyBody().strength(-150))
    .force('link', forceLink<D3Node, D3Link>().strength(0.6).distance(computeForceLinkDistance(25)))
    .force(
      'collision',
      forceCollide<D3Node>().radius((d) => d.r + 5),
    )
    .on('tick', () => {
      handleNestedBodyPositioning(simulation);
      updateNonParticipatingLinkNodes(simulation);
    });

  return simulation;
}

function updateGraphElements(svg: Selection<SVGSVGElement, unknown, null, undefined>) {
  const svgDatum = svg.datum() as SvgDatum;
  const { nodes, links, maxLinkWeight } = svgDatum;

  if (!nodes || !links) {
    return;
  }

  nodes.forEach((node) => {
    node.outerPoints = node.paddedOuterPoints = undefined;
  });

  const leafNodes = nodes.filter((node) => node.children === undefined);
  const groupNodes = nodes.filter((node) => node.children !== undefined);
  const linkPoints: D3Node[][] = [];
  // const linkNodes: D3Node[] = [];
  const targetPointsMap: Dictionary<D3Node[]> = {};

  links.forEach((link) => {
    const points = generateLinkPoints(link);
    linkPoints.push(points);
    // linkNodes.push(...points);
    targetPointsMap[link.targetKey as string] = link.targetPoints as D3Node[];
  });

  svg
    .select('g.group-nodes')
    .selectAll('path')
    .data(groupNodes)
    .join('path')
    .attr('stroke', highlightedColorDark)
    .attr('stroke-width', 1.5)
    .attr('d', calcGroupNodeHullPath);

  svg
    .select('g.group-nodes-selected')
    .selectAll('path')
    .data(groupNodes)
    .join('path')
    .attr('fill', 'none')
    .attr('stroke', highlightedColorDark)
    .attr('stroke-dasharray', '3,2')
    .attr('stroke-width', 1.5)
    .attr('d', calcSelectedGroupNodeHullPath);

  svg
    .select('g.leaf-nodes-selected')
    .selectAll('circle')
    .data(leafNodes)
    .join('circle')
    .attr('fill', 'none')
    .attr('stroke', highlightedColorDark)
    .attr('stroke-width', 1.5)
    .attr('stroke-dasharray', (d) => (d.r < 10 ? '2,1' : '3,2'))
    .attr('transform', (d) => `translate(${d.x},${d.y})`)
    .attr('r', (d) => d.r + Math.min(d.r / 4, 4));

  svg
    .select('g.leaf-nodes')
    .selectAll('circle')
    .data(leafNodes)
    .join('circle')
    .attr('stroke-width', 1.5)
    .attr('transform', (d) => `translate(${d.x},${d.y})`)
    .attr('r', (d) => d.r);

  svg
    .select('g.node-counts')
    .selectAll('text')
    .data(leafNodes)
    .join('text')
    .attr('text-anchor', 'middle')
    .attr('alignment-baseline', 'middle')
    .style('font', (d) => (d.depth === 1 ? '10px sans-serif' : '6px sans-serif'))
    .attr('transform', (d) => `translate(${d.x},${d.y})`);

  svg
    .select('g.links')
    .selectAll('path')
    .data(linkPoints)
    .join('path')
    .attr('d', d3line as any)
    .style('fill', 'none')
    .style('stroke', (_, i) =>
      hsl(0, 0, linkColorScale(links[i].weight / maxLinkWeight)).toString(),
    )
    .attr('stroke-width', (_, i) => linkWidthScale(links[i].weight / maxLinkWeight));

  /*
  svg.select('g.links')
    .selectAll('circle')
    .data(linkNodes)
    .join('circle')
      .attr('fill', '#444')
      .attr('transform', d => `translate(${d.x},${d.y})`)
      .attr('r', 1);
*/

  svg
    .select('g.target-points')
    .selectAll('line')
    .data(Object.entries(targetPointsMap))
    .join('line')
    .attr('data-key', (d) => d[0])
    .attr('x1', (d) => d[1][0].x)
    .attr('y1', (d) => d[1][0].y)
    .attr('x2', (d) => d[1][1].x)
    .attr('y2', (d) => d[1][1].y)
    .attr('stroke-width', 0)
    .attr('marker-end', 'url(#link-arrow)')
    .style('fill', 'none');

  svg
    .select('g.node-labels')
    .selectAll('text')
    .attr('text-anchor', 'middle')
    .data(nodes)
    .join('text')
    .text((d) => truncateText(d.data.label))
    .style('font', (d) => (d.depth === 1 ? '10px sans-serif' : '6px sans-serif'));

  updateGraphVisibility(svg);
}

function updateGraphVisibility(svg: Selection<SVGSVGElement, unknown, null, undefined>) {
  const svgDatum = svg.datum() as SvgDatum;
  const { highlightedNodeId, links, maxLinkWeight, nodeMap } = svgDatum;
  const highlightedNode = highlightedNodeId && nodeMap[highlightedNodeId];
  const highlightedNodeIds = new Set<string>();
  const highlightedTargetIds = new Set<string>();
  const targetWeightsMap: Dictionary<number[]> = {};

  links.forEach((link) => {
    targetWeightsMap[link.targetKey as string] = [];
  });

  function getTargetWeight(d: any[]) {
    const targetWeights = targetWeightsMap[d[0]];
    if (targetWeights) {
      return Math.ceil((10 * Math.max(...targetWeights)) / maxLinkWeight);
    } else {
      return 0;
    }
  }

  function getNodeOpacity(node: D3Node) {
    return highlightedNodeId && !highlightedNodeIds.has(node.data.id) ? 0.1 : 1;
  }

  function getNodeLabelColor(node: D3Node) {
    return node.data.id === highlightedNodeId ? 'inherit' : '#555';
  }

  if (highlightedNode) {
    addToHighlightedNodes(highlightedNode, highlightedNodeIds);
    links?.forEach((link) => {
      if (shouldIncludeLink([link.source as D3Node, link.target as D3Node], highlightedNode)) {
        addToHighlightedNodes(link.source as D3Node, highlightedNodeIds);
        addToHighlightedNodes(link.target as D3Node, highlightedNodeIds);
        const key = link.targetKey as string;
        highlightedTargetIds.add(key);
        if (key in targetWeightsMap) {
          targetWeightsMap[key].push(link.weight);
        }
      }
    });
  } else {
    links?.forEach((link) => {
      const key = link.targetKey as string;
      if (key in targetWeightsMap) {
        targetWeightsMap[key].push(link.weight);
      }
    });
  }

  svg
    .select('g.group-nodes')
    .selectAll<any, D3Node>('path')
    .attr('fill', (d) => getGroupNodeFillColor(d, d === highlightedNode))
    //.attr('fill-opacity', d => getNodeFillOpacity(d, d === highlightedNode))
    .attr('stroke', (d) =>
      getNodeBorderColor(d, d === highlightedNode, highlightedNodeIds?.has(d.data.id)),
    )
    .attr('opacity', getNodeOpacity);

  svg
    .select('g.group-nodes-selected')
    .selectAll<any, D3Node>('path')
    .attr('opacity', getNodeOpacity)
    .style('visibility', (d) =>
      svgDatum.selectedNodeIds?.includes(d.data.id) ? 'visible' : 'hidden',
    );

  svg
    .select('g.leaf-nodes')
    .selectAll<any, D3Node>('circle')
    .attr('fill', (d) =>
      getLeafNodeFillColor(d, d === highlightedNode, highlightedNodeIds?.has(d.data.id)),
    )
    .attr('stroke', (d) =>
      getNodeBorderColor(d, d === highlightedNode, highlightedNodeIds?.has(d.data.id)),
    )
    .attr('opacity', getNodeOpacity);

  svg
    .select('g.leaf-nodes-selected')
    .selectAll<any, D3Node>('circle')
    .attr('opacity', getNodeOpacity)
    .style('visibility', (d) =>
      svgDatum.selectedNodeIds?.includes(d.data.id) ? 'visible' : 'hidden',
    );

  svg
    .select('g.node-counts')
    .selectAll<any, D3Node>('text')
    .attr('fill', getLeafNodeLabelColor)
    .attr('opacity', getNodeOpacity)
    .text((d) => {
      if (!d.data.state.collapsed || !d.data.state.children?.length) {
        return '';
      } else {
        return d.data.state.descendantCount;
      }
    });

  svg
    .select('g.node-labels')
    .selectAll<any, D3Node>('text')
    .attr('opacity', getNodeOpacity)
    .style('fill', getNodeLabelColor)
    .attr('transform', (d) => {
      const labelX = d.labelX || d.x;
      let labelY = d.labelY || d.y - d.r * 1.1;
      if (svgDatum.selectedNodeIds?.includes(d.data.id)) {
        labelY -= 4;
      }
      return `translate(${labelX},${labelY})`;
    });

  svg
    .select('g.links')
    .selectAll<any, D3Node[]>('path')
    .attr('opacity', (d) => (highlightedNode && !shouldIncludeLink(d, highlightedNode) ? 0 : 1));

  svg
    .select('g.target-points')
    .selectAll('line')
    .attr('marker-end', (d: any) => `url(#link-arrow-${getTargetWeight(d)})`);
}
