import React, { useState } from "react";

import { scaleLinear, scaleTime } from "d3";

import { useDimensions } from "utils/hooks";
import { translate } from "utils/svg";

import NoData from "./NoData";
import Axis, { getTickFormatFn } from "./Axis";
import LoadingErrorBoundary from "../LoadingErrorBoundary";
import Foreground from "./Foreground";
import Legend from "./Legend";
import {
  Layout,
  Data,
  xValue,
  prepareStacked,
  AnyDatum,
  yTop,
  SeriesConfig,
  isStacked,
  isLeftAxis,
  isRightAxis,
  extent,
  SeriesInfo,
  adjustTimestamps,
  applySeriesDefaults,
  applyAxisDefaults,
  InteractionPoint,
  Scale,
  formatTimeAdaptive,
  Datum,
  UserAxisConfig,
} from "./util";
import { LineSeries } from "./Series";
import Palette from "utils/palette";

import styles from "./style.module.scss";

type AllAxesConfig = {
  left?: UserAxisConfig;
  bottom?: UserAxisConfig;
  right?: UserAxisConfig;
};

export type Props = {
  axes?: AllAxesConfig;
  small?: boolean;
  series: SeriesConfig[];
  data: Data;
  epochTimeUnit?: "second" | "millisecond";
  backdrop?: (props: Layout["size"] & { xScale: Scale }) => React.ReactElement;
  onClick?: (target: InteractionPoint) => void;
  onRangeSelected?: (start: InteractionPoint, end: InteractionPoint) => void;
};

type DisabledSeries = {
  [key: string]: boolean;
};

const GraphImpl: React.FunctionComponent<
  Props & {
    width: number;
    height: number;
  }
