import { range, sum } from "lodash-es";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";

interface IUseGridVirtualizationProps {
  containerRef: React.RefObject<HTMLDivElement>;
  rowCount: number;
  columnCount: number;
  getRowHeight: (index: number) => number;
  columnWidths: number[];
  headerHeight: number;
  pinnedColumns?: {
    left?: number[];
    right?: number[];
  };
  pinnedRows?: {
    top?: number[];
    bottom?: number[];
  };
  isRowLoaded: (index: number) => boolean;
  onLoadMore?: () => Promise<void>;
  rowOverscan?: number;
  columnOverscan?: number;
}

interface IUseGridVirtualizationResult {
  onScroll: (e: React.UIEvent<HTMLDivElement>) => Promise<void>;
  visibleRows: IVirtualRow[];
  visibleColumns: IVirtualColumn[];
  totalWidth: number;
  totalHeight: number;
  scrollToCell: (rowId: number, colId: number) => void;
}

export interface IVirtualRow {
  index: number;
  top: number;
  height: number;
  pinned?: "top" | "bottom";
}

export type IVirtualColumn = {
  index: number;
  width: number;
} & (
  | {
      pinned: "left";
      left: number;
    }
  | {
      pinned: "right";
      right: number;
    }
  | {
      pinned: null;
      left: number;
    }
);

