import { useEffect, useMemo, useState } from "react";
import React from "react";
import { useDocumentContext } from "../../context/DocumentContext";
import { useAnalyticsData } from "../../hooks/useAnalyticsData";
import { ChartsContainer } from "./style";
import {
  Paper,
  Box,
  CircularProgress,
  Typography,
  MenuItem,
  Select,
} from "@mui/material";
import AccuracyTable from "./AccuracyTable";
import LineChart from "./LineChart";
import AccuraciesChart from "./AccuraciesChart";
import BoxPlot from "./BoxPlot";
import EnsembleGraph from "./EnsembleGraph";
import MonthRangePicker from "../MonthRangePicker";
import { subMonths, addMonths } from "date-fns";
import { ScriptableContext } from "chart.js";
import { startOfMonthUTC, processTimeseries, average } from "./utilities";

const colorScheme = [
  "rgb(54, 162, 235)", // Blue
  "rgb(255, 99, 132)", // Red
  "rgb(255, 206, 86)", // Yellow
  "rgb(75, 192, 192)", // Teal
  "rgb(153, 102, 255)", // Purple
  "rgb(255, 159, 64)", // Orange
  "rgb(101, 116, 205)", // Medium Blue
  "rgb(242, 124, 195)", // Pink
  "rgb(91, 192, 222)", // Light Blue
  "rgb(115, 147, 179)", // Steel Blue
];

const LoadingOverlay = () => (
  <Box
    sx={{
      position: "absolute",
      top: 0,
      left: 0,
      right: 0,
      bottom: 0,
      display: "flex",
      flexDirection: "column",
      alignItems: "center",
      justifyContent: "center",
      gap: 1,
      backgroundColor: "rgba(255, 255, 255, 0.8)",
      zIndex: 1,
    }}
  >
    <CircularProgress size={24} />
    <Typography variant="body2" color="textSecondary">
      Training Forecast Model...
    </Typography>
  </Box>
);

