import {
  ColumnDef,
  flexRender,
  getCoreRowModel,
  useReactTable,
} from '@tanstack/react-table'
import classnames from 'classnames'
import { useEffect, useMemo, useState } from 'react'
import { useSearchParams } from 'react-router-dom'

import { Pagination } from '../../models'
import PaginationNav from '../PaginationNav/PaginationNav'

const Headers = <T,>({
  table,
}: {
  table: ReturnType<typeof useReactTable<T>>
}) => {
  return (
    <>
      {table.getHeaderGroups().map((headerGroup) => (
        <tr key={headerGroup.id}>
          {headerGroup.headers.map((header) => (
            <th
              key={header.id}
              className={classnames(
                'pl-4 pr-3 text-left text-sm font-semibold text-gray-900 py-3 px-3 sm:px-6',
                (header.column.columnDef.meta as ColumnMeta)?.headerClassName
              )}
            >
              {header.isPlaceholder
                ? null
                : flexRender(
                    header.column.columnDef.header,
                    header.getContext()
                  )}
            </th>
          ))}
        </tr>
      ))}
    </>
  )
}

const LoadingRows = ({
  columnCount,
  rowCount,
}: {
  columnCount: number
  rowCount: number
}) => {
  return (
    <>
      {Array.from({ length: rowCount }).map((_, index) => (
        <tr key={index} className="animate-pulse">
          <td className="h-14 px-6 py-3 bg-white" colSpan={columnCount}>
            <div className="h-full w-full rounded-md bg-gray-200"></div>
          </td>
        </tr>
      ))}
    </>
  )
}

const EmptyRow = ({
  text,
  columnCount,
}: {
  text: string
  columnCount: number
}) => {
  return (
    <tr>
      <td
        colSpan={columnCount}
        className="text-center text-sm text-brand-gray-dark py-6"
      >
        {text}
      </td>
    </tr>
  )
}

const Rows = <T,>({
  table,
  onRowClick,
  isLoading,
  emptyText,
}: {
  table: ReturnType<typeof useReactTable<T>>
  onRowClick?: (data: T) => void
  isLoading?: boolean
  emptyText?: string
}) => {
  if (isLoading) {
    return <LoadingRows columnCount={table.getTotalSize()} rowCount={3} />
  }

  if (table.getRowModel().rows.length === 0 && emptyText) {
    return <EmptyRow text={emptyText} columnCount={table.getTotalSize()} />
  }

  return (
    <>
      {table.getRowModel().rows.map((row) => (
        <tr
          key={row.id}
          className={classnames('h-14', {
            'hover:bg-gray-50 cursor-pointer': !!onRowClick,
          })}
          onClick={() => onRowClick?.(row.original)}
        >
          {row.getVisibleCells().map((cell) => (
            <td
              key={cell.id}
              className={classnames(
                'whitespace-nowrap text-sm text-gray-700 py-3 px-3 sm:px-6',
                (cell.column.columnDef.meta as ColumnMeta)?.cellClassName
              )}
            >
              {flexRender(cell.column.columnDef.cell, cell.getContext())}
            </td>
          ))}
        </tr>
      ))}
    </>
  )
}

interface TableProps<T> {
  columns: ColumnDef<T>[]
  data: T[]
  pagination?: Pagination
  isLoading?: boolean
  emptyText?: string
  onPageChange?: (page: number) => void
  onRowClick?: (data: T) => void
}

interface ColumnMeta {
  headerClassName?: string
  cellClassName?: string
}

const usePagination = (
  pagination: Pagination | undefined,
  onPageChange?: (page: number) => void
) => {
  const [searchParams, setSearchParams] = useSearchParams()

  const [paginationState, setPaginationState] = useState({
    pageIndex: pagination?.page ? pagination.page - 1 : 0,
    pageSize: pagination?.pageSize ?? 0,
  })

  // Trigger callback when the pagination changes
  useEffect(() => {
    const currentIndex = paginationState.pageIndex
    if (currentIndex !== undefined) {
      onPageChange?.(currentIndex + 1)
    }
  }, [paginationState, onPageChange])

  // Update the query parameters when the pagination changes
  useEffect(() => {
    if (pagination?.page && pagination?.page > 1) {
      searchParams.set('page', String(pagination?.page))
      setSearchParams(searchParams)
    } else {
      searchParams.delete('page')
      setSearchParams(searchParams)
    }
  }, [pagination, setSearchParams])

  return {
    paginationState,
    setPaginationState,
  }
}

const Table = <T,>({
  columns,
  data,
  pagination,
  isLoading,
  emptyText,
  onPageChange,
  onRowClick,
}: TableProps<T>) => {
  const fallbackColumns: ColumnDef<T>[] = useMemo(() => [], [])
  const fallbackData = [] as T[]

  const { paginationState, setPaginationState } = usePagination(
    pagination,
    onPageChange
  )

  const table = useReactTable({
    data: data ?? fallbackData,
    columns: columns ?? fallbackColumns,
    getCoreRowModel: getCoreRowModel(),
    manualPagination: true,
    rowCount: pagination?.total,
    onPaginationChange: setPaginationState,
    state: {
      pagination: paginationState,
    },
  })

  return (
    <table className="w-full table-auto divide-y divide-gray-200 bg-white">
      <thead>
        <Headers table={table} />
      </thead>
      <tbody className="divide-y divide-gray-200">
        <Rows
          table={table}
          onRowClick={onRowClick}
          isLoading={isLoading}
          emptyText={emptyText}
        />
      </tbody>
      <tfoot>
        {pagination && pagination.pageCount > 1 && (
          <tr>
            <td colSpan={table.getTotalSize()}>
              <PaginationNav
                pagination={pagination}
                onPrevious={() => table.previousPage()}
                onNext={() => table.nextPage()}
              />
            </td>
          </tr>
        )}
      </tfoot>
    </table>
  )
}

export default Table
