import React, { useMemo, useState, useRef, useCallback } from "react";
import { createRoot } from "react-dom/client";

/**
 * PathwayGraph — mock for landing page (SVG version)
 * --------------------------------------------------
 * Pure React + SVG. No external graph library.
 * When you move to your real project, you can either:
 *   (a) keep this SVG implementation (zero deps, full control), or
 *   (b) port to React Flow — the data shape and THEME tokens stay the same.
 *
 * Backend response shape (from your dev):
 *   {
 *     "path":        ["Sildenafil", "PDE5A", "PRKG1", "PAH"],
 *     "scores":      [0.87, 1.0, 0.74],
 *     "hops":        1,
 *     "final_score": 0.67
 *   }
 */

/* ═══════════════════════════════════════════════════════════════════
   THEME TOKENS — replace with your brand system
   ═══════════════════════════════════════════════════════════════════ */
const THEME = {
  canvasBg: "#FFFFFF",         // --color-surface (white)
  gridDot: "#E8E8E8",          // --color-border

  primary: "#036C4D",          // --color-mint
  primarySoft: "#F0F7F4",      // --color-surface-mint

  nodes: {
    drug:    { fill: "#FFFFFF", stroke: "#036C4D", text: "#1A1A1A" },
    protein: { fill: "#F0F7F4", stroke: "#1B7858", text: "#1A1A1A" },
    disease: { fill: "#FFFFFF", stroke: "#B8472E", text: "#3A1810" },
    motif:   { fill: "#FFFFFF", stroke: "#999999", text: "#666666" },
  },

  edgeStrong: "#036C4D",
  edgeMid:    "#999999",
  edgeWeak:   "#D4D4C8",
  edgeLabelBg: "#FFFFFF",
  edgeLabelText: "#1A1A1A",

  fontDisplay: "'Helvetica Neue', Helvetica, Inter, sans-serif",
  fontBody:    "'Helvetica Neue', Helvetica, Inter, sans-serif",
  fontMono:    "'JetBrains Mono', ui-monospace, monospace",
};

/* ═══════════════════════════════════════════════════════════════════
   MOCK DATA
   ═══════════════════════════════════════════════════════════════════ */
const MOCK_DATA = {
  path: ["Sildenafil", "PDE5A", "PRKG1", "PAH"],
  scores: [0.91, 0.81, 0.81],
  hops: 1,
  final_score: 0.67,
  motifs: [
    { attachedTo: "PDE5A", label: "Thrombus Formation", score: 0.81, side: "top" },
    { attachedTo: "PDE5A", label: "Fibrinolysis",       score: 0.01, side: "right" },
    { attachedTo: "PRKG1", label: "Thrombus Formation", score: 0.81, side: "top-right" },
    { attachedTo: "PRKG1", label: "Fibrinolysis",       score: 0.01, side: "right" },
  ],
};

/* ═══════════════════════════════════════════════════════════════════
   HELPERS
   ═══════════════════════════════════════════════════════════════════ */
function inferNodeKind(index, total) {
  if (index === 0) return "drug";
  if (index === total - 1) return "disease";
  return "protein";
}

function edgeColor(score) {
  if (score >= 0.7) return THEME.edgeStrong;
  if (score >= 0.3) return THEME.edgeMid;
  return THEME.edgeWeak;
}

function edgeWidth(score) {
  return 1.2 + score * 2;
}

const NODE_SIZES = {
  drug:    { w: 110, h: 110, shape: "circle" },
  protein: { w: 130, h: 80,  shape: "rect" },
  disease: { w: 130, h: 100, shape: "triangle" },
  motif:   { w: 140, h: 70,  shape: "rect" },
};

// Inner padding (px) for text inside rect/circle nodes
const NODE_PADDING = {
  drug: 14,
  protein: 12,
  disease: 14,
  motif: 12,
};