const Charts = () => {
  const [targetMetric, setTargetMetric] = useState<string>("");
  const [startDate, setStartDate] = useState<Date>(
    startOfMonthUTC(subMonths(new Date(), 12))
  );
  const [endDate, setEndDate] = useState(
    startOfMonthUTC(addMonths(new Date(), 6))
  );

  const { selectedDocuments } = useDocumentContext();
  const {
    timeseriesData,
    forecastData,
    isTimeseriesLoading,
    isForecastLoading,
    error,
  } = useAnalyticsData();

  const handleDateChange = (start: Date, end: Date) => {
    setStartDate(start);
    setEndDate(end);
  };

  // Only depends on timeseries data
  const targetMetricOptions = useMemo(() => {
    return (
      timeseriesData?.metricsHistoricalData?.map(
        ({ metricName }) => metricName
      ) ?? []
    );
  }, [timeseriesData]);

  const timeseriesAnalytics = useMemo(
    () => ({
      groundTruth: processTimeseries(
        timeseriesData?.metricsHistoricalData?.find(
          ({ metricName }) => metricName === targetMetric
        )?.timeSeriesData ?? [],
        startDate,
        endDate
      ),
      bankPredictions:
        timeseriesData?.allBanksAnalytics?.map((bank) => ({
          name: bank.bank,
          predictions: processTimeseries(
            bank.metricAnalytics?.find(
              ({ metricName }) => metricName === targetMetric
            )?.predictionsTimeseries ?? [],
            startDate,
            endDate
          ),
        })) ?? [],
    }),
    [timeseriesData, targetMetric, startDate, endDate]
  );

  const timeseriesDatasets = useMemo(() => {
    if (!timeseriesAnalytics?.groundTruth) {
      return {
        groundTruth: {
          label: "Ground Truth",
          data: [],
          borderColor: "rgba(0, 0, 0, 0.50)",
          backgroundColor: "rgba(0, 0, 0, 0.50)",
          borderWidth: 6,
          segment: {
            borderDash: () => [],
          },
        },
        bankPredictions: [],
      };
    }

    return {
      groundTruth: {
        label: "Ground Truth",
        data: timeseriesAnalytics.groundTruth.map(
          ({ value, asOfTimestamp }) => ({
            x: asOfTimestamp.getTime(),
            y: value,
          })
        ),
        borderColor: "rgba(0, 0, 0, 0.50)",
        backgroundColor: "rgba(0, 0, 0, 0.50)",
        borderWidth: 6,
        segment: {
          borderDash: (ctx: ScriptableContext<"line">) =>
            (ctx as any).p0.parsed.x >= startOfMonthUTC(new Date()).getTime()
              ? [5, 5]
              : [],
        },
      },
      bankPredictions: (timeseriesAnalytics.bankPredictions ?? []).map(
        (bank, index) => ({
          label: bank.name,
          data:
            bank.predictions?.map(({ value, asOfTimestamp }) => ({
              x: asOfTimestamp.getTime(),
              y: value,
            })) ?? [],
          borderColor: colorScheme[index] ?? "rgb(211, 211, 211)",
          backgroundColor: colorScheme[index] ?? "rgb(211, 211, 211)",
          borderWidth: 2,
          segment: {
            borderDash: (ctx: ScriptableContext<"line">) =>
              (ctx as any).p0.parsed.x >= startOfMonthUTC(new Date()).getTime()
                ? [5, 5]
                : [],
          },
        })
      ),
    };
  }, [timeseriesAnalytics]);

  const accuracyDatasets = useMemo(() => {
    if (!timeseriesData?.allBanksAnalytics) {
      return { bankAccuracies: [], accuracyBoxplots: [] };
    }

    return {
      bankAccuracies: timeseriesData.allBanksAnalytics.map((bank, index) => ({
        label: bank.bank,
        data: processTimeseries(
          bank.metricAnalytics?.find(
            ({ metricName }) => metricName === targetMetric
          )?.accuracyTimeseries ?? [],
          startDate,
          endDate
        ).map(({ value, asOfTimestamp }) => ({
          x: asOfTimestamp.getTime(),
          y: value,
        })),
        borderColor: colorScheme[index] ?? "rgb(211, 211, 211)",
        backgroundColor: colorScheme[index] ?? "rgb(211, 211, 211)",
        borderWidth: 2,
      })),
      accuracyBoxplots: timeseriesData.allBanksAnalytics.map((bank, index) => ({
        label: bank.bank,
        data: [
          {
            min: parseFloat(
              bank.metricAnalytics?.find(
                ({ metricName }) => metricName === targetMetric
              )?.accuracyBoxplot?.min ?? "0"
            ),
            q1: parseFloat(
              bank.metricAnalytics?.find(
                ({ metricName }) => metricName === targetMetric
              )?.accuracyBoxplot?.q1 ?? "0"
            ),
            median: parseFloat(
              bank.metricAnalytics?.find(
                ({ metricName }) => metricName === targetMetric
              )?.accuracyBoxplot?.median ?? "0"
            ),
            q3: parseFloat(
              bank.metricAnalytics?.find(
                ({ metricName }) => metricName === targetMetric
              )?.accuracyBoxplot?.q3 ?? "0"
            ),
            max: parseFloat(
              bank.metricAnalytics?.find(
                ({ metricName }) => metricName === targetMetric
              )?.accuracyBoxplot?.max ?? "0"
            ),
          },
        ],
        borderColor: colorScheme[index] ?? "rgb(211, 211, 211)",
        backgroundColor: colorScheme[index] ?? "rgb(211, 211, 211)",
      })),
    };
  }, [timeseriesData, targetMetric, startDate, endDate]);

  const averageBankAccuracies = useMemo(() => {
    if (!timeseriesData?.allBanksAnalytics) return [];

    return timeseriesData.allBanksAnalytics
      .map((bank) => {
        const accuracies = processTimeseries(
          bank.metricAnalytics?.find(
            ({ metricName }) => metricName === targetMetric
          )?.accuracyTimeseries ?? [],
          startDate,
          endDate
        );
        return {
          name: bank.bank,
          accuracy: average(accuracies.map(({ value }) => value)),
        };
      })
      .filter(
        (x): x is { name: string; accuracy: number } => x.accuracy !== null
      )
      .sort((a, b) => b.accuracy - a.accuracy);
  }, [timeseriesData, targetMetric, startDate, endDate]);

  useEffect(() => {
    if (error) console.error("Analytics Error:", error);
  }, [error]);

  useEffect(() => {
    if (
      targetMetric === "" &&
      timeseriesData?.metricsHistoricalData?.[0]?.metricName
    ) {
      setTargetMetric(timeseriesData.metricsHistoricalData[0].metricName);
    }
  }, [timeseriesData, targetMetric]);

  if (isTimeseriesLoading) {
    return (
      <ChartsContainer>
        <Box
          sx={{
            display: "flex",
            flexDirection: "column",
            alignItems: "center",
            justifyContent: "center",
            gap: 1,
            py: 10,
            textAlign: "center",
          }}
        >
          <CircularProgress size={24} />
          <Typography variant="h6" color="textSecondary">
            Loading Report Data...
          </Typography>
        </Box>
      </ChartsContainer>
    );
  }

  if (selectedDocuments.length === 0) {
    return (
      <ChartsContainer>
        <Box
          sx={{
            display: "flex",
            flexDirection: "column",
            alignItems: "center",
            justifyContent: "center",
            gap: 1,
            py: 4,
            textAlign: "center",
          }}
        >
          <Typography variant="h6" color="textSecondary">
            No documents selected
          </Typography>
          <Typography variant="body2" color="textSecondary">
            Upload and/or select a report to get started.
          </Typography>
        </Box>
      </ChartsContainer>
    );
  }

  if (!timeseriesData) return null;

  return (
    <ChartsContainer>
      <div
        style={{
          display: "flex",
          justifyContent: "space-between",
          marginBottom: "20px",
        }}
      >
        <MonthRangePicker
          startMonth={startDate}
          endMonth={endDate}
          onChange={handleDateChange}
        />
        <Select
          size="small"
          value={targetMetric}
          onChange={(event) => setTargetMetric(event.target.value)}
          sx={{
            "& .MuiInputBase-root": { height: 40 },
            "& .MuiSelect-select": {
              paddingTop: "8px",
              paddingBottom: "8px",
            },
            backgroundColor: "white",
          }}
          inputProps={{ notched: false }}
          variant="outlined"
          displayEmpty
        >
          {targetMetricOptions.map((metricName) => (
            <MenuItem
              value={metricName}
              key={metricName}
              sx={{ fontSize: "14px", minHeight: "32px" }}
            >
              {metricName}
            </MenuItem>
          ))}
        </Select>
      </div>

      <div style={{ display: "flex", gap: "20px", marginBottom: "20px" }}>
        <Paper
          sx={{
            flex: 2,
            p: 2,
            borderRadius: 2,
            boxShadow:
              "0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -2px rgba(0,0,0,0.05)",
            position: "relative",
          }}
        >
          <LineChart
            title={targetMetric}
            datasets={[
              timeseriesDatasets.groundTruth,
              ...timeseriesDatasets.bankPredictions,
            ]}
            min={startDate.getTime()}
            max={endDate.getTime()}
            currentMonth={startOfMonthUTC(new Date()).getTime()}
          />
          <BoxPlot datasets={accuracyDatasets.accuracyBoxplots} />
        </Paper>

        <Paper
          sx={{
            flex: 1,
            borderRadius: 2,
            boxShadow:
              "0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -2px rgba(0,0,0,0.05)",
            position: "relative",
          }}
        >
          <AccuracyTable accuracies={averageBankAccuracies} />
          <AccuraciesChart
            datasets={accuracyDatasets.bankAccuracies}
            min={startDate.getTime()}
            max={endDate.getTime()}
          />
        </Paper>
      </div>

      {timeseriesData && (
        <Paper
          sx={{
            p: 2,
            borderRadius: 2,
            boxShadow:
              "0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -2px rgba(0,0,0,0.05)",
            position: "relative",
          }}
        >
          <EnsembleGraph
            ensemblePredictionDistributions={
              forecastData?.ensemblePredictionDistributions ?? []
            }
          />
          {isForecastLoading && <LoadingOverlay />}
        </Paper>
      )}
    </ChartsContainer>
  );
};

export default React.memo(Charts);
