import { useContext, useState, useEffect, useId } from 'react';
import { ChartContext, ChartDataSourceContext } from '../Chart';
import { translateColorToSafeString } from '../../helpers';

type YAxisProps = {
  numericTicks?: number,
  tickFormat?: (tick?: string | number) => string,
  decimals?: number,
  label?: string,
  gridStroke?: string,
  gridStrokeWidth?: number,
  textColor?: string
}

const defaultTickFormat = (tick: string | number): string => {
  if (isNaN(tick as number)) return tick.toString();
  return parseFloat(tick as string).toFixed(1);
};

export function YAxis(props: YAxisProps) {
  const {
    numericTicks = 5,
    tickFormat = defaultTickFormat,
    decimals = 1,
    gridStroke = '#F1F1F5',
    gridStrokeWidth = 2,
    textColor = '#6F7278',
    label = ''
  } = props;

  const { registerYAxis, unregisterYAxis, yAxesPositioning, chartWidth, chartHeight, chartPadTop, chartPadLeft, isCanvas, canvasRef, canvasRefreshCounter } = useContext(ChartContext);
  const { valueScale } = useContext(ChartDataSourceContext);
  const usingNumericScale = valueScale && !valueScale.valueLabels;

  const id = useId();
  const [registered, setRegistered] = useState(false);

  useEffect(() => {
    try {
      let longestExpectedValueLabelLength;

      if (valueScale && valueScale.valueLabels) {
        longestExpectedValueLabelLength = Object.values(valueScale.valueLabels).sort((a, b) => b.length - a.length)[0].length;
      } else {
        longestExpectedValueLabelLength = valueScale.max.toFixed(decimals).length;
      }

      registerYAxis(id, usingNumericScale, longestExpectedValueLabelLength);
      setRegistered(true);
    } catch (e) {
      setRegistered(false);
    }

    return () => unregisterYAxis(id);
  }, [id, usingNumericScale]);

  useEffect(() => {
    if (!isCanvas || !canvasRef?.current) return;
    if (!registered) return;

    const ctx = canvasRef.current.getContext('2d');
    const dpr = window.devicePixelRatio || 1;
    const positioning = yAxesPositioning[id];

    if (usingNumericScale && ctx) {
      const range = valueScale.max - valueScale.min,
        step = Math.round((range / numericTicks) * Math.pow(10, decimals)) / Math.pow(10, decimals);

      if (range == -Infinity) return;

      for (let i = 1; i < numericTicks; i++) {
        const top = range == 0 ? 1 : 1 - ((i * step) / range);

        ctx.beginPath();
        ctx.strokeStyle = translateColorToSafeString(gridStroke, canvasRef.current);
        ctx.moveTo(chartPadLeft * dpr, ((top * chartHeight) + chartPadTop) * dpr);
        ctx.lineTo((chartWidth + chartPadLeft) * dpr, ((top * chartHeight) + chartPadTop) * dpr);
        ctx.stroke();

        if (positioning) {
          ctx.beginPath();
          ctx.strokeStyle = translateColorToSafeString(textColor, canvasRef.current);
          ctx.fillStyle = translateColorToSafeString(textColor, canvasRef.current);
          ctx.lineWidth = 1;
          ctx.textAlign = 'end';
          ctx.textBaseline = 'middle';
          ctx.font = `${dpr * 12}px Barlow`;
          ctx.fillText(
            tickFormat(Math.round((valueScale.min + (i * step)) * Math.pow(10, decimals)) / Math.pow(10, decimals)),
            (positioning.left + positioning.width) * dpr,
            ((top * chartHeight) + chartPadTop) * dpr
          );
          ctx.stroke();
        }
      }
    } else if (ctx) {
      const range = valueScale.max - valueScale.min;

      if (!valueScale.valueLabels) return;

      for (const i of valueScale.valueLabels.keys()) {
        const top = 1 - (i / range);

        ctx.beginPath();
        ctx.strokeStyle = translateColorToSafeString(gridStroke, canvasRef.current);
        ctx.moveTo(chartPadLeft * dpr, ((top * chartHeight) + chartPadTop) * dpr);
        ctx.lineTo((chartWidth + chartPadLeft) * dpr, ((top * chartHeight) + chartPadTop) * dpr);
        ctx.stroke();

        if (positioning) {
          ctx.beginPath();
          ctx.strokeStyle = translateColorToSafeString(textColor, canvasRef.current);
          ctx.fillStyle = translateColorToSafeString(textColor, canvasRef.current);
          ctx.lineWidth = 1;
          ctx.textAlign = 'end';
          ctx.textBaseline = 'middle';
          ctx.font = `${dpr * 12}px Barlow`;
          ctx.fillText(
            valueScale.valueLabels.get(i) || ``,
            (positioning.left + positioning.width) * dpr,
            ((top * chartHeight) + chartPadTop) * dpr
          );
          ctx.stroke();
        }
      }
    }

  }, [isCanvas, canvasRef, canvasRefreshCounter, chartWidth, chartHeight, chartPadTop, chartPadLeft, valueScale, usingNumericScale]);

  if (isCanvas) return null;
  if (!registered) return null;

  const positioning = yAxesPositioning[id];
  const axisNodes = [];

  if (usingNumericScale) {
    const range = valueScale.max - valueScale.min,
      step = Math.round((range / numericTicks) * Math.pow(10, decimals)) / Math.pow(10, decimals);

    if (range == -Infinity) return null;

    for (let i = 1; i < numericTicks; i++) {
      const top = range == 0 ? 1 : 1 - ((i * step) / range);

      axisNodes.push(
        <polyline
          key={`y-${id}-${i}`}
          stroke={gridStroke}
          strokeWidth={gridStrokeWidth}
          points={[
            { top, left: 0 },
            { top, left: 1 }
          ].map((point) => `${(point.left * chartWidth) + chartPadLeft},${(point.top * chartHeight) + chartPadTop}`).join(' ')} />
      );

      if (positioning) {
        const { left, width } = positioning;

        axisNodes.push(
          <text
            key={`y-label-${id}-${i}`}
            fill={textColor}
            style={{
              fontVariantNumeric: 'tabular-nums',
              fontSize: 12
            }}
            textAnchor="end"
            dominantBaseline="middle"
            x={left + width}
            y={(top * chartHeight) + chartPadTop}>
            {tickFormat(Math.round((valueScale.min + (i * step)) * Math.pow(10, decimals)) / Math.pow(10, decimals))}
          </text>
        );
      }
    }
  } else if (valueScale.valueLabels) {
    const range = valueScale.max - valueScale.min;

    for (const i of valueScale.valueLabels.keys()) {
      const top = 1 - (i / range);

      axisNodes.push(
        <polyline
          key={`y-${id}-${i}`}
          stroke={gridStroke}
          strokeWidth={gridStrokeWidth}
          points={[
            { top, left: 0 },
            { top, left: 1 }
          ].map((point) => `${(point.left * chartWidth) + chartPadLeft},${(point.top * chartHeight) + chartPadTop}`).join(' ')} />
      );

      if (positioning) {
        const { left, width } = positioning;

        axisNodes.push(
          <text
            key={`y-label-${id}-${i}`}
            fill={textColor}
            style={{
              fontVariantNumeric: 'tabular-nums',
              fontSize: label ? 8 : 12
            }}
            textAnchor="end"
            dominantBaseline="middle"
            x={left + width}
            y={(top * chartHeight) + chartPadTop}>
            {valueScale.valueLabels.get(i)}
          </text>
        );
      }
    }
  }

  if (label) {
    axisNodes.push(
      <text
        key={`y-label-${id}`}
        fill={textColor}
        style={{
          fontVariantNumeric: 'tabular-nums',
          fontSize: 12
        }}
        textAnchor="middle"
        dominantBaseline="middle"
        transform={`rotate(-90, ${positioning.left + 15} ${chartHeight / 2})`}
        x={positioning.left + 15}
        y={chartHeight / 2}>
        {label}
      </text>
    );
  }

  return (
    <>
      {axisNodes}
    </>
  );
};