// Split a label into up to 2 lines if it doesn't fit one line at given width.
// Uses a rough char-width heuristic — good enough for short bio terms.
function wrapLabel(text, maxWidth, fontSize) {
  const avgCharWidth = fontSize * 0.55;
  const maxChars = Math.floor(maxWidth / avgCharWidth);
  if (text.length <= maxChars) return [text];

  const words = text.split(" ");
  if (words.length === 1) return [text]; // can't wrap a single word

  // Greedy two-line split: balance line lengths
  let bestSplit = 1;
  let bestDiff = Infinity;
  for (let i = 1; i < words.length; i++) {
    const left = words.slice(0, i).join(" ").length;
    const right = words.slice(i).join(" ").length;
    const diff = Math.abs(left - right);
    if (diff < bestDiff && Math.max(left, right) <= maxChars) {
      bestDiff = diff;
      bestSplit = i;
    }
  }
  return [
    words.slice(0, bestSplit).join(" "),
    words.slice(bestSplit).join(" "),
  ];
}

/* ═══════════════════════════════════════════════════════════════════
   LAYOUT
   ═══════════════════════════════════════════════════════════════════ */
function buildGraph(data) {
  const { path, scores, motifs = [] } = data;
  const xStep = 230;
  const baseY = 320;
  const xOffset = 140;

  const nodes = path.map((label, i) => ({
    id: `n-${i}`,
    label,
    kind: inferNodeKind(i, path.length),
    x: xOffset + i * xStep,
    y: baseY,
  }));

  const edges = scores.map((score, i) => ({
    id: `e-${i}`,
    source: `n-${i}`,
    target: `n-${i + 1}`,
    score,
    isMain: true,
  }));

  motifs.forEach((m, idx) => {
    const parentIdx = path.indexOf(m.attachedTo);
    if (parentIdx === -1) return;
    const parent = nodes[parentIdx];

    const offsets = {
      "top":       { dx:    0, dy: -180 },
      "top-right": { dx:  130, dy: -160 },
      "top-left":  { dx: -130, dy: -160 },
      "bottom":    { dx:    0, dy:  180 },
      "right":     { dx:  200, dy:  -30 },
      "left":      { dx: -200, dy:  -30 },
    };
    const off = offsets[m.side] || offsets.top;

    const motifId = `m-${idx}`;
    nodes.push({
      id: motifId,
      label: m.label,
      kind: "motif",
      x: parent.x + off.dx,
      y: parent.y + off.dy,
      score: m.score,
    });

    edges.push({
      id: `em-${idx}`,
      source: motifId,
      target: parent.id,
      score: m.score,
      isMain: false,
    });
  });

  return { nodes, edges };
}

/* ═══════════════════════════════════════════════════════════════════
   GEOMETRY
   ═══════════════════════════════════════════════════════════════════ */
function intersectNode(node, fromX, fromY) {
  const size = NODE_SIZES[node.kind];
  const cx = node.x;
  const cy = node.y;
  const dx = fromX - cx;
  const dy = fromY - cy;
  const len = Math.sqrt(dx * dx + dy * dy) || 1;

  if (size.shape === "circle") {
    const r = size.w / 2;
    return { x: cx + (dx / len) * r, y: cy + (dy / len) * r };
  }
  const halfW = size.w / 2;
  const halfH = size.h / 2;
  const absDx = Math.abs(dx) || 0.0001;
  const absDy = Math.abs(dy) || 0.0001;
  const scale = Math.min(halfW / absDx, halfH / absDy);
  return { x: cx + dx * scale, y: cy + dy * scale };
}

/* ═══════════════════════════════════════════════════════════════════
   NODE
   ═══════════════════════════════════════════════════════════════════ */
