import { useEffect, useMemo, useRef } from "react";
import * as d3 from "d3";
import { useDocumentContext } from "../../context/DocumentContext";
import { Chunk } from "../../generated/protos/chunk";

type NodeType = "from" | "value" | "to";

type Node = d3.SimulationNodeDatum & {
  id: string;
  type: NodeType;
};

type Link = d3.SimulationLinkDatum<Node> & {
  source: string | Node;
  target: string | Node;
  label: string;
};

type GraphData = {
  nodes: Node[];
  links: Link[];
};

export default function KnowledgeGraph({
  relationshipIds,
}: {
  relationshipIds: string[];
}) {
  const { allDocuments } = useDocumentContext();

  const graphData = useMemo(() => {
    const relationshipChunks = allDocuments.flatMap(
      (doc) => doc.relationshipChunks
    );

    return createNodeGraph(relationshipIds, relationshipChunks);
  }, [allDocuments, relationshipIds]);

  const svgRef = useRef<SVGSVGElement>(null);

  useEffect(() => {
    if (!svgRef.current || !graphData.nodes.length) return;

    const svg = d3.select(svgRef.current);
    const width = 746;
    const height = 400;

    svg.selectAll("*").remove();

    svg
      .append("defs")
      .append("marker")
      .attr("id", "arrowhead")
      .attr("viewBox", "-0 -5 10 10")
      .attr("refX", 38)
      .attr("refY", 0)
      .attr("orient", "auto")
      .attr("markerWidth", 6)
      .attr("markerHeight", 6)
      .append("svg:path")
      .attr("d", "M 0,-5 L 10,0 L 0,5")
      .attr("fill", "#999");

    const simulation = d3
      .forceSimulation<Node>(graphData.nodes)
      .force(
        "link",
        d3
          .forceLink<Node, Link>(graphData.links)
          .id((d) => d.id)
          .distance(200)
      )
      .force("charge", d3.forceManyBody().strength(-500))
      .force("center", d3.forceCenter(width / 2, height / 2));

    const g = svg.append("g");
    svg.call(
      d3
        .zoom<SVGSVGElement, unknown>()
        .extent([
          [0, 0],
          [width, height],
        ])
        .scaleExtent([0.1, 4])
        .on("zoom", (event) => {
          g.attr("transform", event.transform);
        })
    );

    const links = g
      .append("g")
      .selectAll("line")
      .data(graphData.links)
      .join("line")
      .attr("stroke", "#999")
      .attr("stroke-width", 1)
      .attr("marker-end", "url(#arrowhead)");

    const nodes = g
      .append("g")
      .selectAll<SVGGElement, Node>("g")
      .data(graphData.nodes)
      .join("g")
      .attr("class", "node")
      .call(
        d3
          .drag<SVGGElement, Node>()
          .on("start", dragstarted)
          .on("drag", dragged)
          .on("end", dragended) as any
      );

    nodes
      .append("circle")
      .attr("r", 18)
      .attr("fill", (d) => {
        switch (d.type) {
          case "from":
            return "#FF9800";
          case "value":
            return "#4CAF50";
          case "to":
            return "#2196F3";
          default:
            return "#999";
        }
      });

    nodes
      .append("text")
      .text((d) => d.id)
      .attr("text-anchor", "middle")
      .attr("dy", "30px")
      .attr("fill", "#333")
      .style("font-size", "12px")
      .style("font-weight", "600");

    simulation.on("tick", () => {
      links
        .attr("x1", (d) => (d.source as Node).x!)
        .attr("y1", (d) => (d.source as Node).y!)
        .attr("x2", (d) => (d.target as Node).x!)
        .attr("y2", (d) => (d.target as Node).y!);

      nodes.attr("transform", (d) => `translate(${d.x},${d.y})`);
    });

    function dragstarted(
      event: d3.D3DragEvent<SVGGElement, Node, unknown>,
      d: Node
    ) {
      if (!event.active) simulation.alphaTarget(0.3).restart();
      d.fx = d.x;
      d.fy = d.y;
    }

    function dragged(
      event: d3.D3DragEvent<SVGGElement, Node, unknown>,
      d: Node
    ) {
      d.fx = event.x;
      d.fy = event.y;
    }

    function dragended(
      event: d3.D3DragEvent<SVGGElement, Node, unknown>,
      d: Node
    ) {
      if (!event.active) simulation.alphaTarget(0);
      d.fx = null;
      d.fy = null;
    }
  }, [graphData]);

  return (
    <svg
      ref={svgRef}
      width={"100%"}
      height={400}
      style={{ border: "1px solid #999", borderRadius: "8px" }}
    />
  );
}
function createNodeGraph(
  relationshipIds: string[],
  relationshipChunks: Chunk[]
): GraphData {
  const graph: GraphData = {
    nodes: [],
    links: [],
  };

  const addNode = (id: string, type: NodeType) => {
    if (!graph.nodes.some((node) => node.id === id)) {
      graph.nodes.push({ id, type });
    }
  };

  const relevantRelationships = relationshipChunks.filter(
    (chunk) =>
      chunk.entityRelationship &&
      relationshipIds.includes(chunk.entityRelationship.relationshipId)
  );

  relevantRelationships.forEach((chunk) => {
    const relationship = chunk.entityRelationship;
    if (
      !relationship?.fromEntity ||
      !relationship?.toEntity ||
      !relationship?.value
    ) {
      return;
    }

    const { fromEntity, toEntity, value } = relationship;

    const fromNodeId = fromEntity.entityName;
    const toNodeId =
      toEntity.entitySubcomponents.length > 0
        ? `${toEntity.entityName} ${toEntity.entitySubcomponents.join(" ")}`
        : toEntity.entityName;
    const valueNodeId = `${value.value} (${value.qualifier})`;

    addNode(fromNodeId, "from");
    addNode(valueNodeId, "value");
    addNode(toNodeId, "to");

    graph.links.push(
      {
        source: fromNodeId,
        target: valueNodeId,
        label: "",
      },
      {
        source: valueNodeId,
        target: toNodeId,
        label: "",
      }
    );
  });

  return graph;
}
