import { useRef, useEffect, useState } from "react";
import {
  Chart as ChartJS,
  ChartData,
  ChartOptions,
  TooltipItem,
  registerables,
} from "chart.js";
import ChartAnnotationPlugin from "chartjs-plugin-annotation";
import { Paper, Box, Typography, Select, MenuItem, Chip } from "@mui/material";
import { TrendingUp } from "lucide-react";

ChartJS.register(...registerables, ChartAnnotationPlugin);

interface Point {
  x: number;
  y: number;
}

type DistributionPoint = {
  x: string;
  y: string;
};

type EnsemblePredictionDistribution = {
  metricName: string;
  distributionCurvePoints: DistributionPoint[];
  mean: string;
  stdDev: string;
  timestamp?: Date;
};

type EnsembleGraphProps = {
  ensemblePredictionDistributions: EnsemblePredictionDistribution[];
};

const EnsembleGraph: React.FC<EnsembleGraphProps> = ({
  ensemblePredictionDistributions,
}) => {
  const uniqueMetrics = Array.from(
    new Set(ensemblePredictionDistributions.map((d) => d.metricName))
  );
  const [selectedMetric, setSelectedMetric] = useState(uniqueMetrics[0] || "");
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const chartRef = useRef<ChartJS | null>(null);
  const containerRef = useRef<HTMLDivElement>(null);

  const processDistributions = (points: DistributionPoint[]): Point[] => {
    const numericPoints = points.map((p) => ({
      x: parseFloat(p.x),
      y: parseFloat(p.y),
    }));

    const maxY = Math.max(...numericPoints.map((p) => p.y));

    return numericPoints.map((p) => ({
      x: p.x,
      y: p.y / maxY,
    }));
  };

  const formatTime = (date?: Date) => {
    if (!date) return "No date";
    return date.toLocaleString(undefined, {
      month: "short",
      day: "numeric",
    });
  };

  useEffect(() => {
    if (!canvasRef.current) return;
    const ctx = canvasRef.current.getContext("2d");
    if (!ctx) return;

    if (chartRef.current) {
      chartRef.current.destroy();
    }

    const selectedData = ensemblePredictionDistributions
      .filter((d) => d.metricName === selectedMetric)
      .sort(
        (a, b) => (a.timestamp?.getTime() ?? 0) - (b.timestamp?.getTime() ?? 0)
      );

    if (selectedData.length === 0) return;

    const allXValues = selectedData.flatMap((d) =>
      d.distributionCurvePoints.map((p) => parseFloat(p.x))
    );
    const minX = Math.min(...allXValues);
    const maxX = Math.max(...allXValues);
    const range = maxX - minX;
    const padding = range * 0.05;

    const data: ChartData = {
      datasets: selectedData.map((dist, i) => {
        const normalizedPoints = processDistributions(
          dist.distributionCurvePoints
        );
        const hue = 210 + i * 30;
        return {
          label: formatTime(dist.timestamp),
          data: normalizedPoints,
          borderColor: `hsl(${hue}, 70%, 50%)`,
          backgroundColor: `hsla(${hue}, 70%, 50%, 0.1)`,
          fill: true,
          tension: 0.4,
          pointRadius: 0,
          borderWidth: 2,
        };
      }),
    };

    const options: ChartOptions = {
      responsive: true,
      maintainAspectRatio: true,
      aspectRatio: 3,
      scales: {
        x: {
          type: "linear",
          min: minX - padding,
          max: maxX + padding,
          title: {
            display: true,
            text: selectedMetric,
          },
          grid: {
            color: "rgba(0,0,0,0.05)",
          },
        },
        y: {
          beginAtZero: true,
          title: {
            display: true,
            text: "Relative Density",
          },
          grid: {
            color: "rgba(0,0,0,0.05)",
          },
        },
      },
      interaction: {
        intersect: false,
        mode: "nearest",
      },
      plugins: {
        tooltip: {
          callbacks: {
            title: (items) => {
              return `Value: ${Number(items[0].parsed.x).toFixed(2)}`;
            },
            label: (context) => {
              const dist = selectedData[context.datasetIndex];
              return [
                `Time: ${formatTime(dist.timestamp)}`,
                `Density: ${context.parsed.y.toFixed(3)}`,
                `Mean: ${Number(dist.mean).toFixed(2)}`,
                `Std Dev: ${Number(dist.stdDev).toFixed(2)}`,
              ];
            },
          },
        },
        annotation: {
          annotations: Object.fromEntries(
            selectedData.map((dist, i) => {
              const hue = 210 + i * 30;
              const meanValue = Number(dist.mean);
              return [
                `mean-line-${i}`,
                {
                  type: "line" as const,
                  xMin: meanValue,
                  xMax: meanValue,
                  borderColor: `hsl(${hue}, 70%, 50%)`,
                  borderWidth: 1,
                  borderDash: [5, 5],
                  label: {
                    display: true,
                    content: `μ=${meanValue.toFixed(2)}`,
                    position: "end",
                    backgroundColor: `hsl(${hue}, 70%, 50%)`,
                    color: "white",
                    padding: 4,
                    font: { size: 11 },
                  },
                },
              ];
            })
          ),
        },
      },
    };

    chartRef.current = new ChartJS(ctx, {
      type: "line",
      data,
      options,
    });

    return () => {
      if (chartRef.current) {
        chartRef.current.destroy();
      }
    };
  }, [selectedMetric, ensemblePredictionDistributions]);

  return (
    <Paper
      elevation={2}
      style={{
        width: "100%",
        height: "100%",
        padding: "24px",
        display: "flex",
        flexDirection: "column",
      }}
    >
      <Box
        sx={{
          display: "flex",
          justifyContent: "space-between",
          alignItems: "center",
          marginBottom: "16px",
        }}
      >
        <Box sx={{ display: "flex", alignItems: "center", gap: "16px" }}>
          <TrendingUp size={20} />
          <Select
            size="small"
            value={selectedMetric}
            onChange={(e) => setSelectedMetric(e.target.value)}
            sx={{ minWidth: 200 }}
          >
            {uniqueMetrics.map((metric) => (
              <MenuItem key={metric} value={metric}>
                {metric}
              </MenuItem>
            ))}
          </Select>
        </Box>
        <Chip
          label={`${
            ensemblePredictionDistributions.filter(
              (d) => d.metricName === selectedMetric
            ).length
          } Distributions`}
          size="small"
        />
      </Box>
      <Box
        ref={containerRef}
        sx={{
          position: "relative",
          width: "100%",
          flex: 1,
          minHeight: "400px",
          height: "calc(100% - 60px)",
        }}
      >
        <canvas ref={canvasRef} style={{ width: "100%", height: "100%" }} />
      </Box>
    </Paper>
  );
};

export default EnsembleGraph;
