import { useRef, useCallback, useState, useEffect } from "react"
import {
	ReactFlow,
	useNodesState,
	useEdgesState,
	Controls,
	useReactFlow,
	Background,
	NodeTypes,
	Edge,
	Connection,
} from "@xyflow/react"

import "@xyflow/react/dist/style.css"

import Sidebar from "./Sidebar"
import {
	openPromptConfigAtom,
	openRetrievalConfigAtom,
	promptRangesAtom,
	retrievalRangesAtom,
	typeAtom,
} from "../atoms/index"
import { useAtom } from "jotai"
import GraphHeader from "./GraphHeader"
import GraphLog from "./GraphLog"
import Variables from "../nodes/Variables"
import Prompt from "../nodes/Prompt"
import JurisprudenceSearch from "../nodes/JurisprudenceSearch"
import Parser from "../nodes/Parser"
import Preview from "../nodes/Preview"

import useAgentBuilder from "hooks/useAgentBuilder"
import { useMutation, useQuery } from "@tanstack/react-query"
import { useParams } from "react-router-dom"
import {
	nodesDefaultSizes,
	sanitizeNode,
	serializeNode,
	typesTranslate,
} from "../utils"
import Stepper from "./Stepper"
import DeleteNodeModal from "./DeleteNodeModal"
import PromptConfig from "./PromptConfig"
import { defaultPromptRanges, defaultRetrievalRanges } from "../constants"
import { GraphData } from "types/agentBuilder"
import { generateRandomName } from "utils/generateRandomLetters"
import LoadingSpinner from "pages/MyDocumentsPage/subcomponents/LoadingSpinner"
import Retrieval from "../nodes/Retrieval"
import RetrievalConfig from "./RetrievalConfig"
import Input from "../nodes/Input"
import useNodeExecutionStatus from "../hooks/useNodeExecutionStatus"
import Transform from "../nodes/Transform"
import Toast from "components/toast"
import { showToast } from "components/toast/functions"
import { ToastType } from "components/toast/types"
export interface NodeData extends Record<string, unknown> {
	nodeData: any
	label?: string
	graphId: string
}

export interface Node {
	id: string
	type: string
	data: NodeData
	position: { x: number; y: number }
	size: { width: number; height: number }
}

interface DropEvent {
	clientX: number
	clientY: number
	preventDefault: () => void
}

interface DragEvent {
	preventDefault: () => void
	dataTransfer: {
		dropEffect: string
	}
}

const nodeTypes: NodeTypes = {
	variables: Variables,
	prompt: Prompt,
	jurisprudence: JurisprudenceSearch,
	parser: Parser,
	preview: Preview,
	retrieval: Retrieval,
	predefinedInput: Input,
	transform: Transform,
}

let id = 0
const getId = () => `dndnode_${id++}`

