import React, {
  useEffect,
  useImperativeHandle,
  useMemo,
  useState,
} from 'react';
import {
  flexRender,
  useReactTable,
  ColumnDef,
  SortingState,
  getCoreRowModel,
  RowSelectionState,
  getSortedRowModel,
  ColumnFiltersState,
  getFilteredRowModel,
  FilterFn,
  SortingFn,
  Table,
  Row,
  RowSelectionOptions,
} from '@tanstack/react-table';
import {
  RankingInfo,
  rankItem,
  compareItems,
} from '@tanstack/match-sorter-utils';
import cx from 'classnames';

import { Props as TagsListProps } from '../tagsList/TagsList';
import styles from './MultiSelectTable.module.css';
import Checkbox from '../atoms/checkbox/Checkbox';
import usePrevious from '../../hooks/usePrevious';
import { isLabelledValue } from './helpers/isLabelledValue';

declare module '@tanstack/table-core' {
  interface FilterFns {
    tags: FilterFn<unknown>;
    fuzzy: FilterFn<unknown>;
  }
  interface FilterMeta {
    itemRank: RankingInfo;
  }
  interface SortingFns {
    fuzzy: SortingFn<unknown>;
    count: SortingFn<unknown>;
    number: SortingFn<unknown>;
  }
}

const tagsFilter: FilterFn<any> = (row, columnId, value) => {
  if (value.length === 0) return true;
  const rowTagsOptions = (row.getValue(
    columnId
  ) as any) as TagsListProps['options'];
  return value.some((tag: string) =>
    rowTagsOptions.find(({ label }) => tag === label)
  );
};

const fuzzyFilter: FilterFn<any> = (row, columnId, value, addMeta) => {
  // raw value by default
  let rowValue = row.getValue(columnId);

  // if value is an object with a label, use the label
  if (isLabelledValue(rowValue)) {
    rowValue = rowValue.label;
  }

  // Rank the item
  const itemRank = rankItem(rowValue, value);

  // Store the itemRank info
  addMeta({
    itemRank,
  });

  // Return if the item should be filtered in/out
  return itemRank.passed;
};

const fuzzySort: SortingFn<any> = (rowA, rowB, columnId) => {
  let dir = 0;

  // Only sort by rank if the column has ranking information
  if (rowA.columnFiltersMeta[columnId]) {
    dir = compareItems(
      rowA.columnFiltersMeta[columnId]?.itemRank,
      rowB.columnFiltersMeta[columnId]?.itemRank
    );
  }

  // use the item rank if it exists
  if (dir !== 0) return dir;

  // Provide an alphanumeric fallback for when the item ranks are equal
  let rowAValue = rowA.getValue(columnId) as string;
  if (isLabelledValue(rowAValue)) {
    rowAValue = rowAValue.label;
  }

  let rowBValue = rowB.getValue(columnId) as string;
  if (isLabelledValue(rowBValue)) {
    rowBValue = rowBValue.label;
  }

  if (rowAValue === rowBValue) return 0;
  return rowAValue < rowBValue ? -1 : 1;
};

const numberSort: SortingFn<any> = (rowA, rowB, columnId) => {
  const aValue = rowA.getValue(columnId);
  const bValue = rowB.getValue(columnId);
  return Number(aValue) - Number(bValue);
};

const countSort: SortingFn<any> = (rowA, rowB, columnId) => {
  const aValue = rowA.getValue(columnId);
  const bValue = rowB.getValue(columnId);
  if (!Array.isArray(aValue) || !Array.isArray(bValue)) return 0;
  return aValue.length - bValue.length;
};

export type Props<RowData extends Record<string, any>> = {
  tableRef?: React.Ref<Table<RowData> | null>;
  data: RowData[];
  columns: ColumnDef<RowData>[];
  getRowId: (
    originalRow: RowData,
    index: number,
    parent?: Row<RowData>
  ) => string;
  onSelectedRowsChange?: (rows: RowData[]) => void;
  initialRowSelection?: RowSelectionState;
  enableRowSelection?: RowSelectionOptions<RowData>['enableRowSelection'];
  disabled?: boolean;
  className?: string;
};

const MultiSelectTable = <RowData extends Record<string, any>>({
  tableRef,
  data,
  columns: columnsProp,
  getRowId,
  onSelectedRowsChange = () => {},
  initialRowSelection = {},
  enableRowSelection,
  disabled,
  className,
}: Props<RowData>) => {
  const [sorting, setSorting] = useState<SortingState>([]);
  const [rowSelection, setRowSelection] = useState<RowSelectionState>(
    initialRowSelection
  );
  const [columnFilters, setColumnFilters] = useState<ColumnFiltersState>([]);

  const columns = useMemo(
    () => [
      {
        id: 'select',
        header: ({ table }) => (
          <th key="select" className={styles.selectColumn}>
            <Checkbox
              type="checkbox"
              checked={table.getIsAllRowsSelected()}
              indeterminate={table.getIsSomeRowsSelected()}
              onChange={table.getToggleAllRowsSelectedHandler()}
              disabled={disabled}
            />
          </th>
        ),
        cell: ({ row }) => (
          <td key="select" className={styles.selectColumn}>
            <Checkbox
              type="checkbox"
              checked={row.getIsSelected()}
              disabled={!row.getCanSelect() || disabled}
              indeterminate={row.getIsSomeSelected()}
              onChange={row.getToggleSelectedHandler()}
            />
          </td>
        ),
      },
      ...columnsProp,
    ],
    [columnsProp, disabled]
  );

  const table = useReactTable({
    data,
    columns,
    filterFns: {
      fuzzy: fuzzyFilter,
      tags: tagsFilter,
    },
    sortingFns: {
      fuzzy: fuzzySort,
      count: countSort,
      number: numberSort,
    },
    state: {
      rowSelection,
      sorting,
      columnFilters,
    },
    getCoreRowModel: getCoreRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    getRowId,
    onSortingChange: setSorting,
    onColumnFiltersChange: setColumnFilters,
    onRowSelectionChange: setRowSelection,
    enableRowSelection:
      typeof enableRowSelection === 'undefined'
        ? !disabled
        : enableRowSelection,
  });

  useImperativeHandle(tableRef, () => table, [table]);

  const previousRowSelection = usePrevious(rowSelection);
  useEffect(() => {
    if (previousRowSelection && previousRowSelection !== rowSelection) {
      onSelectedRowsChange(
        table.getSelectedRowModel().rows.map((row) => row.original)
      );
    }
  }, [table, rowSelection, previousRowSelection, onSelectedRowsChange]);

  return (
    <table className={cx(styles.table, className)}>
      <thead>
        {table.getHeaderGroups().map((headerGroup) => (
          <tr key={headerGroup.id}>
            {headerGroup.headers.map((header) =>
              flexRender(header.column.columnDef.header, header.getContext())
            )}
          </tr>
        ))}
      </thead>

      <tbody>
        {table.getRowModel().rows.map((row) => (
          <tr
            key={row.id}
            className={cx({ [styles.selected]: row.getIsSelected() })}
          >
            {row
              .getVisibleCells()
              .map((cell) =>
                flexRender(cell.column.columnDef.cell, cell.getContext())
              )}
          </tr>
        ))}
      </tbody>
    </table>
  );
};

export default MultiSelectTable;