function Node({ node, onPointerDown, isDragging }) {
  const palette = THEME.nodes[node.kind];
  const size = NODE_SIZES[node.kind];
  const isMotif = node.kind === "motif";

  const cx = node.x;
  const cy = node.y;

  let shape;
  if (size.shape === "circle") {
    shape = (
      <circle
        cx={cx} cy={cy} r={size.w / 2}
        fill={palette.fill}
        stroke={palette.stroke}
        strokeWidth={1.8}
      />
    );
  } else if (size.shape === "triangle") {
    const pts = `${cx},${cy - size.h / 2} ${cx + size.w / 2},${cy + size.h / 2} ${cx - size.w / 2},${cy + size.h / 2}`;
    shape = (
      <polygon
        points={pts}
        fill={palette.fill}
        stroke={palette.stroke}
        strokeWidth={1.8}
      />
    );
  } else {
    shape = (
      <rect
        x={cx - size.w / 2}
        y={cy - size.h / 2}
        width={size.w}
        height={size.h}
        rx={12}
        fill={palette.fill}
        stroke={palette.stroke}
        strokeWidth={1.5}
      />
    );
  }

  const fontSize = isMotif ? 14 : node.kind === "drug" ? 20 : 18;
  const fontFamily = (node.kind === "drug" || node.kind === "disease")
    ? THEME.fontDisplay
    : THEME.fontBody;

  // Compute available width for text (node width minus padding on both sides)
  const padding = NODE_PADDING[node.kind] || 12;
  const availableWidth = size.w - padding * 2;
  const lines = wrapLabel(node.label, availableWidth, fontSize);

  // Vertical anchoring per shape
  const lineHeight = fontSize * 1.15;
  const totalTextHeight = lines.length * lineHeight;
  const baseTextY = node.kind === "disease"
    ? cy + 22 - (totalTextHeight - lineHeight) / 2
    : cy - (totalTextHeight - lineHeight) / 2;

  return (
    <g
      className="pathway-node"
      onPointerDown={(e) => onPointerDown(e, node.id)}
      style={{
        cursor: isDragging ? "grabbing" : "grab",
        touchAction: "none",
        filter: isMotif
          ? "drop-shadow(0 1px 2px rgba(0,0,0,0.04))"
          : "drop-shadow(0 4px 12px rgba(31,111,74,0.10))",
      }}
    >
      {shape}
      <text
        x={cx}
        y={baseTextY}
        textAnchor="middle"
        dominantBaseline="middle"
        fontFamily={fontFamily}
        fontSize={fontSize}
        fontWeight={node.kind === "drug" ? 500 : 600}
        fill={palette.text}
        style={{ userSelect: "none", pointerEvents: "none" }}
      >
        {lines.map((line, i) => (
          <tspan key={i} x={cx} dy={i === 0 ? 0 : lineHeight}>
            {line}
          </tspan>
        ))}
      </text>
      {isMotif && node.score !== undefined && (
        <text
          x={cx}
          y={baseTextY + totalTextHeight - lineHeight + 14}
          textAnchor="middle"
          fontFamily={THEME.fontMono}
          fontSize={11}
          fill={palette.stroke}
          opacity={0.7}
          style={{ userSelect: "none", pointerEvents: "none" }}
        >
          {node.score.toFixed(2)}
        </text>
      )}
    </g>
  );
}

/* ═══════════════════════════════════════════════════════════════════
   EDGE
   ═══════════════════════════════════════════════════════════════════ */
function Edge({ edge, sourceNode, targetNode }) {
  const start = intersectNode(sourceNode, targetNode.x, targetNode.y);
  const end   = intersectNode(targetNode, sourceNode.x, sourceNode.y);

  const color = edgeColor(edge.score);
  const width = edgeWidth(edge.score);
  const dashed = edge.score < 0.3;

  const mx = (start.x + end.x) / 2;
  const my = (start.y + end.y) / 2;

  const markerEnd = edge.isMain ? `url(#arrow-${color.replace("#", "")})` : undefined;

  const labelText = `GDA: ${edge.score.toFixed(2)}`;
  const labelWidth = labelText.length * 7 + 12;

  return (
    <g>
      <line
        x1={start.x} y1={start.y}
        x2={end.x}   y2={end.y}
        stroke={color}
        strokeWidth={width}
        strokeDasharray={dashed ? "4 4" : undefined}
        markerEnd={markerEnd}
      />
      <g transform={`translate(${mx}, ${my})`}>
        <rect
          x={-labelWidth / 2}
          y={-11}
          width={labelWidth}
          height={22}
          rx={4}
          fill={THEME.edgeLabelBg}
          fillOpacity={0.95}
          stroke={THEME.gridDot}
          strokeWidth={0.5}
        />
        <text
          textAnchor="middle"
          dominantBaseline="middle"
          fontFamily={THEME.fontMono}
          fontSize={12}
          fontWeight={500}
          fill={THEME.edgeLabelText}
          style={{ userSelect: "none", pointerEvents: "none" }}
        >
          {labelText}
        </text>
      </g>
    </g>
  );
}

/* ═══════════════════════════════════════════════════════════════════
   MAIN
   ═══════════════════════════════════════════════════════════════════ */