const Flow = () => {
	const reactFlowWrapper = useRef(null)
	const { getGraph, addNode, addNewEdge, deleteEdge, getGraphCheckpoints } =
		useAgentBuilder()
	const { graphId = "" } = useParams()
	const { data, isRefetching } = useQuery(
		["graph", graphId],
		() => getGraph(graphId),
		{
			refetchOnWindowFocus: false,
		},
	)
	const { data: checkpoints } = useQuery(
		["checkpoints", graphId],
		() => getGraphCheckpoints(graphId),
		{
			refetchOnWindowFocus: false,
		},
	)

	const [nodes, setInitialNodes, onNodesChange] = useNodesState([] as Node[])
	const [edges, setEdges, onEdgesChange] = useEdgesState<Edge>([])
	const { screenToFlowPosition, addEdges, setNodes, getNodes, getNode } =
		useReactFlow()
	const [type] = useAtom(typeAtom)
	const [, setPromptRanges] = useAtom(promptRangesAtom)
	const [, setRetrievalRanges] = useAtom(retrievalRangesAtom)
	const [openPromptConfig, setOpenPromptConfig] =
		useAtom(openPromptConfigAtom)
	const [openRetrievalConfig, setOpenRetrievalConfig] = useAtom(
		openRetrievalConfigAtom,
	)

	const [openStepper, setOpenStepper] = useState(false)
	const [openDeleteNodeModal, setOpenDeleteNodeModal] = useState(false)
	const [nodeToDelete, setNodeToDelete] = useState<Node | null>(null)
	const [loading, setLoading] = useState(false)

	const addNodeMutation = useMutation({
		mutationFn: addNode,
		onError: (error) => {
			showToast((error as Error)?.message, ToastType.Error)
		},
	})

	const addNewEdgeMutation = useMutation({
		mutationFn: addNewEdge,
		onError: (error) => {
			showToast((error as Error)?.message, ToastType.Error)
		},
	})

	const deleteEdgeMutation = useMutation({
		mutationFn: deleteEdge,
		onError: (error) => {
			showToast((error as Error)?.message, ToastType.Error)
		},
	})

	const onDragOver = useCallback((event: DragEvent) => {
		event.preventDefault()
		event.dataTransfer.dropEffect = "move"
	}, [])

	const onDrop = useCallback(
		async (event: DropEvent) => {
			event.preventDefault()
			if (!type) {
				return
			}
			const position = screenToFlowPosition({
				x: event.clientX,
				y: event.clientY,
			})

			const nodeName = `${
				typesTranslate[type as keyof typeof typesTranslate]
			}_${generateRandomName()}`

			const newNode: Node = {
				id: getId(),
				type,
				position,
				data: {
					graphId,
					nodeData: {
						name: nodeName,
					},
				},
				size: {
					...nodesDefaultSizes[
						type as keyof typeof nodesDefaultSizes
					],
				},
			}

			const serializedNode = serializeNode(newNode, graphId)

			const data = await addNodeMutation.mutateAsync(serializedNode)

			setPromptRanges((prev) => [
				...prev,
				{
					nodeId: data.id,
					promptRanges: defaultPromptRanges[0].promptRanges,
					llm: "gpt-4o-mini-2024-07-18",
				},
			])

			setRetrievalRanges((prev) => [
				...prev,
				{
					nodeId: data.id,
					retrievalRanges: defaultRetrievalRanges[0].retrievalRanges,
				},
			])

			const sanitizedNode = sanitizeNode([data])

			setInitialNodes((nds) => nds.concat(sanitizedNode))
		},
		[screenToFlowPosition, type],
	)

	function handleDeleteNode(nodeId: string) {
		const node = getNode(nodeId)
		if (!node) return
		setNodeToDelete(node as Node)
		setOpenDeleteNodeModal(true)
	}

	async function onSaveEdges(e: Connection) {
		addEdges({
			source: e.source,
			target: e.target,
			id: "",
		} as Edge)

		const data = await addNewEdgeMutation.mutateAsync({
			from_node: e.source,
			to_node: e.target,
		})

		setEdges((edges) => [
			...edges.map((edge) => {
				if (edge.source === e.source && edge.target === e.target) {
					return {
						...edge,
						id: data.id,
					}
				}
				return edge
			}),
		])
	}

	function getPromptRanges(data: GraphData) {
		const promptNodes = data.nodes.filter(
			(node) => node.node_type === "PROMPT",
		)
		const newRanges = promptNodes.map((node) => {
			const temperature = node.temperature
			const top_p = node.top_p
			const frequency_penalty = node.frequency_penalty
			const presence_penalty = node.presence_penalty
			const max_tokens = node.max_tokens

			return {
				nodeId: node.id,
				llm: node.llm_model,
				promptRanges: defaultPromptRanges[0].promptRanges.map(
					(range) => {
						if (range.id === "temperature") {
							return {
								...range,
								value: temperature,
							}
						}
						if (range.id === "top_p") {
							return {
								...range,
								value: top_p,
							}
						}
						if (range.id === "frequency_penalty") {
							return {
								...range,
								value: frequency_penalty,
							}
						}
						if (range.id === "presence_penalty") {
							return {
								...range,
								value: presence_penalty,
							}
						}
						if (range.id === "max_tokens") {
							return {
								...range,
								value: max_tokens,
							}
						}
						return range
					},
				),
			}
		})
		return newRanges
	}

	function getRetrievalRanges(data: GraphData) {
		const retrievalNodes = data.nodes.filter(
			(node) => node.node_type === "RETRIEVAL",
		)
		const newRanges = retrievalNodes.map((node) => {
			const top_k = node.top_k
			const similarity_score = node.similarity_score

			return {
				nodeId: node.id,
				retrievalRanges: defaultRetrievalRanges[0].retrievalRanges.map(
					(range) => {
						if (range.id === "top_k") {
							return {
								...range,
								value: top_k,
							}
						}
						if (range.id === "similarity_score") {
							return {
								...range,
								value: similarity_score,
							}
						}
						return range
					},
				),
			}
		})
		return newRanges
	}

	useEffect(() => {
		if (data && !isRefetching) {
			const savedNodes = sanitizeNode(data.nodes)
			setNodes([...savedNodes])
			const newPromptRanges = getPromptRanges(data)
			const newRetrievalRanges = getRetrievalRanges(data)
			const transformedEdges = data.nodes.flatMap((node) =>
				node.outgoing_edges.map((edge) => ({
					source: edge.from_node,
					target: edge.to_node,
					id: edge.id,
				})),
			)

			setPromptRanges(newPromptRanges)
			setRetrievalRanges(newRetrievalRanges)
			setEdges(transformedEdges)
			setLoading(false)
		}
	}, [data, isRefetching])

	useEffect(() => {
		if (!openDeleteNodeModal) {
			setNodeToDelete(null)
		}
	}, [openDeleteNodeModal])

	useNodeExecutionStatus()

	if (!data) return null

	return (
		<div className="dndflow relative !font-Red-Hat-Display">
			{loading && (
				<div className="fixed top-0 left-0 w-full h-full bg-white bg-opacity-70 z-50 flex items-center justify-center">
					<LoadingSpinner />
				</div>
			)}
			<GraphHeader
				setOpenStepper={setOpenStepper}
				graphId={graphId}
				graph={data}
				checkpoints={checkpoints}
			/>
			<div
				className="reactflow-wrapper h-[calc(100vh-52px)]  w-screen"
				ref={reactFlowWrapper}
			>
				<ReactFlow
					nodes={nodes}
					edges={edges}
					onNodesChange={(changes) => {
						const currentNodes = getNodes()
						if (changes[0].type === "remove") {
							const removeNodeId = changes[0]?.id || ""
							const currentNode = currentNodes.find(
								(node) => node.id === removeNodeId,
							)

							if (currentNode) {
								handleDeleteNode(removeNodeId)
							} else {
								onNodesChange(changes)
							}
						} else {
							onNodesChange(changes)
						}
					}}
					onEdgesChange={(changes) => {
						if (changes.length === 1) {
							onEdgesChange(changes)
						}
					}}
					onEdgesDelete={(e) => {
						if (e.length === 1) {
							deleteEdgeMutation.mutate(e[0].id)
						}
					}}
					onBeforeDelete={async (e) => {
						const currentNodes = getNodes()
						const hasInitialNodes = e.nodes.find(
							(node) =>
								node.type === "predefinedInput" ||
								node.type === "variables",
						)
						const edgeIsInitial = e.edges.find((edge) => {
							const sourceNode = currentNodes.find(
								(node) => node.id === edge.source,
							)
							const targetNode = currentNodes.find(
								(node) => node.id === edge.target,
							)
							return (
								sourceNode?.type === "variables" &&
								targetNode?.type === "predefinedInput"
							)
						})

						if (hasInitialNodes || edgeIsInitial) {
							return false
						}

						return true
					}}
					edgesReconnectable={false}
					onConnect={(e) => {
						onSaveEdges(e)
					}}
					onDrop={onDrop}
					onDragOver={onDragOver}
					fitView
					nodeTypes={nodeTypes}
					style={{ backgroundColor: "#F7F9FB" }}
					defaultEdgeOptions={{
						style: { stroke: "#0074FF", strokeWidth: 2 },
					}}
					connectionLineStyle={{
						stroke: "#0074FF",
						strokeWidth: 2,
					}}
					minZoom={0.1}
				>
					<Controls
						position="bottom-right"
						orientation="horizontal"
					/>
					<Background />
				</ReactFlow>
			</div>
			<GraphLog graphId={graphId} />
			<Sidebar />

			{openStepper && (
				<div className="bg-white w-[330px] h-[calc(100vh-52px)] absolute z-50 right-0 top-[52px] border-l-[1px] border-[#F0F0F0]">
					<Stepper
						checkpoints={checkpoints || []}
						setOpenStepper={setOpenStepper}
						setLoading={setLoading}
						graph={data}
					/>
				</div>
			)}

			{openPromptConfig && (
				<PromptConfig
					setOpenPromptConfig={setOpenPromptConfig}
					openPromptConfig={openPromptConfig}
				/>
			)}

			{openRetrievalConfig && (
				<RetrievalConfig
					openRetrievalConfig={openRetrievalConfig}
					setOpenRetrievalConfig={setOpenRetrievalConfig}
				/>
			)}

			{openDeleteNodeModal && nodeToDelete && (
				<DeleteNodeModal
					open={openDeleteNodeModal}
					setOpen={setOpenDeleteNodeModal}
					node={nodeToDelete}
					setNode={setNodeToDelete}
				/>
			)}
			<Toast />
		</div>
	)
}

export default Flow
