'use client'
import { useRef, useState, useCallback, ReactNode } from 'react';
import cn from 'classnames';
import { ChartContext, ChartContextType, YScale } from './ChartContext';
import { LoadingSpinner } from '../LoadingSpinner';
import { HeatMapProvider, HeatMapProviderProps } from '../HeatMap';
import { BarProvider, BarProviderProps } from '../Bar';
import moduleStyles from './Chart.module.scss';

export type ChartProps = {
  className?: string;
  useCanvas?: boolean;
  chartClassName?: string;
  barConfig?: BarProviderProps;
  heatMapConfig?: HeatMapProviderProps;
  children: ReactNode | ReactNode[];
}

export function Chart(props: ChartProps) {
  const {
    className,
    children,
    chartClassName,
    useCanvas,
    barConfig,
    heatMapConfig,
    ...rest
  } = props;

  const canvasRef = useRef<HTMLCanvasElement>(null);
  const svgRef = useRef<SVGSVGElement>(null);

  const yAxes = useRef<Map<string, { valueLength: number }>>(new Map());
  const xAxes = useRef<Map<string, Record<string, never>>>(new Map());
  const numericYAxisID = useRef<string>();
  const canvasRefreshCounter = useRef(0);
  const [allYScales, setAllYScales] = useState<YScale[]>([]);
  const [forcedUpdateCounter, setForcedUpdateCounter] = useState(0);

  const forceUpdate = useCallback(() => {
    setForcedUpdateCounter(forcedUpdateCounter + 1);
  }, [setForcedUpdateCounter, forcedUpdateCounter]);

  const clearCanvas = useCallback(() => {
    if (!useCanvas || !canvasRef.current) return;
    const context = canvasRef.current.getContext('2d');
    if (!context) return;
    context.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
    canvasRefreshCounter.current += 1;
  }, [useCanvas, canvasRef]);

  const registerYScale = useCallback((yScale: YScale) => {
    clearCanvas();
    setAllYScales((old) => [...old, yScale]);
  }, [allYScales, setAllYScales, clearCanvas]);

  const unregisterYScale = useCallback((yScale: YScale) => {
    clearCanvas();
    setAllYScales(allYScales.filter((_yScale) => _yScale !== yScale));
  }, [allYScales, setAllYScales, clearCanvas]);

  const registerYAxis = useCallback((yAxisID: string, usingNumericScale = false, longestExpectedValueLabelLength = 4) => {
    if (numericYAxisID.current && numericYAxisID.current !== yAxisID && usingNumericScale) {
      throw new Error('Numeric Y Axis already registered');
    }

    if (!yAxes.current.has(yAxisID)) {
      yAxes.current.set(yAxisID, {
        valueLength: longestExpectedValueLabelLength
      });

      if (usingNumericScale) numericYAxisID.current = yAxisID;

      clearCanvas();
      forceUpdate();
    }
  }, [numericYAxisID, yAxes, clearCanvas, forceUpdate]);

  const unregisterYAxis = useCallback((yAxisID: string) => {
    yAxes.current.delete(yAxisID);

    if (numericYAxisID.current == yAxisID) {
      numericYAxisID.current = undefined;
    }

    clearCanvas();
    forceUpdate();
  }, [numericYAxisID, yAxes, clearCanvas, forceUpdate]);

  const registerXAxis = useCallback((xAxisID: string) => {
    if (!xAxes.current.has(xAxisID)) {
      xAxes.current.set(xAxisID, {});
      clearCanvas();
      forceUpdate();
    }
  }, [xAxes, clearCanvas, forceUpdate]);

  const unregisterXAxis = useCallback((xAxisID: string) => {
    xAxes.current.delete(xAxisID);
    clearCanvas();
    forceUpdate();
  }, [xAxes, clearCanvas, forceUpdate]);

  const numericScale = allYScales
    .filter(({ valueLabels }) => !valueLabels)
    .reduce((scale, { min, max }) => ({
      ...scale,
      min: Math.min(scale.min, min),
      max: Math.max(scale.max, max)
    }), { min: Infinity, max: -Infinity });

  let yAxesWidth = 0;
  const yAxesPositioning = Array.from(yAxes.current.keys()).reduce((pos, axisId, index) => {
    const left = yAxesWidth;
    const yAxis = yAxes.current.get(axisId);
    if (!yAxis) return pos;

    let width = (yAxis.valueLength * 7);
    if (index > 0) width += 15;
    yAxesWidth += width;

    return {
      ...pos,
      [axisId]: {
        left,
        width
      }
    };
  }, {});

  let width = 0;
  let height = 0;

  if (useCanvas && canvasRef.current) {
    const parentBounding = canvasRef.current.parentElement?.getBoundingClientRect();
    if (parentBounding) {
      width = parentBounding.width;
      height = parentBounding.height;
    }
  } else if (svgRef.current) {
    const parentBounding = svgRef.current.parentElement?.getBoundingClientRect();
    if (parentBounding) {
      width = parentBounding.width;
      height = parentBounding.height;
    } else {
      width = svgRef.current.clientWidth;
      height = svgRef.current.clientHeight;
    }
  }

  const chartPadTop = 5;
  let chartWidth = width;
  let chartPadLeft = 0;
  let chartHeight = height - chartPadTop;

  if (yAxesWidth > 0) {
    chartWidth -= yAxesWidth + 15;
    chartPadLeft += yAxesWidth + 15;
  }

  const xAxesPositioning = Array.from(xAxes.current.keys()).reduce((pos, axisId, index) => {
    if (index > 0) chartHeight -= 10;
    if (index == 0) chartHeight -= 5;

    return {
      ...pos,
      [axisId]: {
        top: chartHeight + chartPadTop,
        height: 10
      }
    };
  }, {});

  if (xAxes.current.size > 0) chartHeight -= 15;

  const contextValue: ChartContextType = {
    chartWidth,
    chartHeight,
    chartPadTop,
    chartPadLeft,

    canvasRef,
    isCanvas: !!useCanvas,
    canvasRefreshCounter: canvasRefreshCounter.current,

    registerYScale,
    unregisterYScale,

    registerYAxis,
    unregisterYAxis,

    registerXAxis,
    unregisterXAxis,

    numericScale,

    yAxesPositioning,
    xAxesPositioning
  };

  const chartChildren = (
    <ChartContext.Provider value={contextValue}>
      <BarProvider {...barConfig}>
        <HeatMapProvider {...heatMapConfig}>
          {children}
        </HeatMapProvider>
      </BarProvider>
    </ChartContext.Provider>
  );

  if (useCanvas) {
    const dpr = window.devicePixelRatio || 1;

    const canvasWidth = width * dpr;
    const canvasHeight = height * dpr;
    const style = { width, height };

    return (
      <div
        className={cn(moduleStyles['chart'], className)}>
        {allYScales.length == 0 ? (
          <LoadingSpinner className={moduleStyles['loading-spinner']} />
        ) : null}
        <canvas
          width={canvasWidth}
          height={canvasHeight}
          style={style}
          className={chartClassName}
          ref={canvasRef}
          {...rest}>
          {chartChildren}
        </canvas>
      </div>
    );
  }

  return (
    <div
      className={cn(moduleStyles['chart'], className)}>
      {allYScales.length == 0 ? (
        <LoadingSpinner className={moduleStyles['loading-spinner']} />
      ) : null}
      <svg
        xmlns="http://www.w3.org/2000/svg"
        viewBox={`0 0 ${width} ${height}`}
        className={chartClassName}
        ref={svgRef}
        {...rest}>
        {chartChildren}
      </svg>
    </div>
  );
};