function PathwayGraph({ data = MOCK_DATA }) {
  const initial = useMemo(() => buildGraph(data), [data]);
  const [nodes, setNodes] = useState(initial.nodes);
  const edges = initial.edges;
  const svgRef = useRef(null);
  const dragState = useRef(null);
  const [draggingId, setDraggingId] = useState(null);

  const screenToSvg = (clientX, clientY) => {
    const svg = svgRef.current;
    const pt = svg.createSVGPoint();
    pt.x = clientX;
    pt.y = clientY;
    return pt.matrixTransform(svg.getScreenCTM().inverse());
  };

  const onPointerDown = useCallback((e, nodeId) => {
    e.preventDefault();
    e.stopPropagation();
    if (e.currentTarget && e.currentTarget.setPointerCapture) {
      try { e.currentTarget.setPointerCapture(e.pointerId); } catch (_) {}
    }
    const svgP = screenToSvg(e.clientX, e.clientY);
    const node = nodes.find((n) => n.id === nodeId);
    if (!node) return;
    dragState.current = {
      id: nodeId,
      pointerId: e.pointerId,
      offsetX: svgP.x - node.x,
      offsetY: svgP.y - node.y,
    };
    setDraggingId(nodeId);
  }, [nodes]);

  const onPointerMove = useCallback((e) => {
    if (!dragState.current) return;
    if (e.pointerId !== dragState.current.pointerId) return;
    const svgP = screenToSvg(e.clientX, e.clientY);
    const { id, offsetX, offsetY } = dragState.current;
    setNodes((prev) =>
      prev.map((n) =>
        n.id === id ? { ...n, x: svgP.x - offsetX, y: svgP.y - offsetY } : n
      )
    );
  }, []);

  const onPointerUp = useCallback(() => {
    dragState.current = null;
    setDraggingId(null);
  }, []);

  const nodeMap = useMemo(() => {
    const m = {};
    nodes.forEach((n) => { m[n.id] = n; });
    return m;
  }, [nodes]);

  const arrowColors = [...new Set(edges.filter((e) => e.isMain).map((e) => edgeColor(e.score)))];

  return (
    <div
      style={{
        width: "100%",
        height: "100%",
        background: THEME.canvasBg,
        position: "relative",
        overflow: "hidden",
        fontFamily: THEME.fontBody,
      }}
    >
      <div
        style={{
          position: "absolute",
          bottom: 20,
          right: 24,
          zIndex: 5,
          fontFamily: THEME.fontMono,
          fontSize: 10,
          color: THEME.nodes.motif.text,
          letterSpacing: 0.5,
          opacity: 0.6,
          pointerEvents: "none",
        }}
      >
        ↔ drag nodes to rearrange
      </div>

      <svg
        ref={svgRef}
        viewBox="50 80 820 320"
        preserveAspectRatio="xMidYMid meet"
        style={{ width: "100%", height: "100%", display: "block", touchAction: "none" }}
        onPointerMove={onPointerMove}
        onPointerUp={onPointerUp}
        onPointerCancel={onPointerUp}
        onPointerLeave={onPointerUp}
      >
        <defs>
          {arrowColors.map((c) => (
            <marker
              key={c}
              id={`arrow-${c.replace("#", "")}`}
              viewBox="0 0 10 10"
              refX="9"
              refY="5"
              markerWidth="7"
              markerHeight="7"
              orient="auto-start-reverse"
            >
              <path d="M 0 0 L 10 5 L 0 10 z" fill={c} />
            </marker>
          ))}
        </defs>

        {edges.map((e) => {
          const s = nodeMap[e.source];
          const t = nodeMap[e.target];
          if (!s || !t) return null;
          return <Edge key={e.id} edge={e} sourceNode={s} targetNode={t} />;
        })}

        {nodes.map((n) => (
          <Node
            key={n.id}
            node={n}
            onPointerDown={onPointerDown}
            isDragging={draggingId === n.id}
          />
        ))}
      </svg>
    </div>
  );
}

/* ═══════════════════════════════════════════════════════════════════
   MOUNT — render into #pathway-graph
   ═══════════════════════════════════════════════════════════════════ */
const container = document.getElementById("pathway-graph");
if (container) {
  createRoot(container).render(<PathwayGraph />);
} else {
  console.warn("[pathway-mount] #pathway-graph not found");
}