> = ({
  width,
  height,
  axes: rawAxes,
  series: rawSeries,
  data: rawData,
  epochTimeUnit = "second",
  small = false,
  backdrop,
  onClick,
  onRangeSelected,
}) => {
  const series = rawSeries
    .filter((s) => rawData[s.key])
    .map(applySeriesDefaults);
  const axes = Object.entries(rawAxes || {}).reduce<AllAxesConfig>(
    (newAxes, [currType, currConfig]) => {
      if (currConfig !== undefined) {
        newAxes[currType] = applyAxisDefaults(currConfig);
      }
      return newAxes;
    },
    {},
  );
  const hasRightSeries = series.filter(isRightAxis).length > 0;

  const leftGutter = small ? 0 : 65;
  const rightGutter = small ? 0 : hasRightSeries ? 80 : 25;
  const topGutter = small ? 0 : 35;
  const bottomGutter = small ? 0 : 30;

  const bodyDims: Layout = {
    x: leftGutter,
    y: topGutter,
    size: {
      width: width - leftGutter - rightGutter,
      height: height - topGutter - bottomGutter,
    },
  };
  const legendDims: Layout = {
    x: leftGutter,
    y: 5,
    size: {
      width: bodyDims.size.width,
      height: topGutter - 5,
    },
  };
  const leftAxisDims: Layout = {
    x: 0,
    y: topGutter,
    size: {
      width: leftGutter,
      height: bodyDims.size.height,
    },
  };
  const bottomAxisDims = {
    x: leftGutter,
    y: topGutter + bodyDims.size.height,
    size: {
      width: bodyDims.size.width,
      height: bottomGutter,
    },
  };
  const rightAxisDims = {
    x: leftGutter + bodyDims.size.width,
    y: topGutter,
    size: {
      width: rightGutter,
      height: bodyDims.size.height,
    },
  };

  const relevantData = Object.entries(rawData)
    .filter(([k]) => series.find((s) => s.key === k))
    .reduce((obj, [k, v]) => {
      obj[k] = v;
      return obj;
    }, {});
  // Not an elegant solution, but there doesn't seem to be a great way to get d3's
  // timeScale to play nice with (second granularity) unix timestamps
  const data =
    epochTimeUnit === "millisecond"
      ? relevantData
      : adjustTimestamps(relevantData);

  const [seriesDisabled, setSeriesDisabled] = useState<DisabledSeries>(
    series.reduce((map, curr) => {
      map[curr.key] = curr.defaultDisabled || data[curr.key].length === 0;
      return map;
    }, {}),
  );

  const activeSeries = series.filter((s) => !seriesDisabled[s.key]);
  const stackedSeries = activeSeries.filter(isStacked);
  const stackedData = prepareStacked(data, stackedSeries) || {};

  const getSeriesData = (s: SeriesConfig): AnyDatum[] => {
    if (isStacked(s)) {
      return stackedData[s.key];
    } else {
      return data[s.key];
    }
  };

  const xRangePadding = Math.max.apply(
    null,
    activeSeries.map((s) => s.type?.xPadding || 0),
  );
  const bottomDomain = extent(Object.values(data), xValue);
  const bottomWidth = bottomAxisDims.size.width;
  const bottomScale = scaleTime()
    .domain(bottomDomain)
    .range([xRangePadding, bottomWidth - xRangePadding]);
  if (xRangePadding > 0) {
    // N.B.: for graphs like bar graphs, we need to ensure the range has enough room for
    // the width of the bars at the ends: we do that by first padding the range, and then
    // adjusting the domain based on that padded range
    const adjustedStart = bottomScale.invert(0).getTime();
    const adjustedEnd = bottomScale.invert(bottomWidth).getTime();
    bottomScale.domain([adjustedStart, adjustedEnd]).range([0, bottomWidth]);
  }

  const leftSeries = activeSeries.filter(isLeftAxis).map(getSeriesData);
  const leftDomain = extent(leftSeries, yTop);
  const leftStartZero = [0, leftDomain[1]];
  const leftScale = scaleLinear()
    .domain(leftStartZero)
    .nice()
    .range([leftAxisDims.size.height, 0]);

  const rightSeries = activeSeries.filter(isRightAxis).map(getSeriesData);
  const rightDomain = extent(rightSeries, yTop);
  const rightStartZero = [0, rightDomain[1]];
  const rightScale = scaleLinear()
    .domain(rightStartZero)
    .nice()
    .range([rightAxisDims.size.height, 0]);

  const leftTickFormat = axes?.left?.format;
  const leftFormatFn = getTickFormatFn(axes?.left?.tipFormat, leftScale);

  const rightTickFormat = axes?.right?.format;
  const rightFormatFn = getTickFormatFn(axes?.right?.tipFormat, rightScale);

  const bottomTickFormat = axes?.bottom?.format || formatTimeAdaptive;

  // N.B.: here we also include inactive series, because we still want those
  // to, e.g., appear in the legend
  const palette = new Palette();
  const seriesInfo = series.map<SeriesInfo>((s: SeriesConfig) => {
    const leftAxis = !s.yAxis || s.yAxis === "left";
    const color = palette.getStrokeAndFill(s.color);
    return {
      key: s.key,
      label: s.label || s.key,
      tipLabel: s.tipLabel || s.label || s.key,
      disabled: seriesDisabled[s.key],
      type: s.type || LineSeries,
      tipFormat: leftAxis ? leftFormatFn : rightFormatFn,
      scale: leftAxis ? leftScale : rightScale,
      stroke: color.stroke,
      fill: color.fill,
      className: s.className,
      opts: s.opts,
    };
  });
  const activeSeriesInfo = seriesInfo.filter((s) => !s.disabled);

  const handleSeriesLegendClick = (key: string) => {
    setSeriesDisabled((curr) => ({ ...curr, [key]: !curr[key] }));
  };

  return (
    <svg width={width} height={height} preserveAspectRatio="none">
      <rect fill="white" width={width} height={height} />
      {backdrop && (
        <g transform={translate(bodyDims.x, bodyDims.y)}>
          {backdrop({ ...bodyDims.size, xScale: bottomScale })}
        </g>
      )}
      {!small && (
        <g transform={translate(leftAxisDims.x, leftAxisDims.y)}>
          <Axis
            placement="left"
            scale={leftScale}
            tickFormat={leftTickFormat}
            {...leftAxisDims.size}
          />
        </g>
      )}
      {!small && hasRightSeries && (
        <g transform={translate(rightAxisDims.x, rightAxisDims.y)}>
          <Axis
            placement="right"
            scale={rightScale}
            tickFormat={rightTickFormat}
            {...rightAxisDims.size}
          />
        </g>
      )}
      {!small && (
        <g transform={translate(bottomAxisDims.x, bottomAxisDims.y)}>
          <Axis
            placement="bottom"
            scale={bottomScale}
            tickFormat={bottomTickFormat}
            {...bottomAxisDims.size}
          />
        </g>
      )}
      {!small && (
        <g transform={translate(legendDims.x, legendDims.y)}>
          <Legend
            onToggleDisabled={handleSeriesLegendClick}
            series={seriesInfo}
            {...legendDims.size}
          />
        </g>
      )}
      <g transform={translate(bodyDims.x, bodyDims.y)}>
        <Foreground
          series={activeSeriesInfo}
          data={data}
          stackedData={stackedData}
          xScale={bottomScale}
          onClick={onClick}
          onRangeSelected={onRangeSelected}
          {...bodyDims.size}
          small={small}
        />
      </g>
    </svg>
  );
};

const Wrapper: React.FunctionComponent<Props> = (props) => {
  const [ref, dimensions] = useDimensions<HTMLDivElement>();
  return (
    <div
      ref={ref}
      className={props.small ? styles.wrapperSmall : styles.wrapper}
    >
      {dimensions && (
        <LoadingErrorBoundary>
          <GraphImpl
            width={dimensions.width}
            height={dimensions.height}
            {...props}
          />
        </LoadingErrorBoundary>
      )}
    </div>
  );
};

const Graph: React.FunctionComponent<Props> = ({ data, ...rest }) => {
  const hasData = (v: Datum[]) => Array.isArray(v) && v.length > 0;
  const noData =
    !data ||
    (Object.keys(data).length === 1 && "No Data" in data) ||
    Object.values(data).filter(hasData).length === 0;
  if (noData) {
    return (
      <div className={styles.wrapper}>
        <NoData />
      </div>
    );
  }
  return <Wrapper data={data} {...rest} />;
};

export default Graph;