export const useGridVirtualization = ({
  containerRef,
  rowCount,
  columnCount,
  getRowHeight,
  headerHeight,
  columnWidths,
  pinnedColumns,
  pinnedRows,
  isRowLoaded,
  onLoadMore,
  rowOverscan = 10,
  columnOverscan = 2,
}: IUseGridVirtualizationProps): IUseGridVirtualizationResult => {
  const [scrollTop, setScrollTop] = useState(0);
  const [scrollLeft, setScrollLeft] = useState(0);
  const [containerWidth, setContainerWidth] = useState(0);
  const [containerHeight, setContainerHeight] = useState(0);
  const [isLoadingMore, setIsLoadingMore] = useState(false);
  const loadMoreRef = useRef<Promise<void> | null>(null);
  const prevScrollTopRef = useRef(0);

  const totalWidth = useMemo(() => {
    const columnsWidth = sum(columnWidths);

    return Math.max(columnsWidth, containerWidth);
  }, [columnWidths, containerWidth]);

  const totalHeight = useMemo(() => {
    const pinnedTopIndexes = new Set(pinnedRows?.top ?? []);
    const pinnedBottomIndexes = new Set(pinnedRows?.bottom ?? []);

    // Includes the header in the calculation as it's a pinned top row
    const topPinnedRowTotalHeight = headerHeight + sum((pinnedRows?.top ?? []).map(getRowHeight));
    const bottomPinnedRowTotalHeight = sum((pinnedRows?.bottom ?? []).map(getRowHeight));

    const nonPinnedRowTotalHeight = sum(
      range(rowCount)
        .filter((i) => !pinnedTopIndexes.has(i) && !pinnedBottomIndexes.has(i))
        .map(getRowHeight),
    );

    const height = topPinnedRowTotalHeight + bottomPinnedRowTotalHeight + nonPinnedRowTotalHeight;

    return Math.max(height, containerHeight);
  }, [getRowHeight, rowCount, pinnedRows?.top, pinnedRows?.bottom, headerHeight, containerHeight]);

  useEffect(() => {
    const updateContainerDimensions = (): void => {
      const width = containerRef.current?.clientWidth ?? 0;
      const height = containerRef.current?.clientHeight ?? 0;

      setContainerWidth(width);
      setContainerHeight(height);
    };

    updateContainerDimensions();

    window.addEventListener("resize", updateContainerDimensions);

    return (): void => {
      window.removeEventListener("resize", updateContainerDimensions);
    };
  }, [containerRef]);

  // Calculate accumulated column offsets, accounting for sticky positioning
  const columnOffsets = useMemo(() => {
    const offsets = [];
    let leftPinnedWidth = 0;
    let rightPinnedWidth = 0;
    let scrollableOffset = 0;

    const leftPinnedIndexes = new Set(pinnedColumns?.left ?? []);
    const rightPinnedIndexes = new Set(pinnedColumns?.right ?? []);

    // Calculate total pinned widths first
    pinnedColumns?.left?.forEach((index) => {
      leftPinnedWidth += columnWidths[index] ?? 0;
    });

    pinnedColumns?.right?.forEach((index) => {
      rightPinnedWidth += columnWidths[index] ?? 0;
    });

    // Calculate offsets for left-pinned columns
    let leftOffset = 0;

    pinnedColumns?.left?.forEach((index) => {
      offsets[index] = leftOffset;
      leftOffset += columnWidths[index] ?? 0;
    });

    // Calculate offsets for scrollable columns
    for (let i = 0; i < columnWidths.length; i++) {
      if (!leftPinnedIndexes.has(i) && !rightPinnedIndexes.has(i)) {
        offsets[i] = scrollableOffset + leftPinnedWidth;
        scrollableOffset += columnWidths[i] ?? 0;
      }
    }

    // Calculate offsets for right-pinned columns
    let rightOffset = totalWidth - rightPinnedWidth;

    pinnedColumns?.right?.forEach((index) => {
      offsets[index] = rightOffset;
      rightOffset += columnWidths[index] ?? 0;
    });

    return offsets;
  }, [columnWidths, pinnedColumns?.left, pinnedColumns?.right, totalWidth]);

  // Calculate row positions and store them for quick lookup
  const rowPositions = useMemo(() => {
    const positions: number[] = Array.from({ length: rowCount }, () => 0);
    const pinnedTopIndexes = new Set(pinnedRows?.top ?? []);
    const pinnedBottomIndexes = new Set(pinnedRows?.bottom ?? []);

    let currentPosition = 0; // Start after header

    // Calculate positions for non-pinned rows sequentially
    for (let i = 0; i < rowCount; i++) {
      if (!pinnedTopIndexes.has(i) && !pinnedBottomIndexes.has(i)) {
        positions[i] = currentPosition;
        currentPosition += getRowHeight(i);
      }
    }

    // Calculate positions for pinned top rows - they stack from the top after header
    let pinnedTopOffset = headerHeight;

    pinnedRows?.top?.forEach((index) => {
      positions[index] = pinnedTopOffset;
      pinnedTopOffset += getRowHeight(index);
    });

    return positions;
  }, [getRowHeight, rowCount, pinnedRows?.top, pinnedRows?.bottom, headerHeight]);

  // Calculate visible rows based on adjusted scrollTop and containerHeight
  const visibleRows = useMemo(() => {
    const rows: IVirtualRow[] = [];
    const pinnedTopIndexes = new Set(pinnedRows?.top ?? []);
    const pinnedBottomIndexes = new Set(pinnedRows?.bottom ?? []);

    let pinnedTopOffset = headerHeight; // Start pinned rows after header

    // Add top pinned rows
    pinnedRows?.top?.forEach((index) => {
      rows.push({
        index,
        top: pinnedTopOffset,
        height: getRowHeight(index),
        pinned: "top",
      });
      pinnedTopOffset += getRowHeight(index);
    });

    // Calculate visible range for non-pinned rows
    // Don't include pinned rows height in scroll offset calculation
    const visibleStartOffset = Math.max(0, scrollTop - rowOverscan * getRowHeight(0));
    const visibleEndOffset = scrollTop + containerHeight + rowOverscan * getRowHeight(0);

    // Add visible non-pinned rows
    for (let i = 0; i < rowCount; i++) {
      if (!pinnedTopIndexes.has(i) && !pinnedBottomIndexes.has(i)) {
        const rowHeight = getRowHeight(i);
        const rowTop = rowPositions[i];

        if (rowTop != null && rowTop + rowHeight >= visibleStartOffset && rowTop <= visibleEndOffset) {
          rows.push({
            index: i,
            top: rowTop,
            height: rowHeight,
          });
        }
      }
    }

    // Add bottom pinned rows
    if (pinnedBottomIndexes.size > 0) {
      let bottomOffset = containerHeight;

      Array.from(pinnedBottomIndexes)
        .sort((a, b) => b - a)
        .forEach((index) => {
          const rowHeight = getRowHeight(index);

          bottomOffset -= rowHeight;
          rows.push({
            index,
            top: bottomOffset,
            height: rowHeight,
            pinned: "bottom",
          });
        });

      // Sort the bottom pinned rows back into ascending order
      const bottomPinnedStartIndex = rows.length - pinnedBottomIndexes.size;
      const bottomPinnedEndIndex = rows.length;
      const bottomPinnedRows = rows.slice(bottomPinnedStartIndex, bottomPinnedEndIndex);

      bottomPinnedRows.sort((a, b) => a.index - b.index);
      rows.splice(bottomPinnedStartIndex, pinnedBottomIndexes.size, ...bottomPinnedRows);
    }

    return rows;
  }, [
    pinnedRows?.top,
    pinnedRows?.bottom,
    headerHeight,
    scrollTop,
    rowOverscan,
    getRowHeight,
    containerHeight,
    rowCount,
    rowPositions,
  ]);

  // Calculate visible columns based on scrollLeft and containerWidth
  const visibleColumns = useMemo(() => {
    const columns: IVirtualColumn[] = [];
    const maxColumnWidth = Math.max(...columnWidths);
    const startOffset = scrollLeft - (columnOverscan - 1) * maxColumnWidth;
    const endOffset = scrollLeft + containerWidth + (columnOverscan - 1) * maxColumnWidth;

    const leftPinnedCount = pinnedColumns?.left?.length ?? 0;
    const rightPinnedCount = pinnedColumns?.right?.length ?? 0;

    // Include left pinned columns first
    for (let i = 0; i < leftPinnedCount; i++) {
      columns.push({
        index: i,
        left: columnOffsets[i] ?? 0,
        width: columnWidths[i] ?? 0,
        pinned: "left",
      });
    }

    // Calculate visible non-pinned columns range
    const startIndex = leftPinnedCount;
    const endIndex = columnCount - rightPinnedCount - 1;

    // Find visible range for scrollable columns
    for (let i = startIndex; i <= endIndex; i++) {
      const colStart = columnOffsets[i];

      if (colStart == null) continue;
      const colEnd = colStart + (columnWidths[i] ?? 0);

      if (colEnd >= startOffset && colStart <= endOffset) {
        columns.push({
          index: i,
          left: columnOffsets[i] ?? 0,
          width: columnWidths[i] ?? 0,
          pinned: null,
        });
      }
    }

    // Add right-pinned columns last
    if (rightPinnedCount > 0) {
      const rightPinnedStartIndex = columnCount - rightPinnedCount;

      for (let i = rightPinnedStartIndex; i < columnCount; i++) {
        columns.push({
          index: i,
          right: totalWidth - (columnOffsets[i] ?? 0) - (columnWidths[i] ?? 0),
          width: columnWidths[i] ?? 0,
          pinned: "right",
        });
      }
    }

    return columns;
  }, [
    scrollLeft,
    containerWidth,
    pinnedColumns?.left,
    pinnedColumns?.right,
    columnCount,
    columnOffsets,
    columnWidths,
    columnOverscan,
    totalWidth,
  ]);

  // Add new function to check if we need to load more data
  const checkNeedsLoad = useCallback(
    (currentScrollTop: number) => {
      if (!onLoadMore || isLoadingMore || loadMoreRef.current) {
        return false;
      }

      const rowHeight = getRowHeight(0);
      const viewportHeight = containerRef.current?.clientHeight ?? 0;
      const visibleEndIndex = Math.floor((currentScrollTop + viewportHeight) / rowHeight);
      const visibleRowCount = Math.ceil(viewportHeight / rowHeight);

      // Check if we're approaching unloaded rows
      const lastVisibleRow = visibleEndIndex;
      const overscanEndRow = Math.min(lastVisibleRow + rowOverscan, rowCount - 1);

      // Find the first unloaded row in the overscan area
      let firstUnloadedRow = -1;

      for (let i = lastVisibleRow + 1; i <= overscanEndRow; i++) {
        if (!isRowLoaded(i)) {
          firstUnloadedRow = i;
          break;
        }
      }

      if (firstUnloadedRow !== -1) {
        // Only trigger load if we're within 1/4 of a viewport of the first unloaded row
        const distanceToUnloaded = firstUnloadedRow - lastVisibleRow;
        const loadThreshold = Math.ceil(visibleRowCount / 4);

        // Also ensure we're close enough to the end of loaded rows
        return distanceToUnloaded <= loadThreshold && distanceToUnloaded <= Math.ceil(rowOverscan / 2);
      }

      return false;
    },
    [onLoadMore, isLoadingMore, getRowHeight, containerRef, rowOverscan, rowCount, isRowLoaded],
  );

  // Add effect to check for load on mount and container size changes
  useEffect(() => {
    const container = containerRef.current;

    if (!container || !onLoadMore) return;

    const loadMore = async (): Promise<void> => {
      if (isLoadingMore || loadMoreRef.current) return;

      setIsLoadingMore(true);
      const promise = onLoadMore();

      loadMoreRef.current = promise;

      try {
        await promise;
      } finally {
        if (loadMoreRef.current === promise) {
          loadMoreRef.current = null;
        }
        setIsLoadingMore(false);
      }
    };

    if (checkNeedsLoad(container.scrollTop)) {
      void loadMore();
    }
  }, [containerRef, checkNeedsLoad, onLoadMore, isLoadingMore]);

  const onScroll = useCallback(
    async (e: React.UIEvent<HTMLDivElement>) => {
      const target = e.currentTarget;
      const newScrollTop = target.scrollTop;
      const isScrollingDown = newScrollTop > prevScrollTopRef.current;

      setScrollTop(newScrollTop);
      setScrollLeft(target.scrollLeft);
      prevScrollTopRef.current = newScrollTop;

      // Check if we need to load more rows
      if (isScrollingDown && onLoadMore && !isLoadingMore && !loadMoreRef.current && checkNeedsLoad(newScrollTop)) {
        setIsLoadingMore(true);
        const promise = onLoadMore();

        loadMoreRef.current = promise;

        try {
          await promise;
        } finally {
          if (loadMoreRef.current === promise) {
            loadMoreRef.current = null;
          }
          setIsLoadingMore(false);
        }
      }
    },
    [checkNeedsLoad, onLoadMore, isLoadingMore],
  );

  // Add new method to scroll to a specific cell
  const scrollToCell = useCallback(
    (rowId: number, colId: number) => {
      if (!containerRef.current) return;

      const rowHeight = getRowHeight(rowId);
      const columnWidth = columnWidths[colId] ?? 0;

      const targetTop = rowPositions[rowId];
      const targetLeft = columnOffsets[colId];

      if (targetTop == null || targetLeft == null) return;

      const pinnedTopIndexes = new Set(pinnedRows?.top ?? []);
      const pinnedBottomIndexes = new Set(pinnedRows?.bottom ?? []);
      const pinnedLeftIndexes = new Set(pinnedColumns?.left ?? []);
      const pinnedRightIndexes = new Set(pinnedColumns?.right ?? []);

      // Skip scrolling if the cell is pinned (already always visible)
      if (pinnedTopIndexes.has(rowId) || pinnedBottomIndexes.has(rowId)) {
        return;
      }

      if (pinnedLeftIndexes.has(colId) || pinnedRightIndexes.has(colId)) {
        return;
      }

      let pinnedTopHeight = headerHeight; // Start with header height
      let pinnedBottomHeight = 0;
      let pinnedLeftWidth = 0;
      let pinnedRightWidth = 0;

      // Add heights of pinned top rows
      pinnedRows?.top?.forEach((index) => {
        pinnedTopHeight += getRowHeight(index);
      });

      // Add heights of pinned bottom rows
      pinnedRows?.bottom?.forEach((index) => {
        pinnedBottomHeight += getRowHeight(index);
      });

      // Add widths of pinned columns
      pinnedColumns?.left?.forEach((index) => {
        pinnedLeftWidth += columnWidths[index] ?? 0;
      });

      pinnedColumns?.right?.forEach((index) => {
        pinnedRightWidth += columnWidths[index] ?? 0;
      });

      // Calculate visible viewport boundaries
      const visibleTop = scrollTop + pinnedTopHeight;
      const visibleBottom = scrollTop + containerHeight - pinnedBottomHeight;
      const visibleLeft = scrollLeft + pinnedLeftWidth;
      const visibleRight = scrollLeft + containerWidth - pinnedRightWidth;

      let finalScrollTop = scrollTop;
      let finalScrollLeft = scrollLeft;

      const targetBottom = targetTop + rowHeight;
      const targetRight = targetLeft + columnWidth;

      // Calculate vertical scroll position
      if (targetTop < visibleTop) {
        // Scroll up to show the target at the top of the viewport (after pinned elements)
        finalScrollTop = targetTop - pinnedTopHeight;
      } else if (targetBottom > visibleBottom) {
        // Scroll down to show the target at the bottom of the viewport (accounting for pinned elements)
        finalScrollTop = targetBottom - containerHeight + pinnedBottomHeight;
      }

      // Calculate horizontal scroll position
      if (targetLeft < visibleLeft) {
        finalScrollLeft = targetLeft - pinnedLeftWidth;
      } else if (targetRight > visibleRight) {
        finalScrollLeft = targetRight - containerWidth + pinnedRightWidth;
      }

      // Ensure scroll positions are within valid bounds
      finalScrollTop = Math.max(0, Math.min(totalHeight - containerHeight, finalScrollTop));
      finalScrollLeft = Math.max(0, Math.min(totalWidth - containerWidth, finalScrollLeft));

      // Only perform scrolling if needed
      if (finalScrollTop !== scrollTop || finalScrollLeft !== scrollLeft) {
        containerRef.current.scrollTo({
          top: finalScrollTop,
          left: finalScrollLeft,
          behavior: "smooth",
        });
      }
    },
    [
      containerRef,
      scrollTop,
      scrollLeft,
      rowPositions,
      columnOffsets,
      getRowHeight,
      columnWidths,
      containerHeight,
      containerWidth,
      totalHeight,
      totalWidth,
      pinnedRows?.top,
      pinnedRows?.bottom,
      pinnedColumns?.left,
      pinnedColumns?.right,
      headerHeight,
    ],
  );

  return {
    onScroll,
    visibleRows,
    visibleColumns,
    totalWidth,
    totalHeight,
    scrollToCell,
  };
};
