"""
================================================================================
HYPERBOLISCHER ZELLULARAUTOMAT v2
================================================================================
Variante mit Topologie-Auswahl, Kachel-Tiling und Hex-Modus.

Dieses Programm baut auf dem Original (zellularautomat.py) auf und ergaenzt:

  1. TOPOLOGIE-AUSWAHL
     - Hyperbolic:   Cayley-projizierte Halbebene (Original)
     - Donut:       NxM-Rechteck mit Wrap-around in beide Achsen (Torus)
     - Rectangle:   NxM-Rechteck mit statischen Raendern (kein Wrap)
     - Hex:         Sechseckgitter mit 6 Nachbarn (axial)

  2. KACHEL-TILING (nur Hyperbolic)
     - 1x1, 2x2, 3x3, 4x4 unabhaengige Poincare-Scheiben im Raster.
     - Jede Kachel ist ein eigener unabhaengiger Automat (kein Wrap
       zwischen Kacheln, sonst Verzerrung).

  3. HEX-MODUS (nur Donut und Rectangle)
     - Eigene Topologie, nicht Toggle auf anderen.
     - Drei dokumentierte Hex-Life-Regeln: B2/S34, B3/S12, B4/S34.
     - 1D-Regeln (R30, R110) sind im Hex-Modus deaktiviert.

TASTATUR
--------
  T       Topologie wechseln (Hyperbolic -> Donut -> Rectangle -> Hex)
  1/2/3   Regel in der aktuellen Topologie
  +/-     Geschwindigkeit (1-30 Gen/s)
  Leertaste  Pause / Weiter
  R       Zellen zufallig verteilen (50 %)
  Maus   Klick auf eine Zelle toggelt
  Esc     Beenden
================================================================================"""

import math
import sys
import numpy as np

try:
    import pygame
except ImportError:
    sys.stderr.write(
        "FEHLER: pygame ist nicht installiert.\n"
        "Bitte aktiviere das venv und installiere die Abhaengigkeiten:\n"
        "  source .venv/bin/activate && pip install pygame numpy\n"
    )
    sys.exit(1)


# ==============================================================================
# KONSTANTEN – Farbpalette, Geometrie
# ==============================================================================
TEAL       = (0, 128, 128)
FOREST     = (34, 139, 34)
PETROL     = (0, 98, 98)
PETROL_DK  = (0, 51, 51)
WHITE      = (255, 255, 255)
BLACK      = (0, 0, 0)
UI_ALPHA   = 180

WINDOW_W   = 1100
WINDOW_H   = 900
CENTER     = (WINDOW_W // 2, WINDOW_H // 2)
DISK_R     = 200
GRID_W     = 40
GRID_H     = 30
ORIGIN_X   = -1.6
ORIGIN_Y   = 0.15
DX         = 0.08
DY         = 0.18

FONT_NAME  = "couriernew"
FONT_SIZE  = 16
FONT_SMALL = 13

# Topologien
TOPO_HYPERBOLIC = "hyperbolic"
TOPO_DONUT      = "donut"
TOPO_RECTANGLE  = "rectangle"
TOPO_HEX        = "hex"
TOPOLOGIES      = [TOPO_HYPERBOLIC, TOPO_DONUT, TOPO_RECTANGLE, TOPO_HEX]

# Tiling (nur Hyperbolic)
TILING_OPTIONS = [1, 2, 3, 4]

# Regeln
RULE_R30   = "R30"
RULE_R110  = "R110"
RULE_GOL   = "B3/S23"
RULE_B2S34 = "B2/S34"
RULE_B3S12 = "B3/S12"
RULE_B4S34 = "B4/S34"

SQUARE_RULES = [RULE_R110, RULE_R30, RULE_GOL]
HEX_RULES    = [RULE_B2S34, RULE_B3S12, RULE_B4S34]


# ==============================================================================
# MATHEMATIK – Cayley-Transformation (identisch zum Original)
# ==============================================================================
def cayley_forward(z):
    return (z - 1j) / (z + 1j)


# ==============================================================================
# REGEL-TABELLEN
# ==============================================================================
R30_TABLE = {
    (1, 1, 1): 0, (1, 1, 0): 0, (1, 0, 1): 0, (1, 0, 0): 1,
    (0, 1, 1): 1, (0, 1, 0): 1, (0, 0, 1): 1, (0, 0, 0): 0,
}

R110_TABLE = {
    (1, 1, 1): 0, (1, 1, 0): 1, (1, 0, 1): 1, (1, 0, 0): 0,
    (0, 1, 1): 1, (0, 1, 0): 1, (0, 0, 1): 1, (0, 0, 0): 0,
}


def parse_bs(s):
    """'B3/S23' -> ({3}, {2, 3})"""
    b, su = s.split("/")
    return set(int(c) for c in b[1:]), set(int(c) for c in su[1:])


# ==============================================================================
# GRID-OBJEKT – einheitliche Schnittstelle fuer alle Topologien
# ==============================================================================
class Grid:
    """Container fuer Zustand, Geometrie und Position jeder Zelle."""

    def __init__(self, n, states, screen, radii, visible, neighbours):
        self.n = n
        self.states = states
        self.next_states = states.copy()
        self.flash = np.zeros(n, dtype=np.float32)
        self.screen = screen            # (n, 2) Pixel-Koordinaten
        self.radii = radii              # (n,) Pixel-Radien
        self.visible = visible          # (n,) bool
        self.neighbours = neighbours    # (n, k_max) int – Nachbar-Indices, -1 = none

    def reset_random(self, p=0.5):
        self.states = (np.random.random(self.n) < p).astype(np.uint8)
        self.next_states = self.states.copy()
        self.flash.fill(0.0)

    def reset_empty(self):
        self.states.fill(0)
        self.next_states.fill(0)
        self.flash.fill(0.0)

    def toggle(self, ix):
        if 0 <= ix < self.n:
            self.states[ix] ^= 1
            self.flash[ix] = 1.0


# ==============================================================================
# TOPOLOGIE-BUILDER
# ==============================================================================
def build_hyperbolic(tiling):
    """Mehrere unabhaengige Poincare-Scheiben im tiling x tiling-Raster.

    Jede Kachel ist ein eigener Automat. Kein Wrap zwischen Kacheln
    (Verzerrung am Rand wuerde sonst die Geometrie zerstoeren).
    """
    n = tiling
    total = n * n * GRID_W * GRID_H
    states = np.zeros(total, dtype=np.uint8)
    screen = np.zeros((total, 2), dtype=np.float32)
    radii = np.zeros(total, dtype=np.float32)
    visible = np.ones(total, dtype=bool)
    neighbours = -np.ones((total, 4), dtype=np.int32)  # 4-Nachbar in Halbebene

    # Verteilung der Scheiben im Fenster
    margin = 30
    avail_w = WINDOW_W - 2 * margin
    avail_h = WINDOW_H - 2 * margin
    step_x = avail_w / n
    step_y = avail_h / n
    disk_r = min(step_x, step_y) * 0.42

    idx = 0
    for ty in range(n):
        for tx in range(n):
            cx = margin + step_x * (tx + 0.5)
            cy = margin + step_y * (ty + 0.5)
            for j in range(GRID_H):
                for i in range(GRID_W):
                    z = complex(ORIGIN_X + i * DX, ORIGIN_Y + j * DY)
                    w = cayley_forward(z)
                    abs_w = abs(w)
                    if abs_w >= 0.999:
                        visible[idx] = False
                    sx = cx + w.real * disk_r
                    sy = cy - w.imag * disk_r
                    screen[idx] = (sx, sy)
                    radii[idx] = max(0.5, min(3.0, (1.0 - abs_w) * 2.5 + 0.5))

                    # 4-Nachbar in Halbebene, mit Wrap auf Tile-Ebene
                    li = (i - 1) % GRID_W
                    ri = (i + 1) % GRID_W
                    uj = (j - 1) % GRID_H
                    dj = (j + 1) % GRID_H
                    base = (ty * n + tx) * GRID_W * GRID_H
                    neighbours[idx] = [
                        base + j * GRID_W + li,
                        base + j * GRID_W + ri,
                        base + uj * GRID_W + i,
                        base + dj * GRID_W + i,
                    ]
                    idx += 1
    return Grid(total, states, screen, radii, visible, neighbours)


def build_donut():
    """NxM-Rechteck mit Wrap-around in beide Achsen (Torus)."""
    W, H = 80, 60
    total = W * H
    states = np.zeros(total, dtype=np.uint8)
    cell = 8
    grid_w_px = cell * W
    grid_h_px = cell * H
    off_x = (WINDOW_W - grid_w_px) // 2
    off_y = (WINDOW_H - grid_h_px) // 2
    screen = np.zeros((total, 2), dtype=np.float32)
    for j in range(H):
        for i in range(W):
            ix = j * W + i
            screen[ix] = (off_x + i * cell + cell / 2, off_y + j * cell + cell / 2)
    radii = np.full(total, cell * 0.45, dtype=np.float32)
    visible = np.ones(total, dtype=bool)
    neighbours = -np.ones((total, 4), dtype=np.int32)
    for j in range(H):
        for i in range(W):
            ix = j * W + i
            li = (i - 1) % W
            ri = (i + 1) % W
            uj = (j - 1) % H
            dj = (j + 1) % H
            neighbours[ix] = [j * W + li, j * W + ri, uj * W + i, dj * W + i]
    return Grid(total, states, screen, radii, visible, neighbours), cell


def build_rectangle():
    """NxM-Rechteck ohne Wrap (statische Raender)."""
    W, H = 80, 60
    total = W * H
    states = np.zeros(total, dtype=np.uint8)
    cell = 8
    grid_w_px = cell * W
    grid_h_px = cell * H
    off_x = (WINDOW_W - grid_w_px) // 2
    off_y = (WINDOW_H - grid_h_px) // 2
    screen = np.zeros((total, 2), dtype=np.float32)
    for j in range(H):
        for i in range(W):
            ix = j * W + i
            screen[ix] = (off_x + i * cell + cell / 2, off_y + j * cell + cell / 2)
    radii = np.full(total, cell * 0.45, dtype=np.float32)
    visible = np.ones(total, dtype=bool)
    neighbours = -np.ones((total, 4), dtype=np.int32)
    for j in range(H):
        for i in range(W):
            ix = j * W + i
            li = i - 1 if i > 0 else -1
            ri = i + 1 if i < W - 1 else -1
            uj = j - 1 if j > 0 else -1
            dj = j + 1 if j < H - 1 else -1
            neighbours[ix] = [li if li >= 0 else -1, ri if ri >= 0 else -1,
                              uj * W + i if uj >= 0 else -1,
                              dj * W + i if dj >= 0 else -1]
    # Korrigiere Offsets
    for j in range(H):
        for i in range(W):
            ix = j * W + i
            arr = []
            if i > 0:     arr.append(ix - 1)
            else:         arr.append(-1)
            if i < W - 1: arr.append(ix + 1)
            else:         arr.append(-1)
            if j > 0:     arr.append(ix - W)
            else:         arr.append(-1)
            if j < H - 1: arr.append(ix + W)
            else:         arr.append(-1)
            neighbours[ix] = arr
    return Grid(total, states, screen, radii, visible, neighbours), cell


def build_hex():
    """Hex-Gitter (axial-Koordinaten, pointy-top) mit 6 Nachbarn.

    Speichert im rechteckigen (q, r)-Koordinatensystem, wobei q Spalten
    und r Zeilen adressiert. Wrap ist optional, hier ohne (statisch).
    """
    COLS = 60  # q
    ROWS = 40  # r
    total = COLS * ROWS
    states = np.zeros(total, dtype=np.uint8)

    # Hex-Geometrie: pointy-top
    size = 8
    dx = math.sqrt(3) * size
    dy = 1.5 * size
    grid_w_px = dx * COLS + dx / 2
    grid_h_px = dy * ROWS + size
    off_x = (WINDOW_W - grid_w_px) // 2 + dx / 2
    off_y = (WINDOW_H - grid_h_px) // 2 + size

    screen = np.zeros((total, 2), dtype=np.float32)
    for r in range(ROWS):
        for q in range(COLS):
            ix = r * COLS + q
            cx = off_x + q * dx + (r % 2) * (dx / 2)
            cy = off_y + r * dy
            screen[ix] = (cx, cy)
    radii = np.full(total, size * 0.85, dtype=np.float32)
    visible = np.ones(total, dtype=bool)

    # 6 Nachbarn in axial: (+1,0), (+1,-1), (0,-1), (-1,0), (-1,+1), (0,+1)
    neighbours = -np.ones((total, 6), dtype=np.int32)
    OFF = [(1, 0), (1, -1), (0, -1), (-1, 0), (-1, 1), (0, 1)]
    for r in range(ROWS):
        for q in range(COLS):
            ix = r * COLS + q
            arr = []
            for dq, dr in OFF:
                nq, nr = q + dq, r + dr
                if 0 <= nq < COLS and 0 <= nr < ROWS:
                    arr.append(nr * COLS + nq)
                else:
                    arr.append(-1)
            neighbours[ix] = arr
    return Grid(total, states, screen, radii, visible, neighbours)


# ==============================================================================
# RULE-STEP
# ==============================================================================
def step_1d_rowwise(grid, table):
    """Regel 30/110 – zeilenweise, mit Wrap pro Zeile.

    Funktioniert nur fuer Hyperbolic- und Donut-Topologien. Im Rechteck
    ohne Wrap werden die Raender mit 0 statt Wrap behandelt (siehe
    step_1d_no_wrap).
    """
    n = grid.n
    s = grid.states
    nxt = s.copy()
    # Wir suchen die Zeile anhand des screen-y-Wertes.
    rows = {}
    for ix in range(n):
        sy = int(grid.screen[ix, 1])
        rows.setdefault(sy, []).append(ix)
    for sy, indices in rows.items():
        indices.sort(key=lambda ix: grid.screen[ix, 0])
        L = len(indices)
        for k, ix in enumerate(indices):
            l = s[indices[(k - 1) % L]]
            c = s[ix]
            r = s[indices[(k + 1) % L]]
            nxt[ix] = table[(l, c, r)]
    return nxt


def step_2d(grid, births, survives):
    """2D Life-like Regel mit 4 oder 6 Nachbarn (grid.neighbours bestimmt)."""
    s = grid.states
    n = grid.n
    nxt = np.zeros_like(s)
    for ix in range(n):
        nbs = grid.neighbours[ix]
        live = sum(1 for j in nbs if j >= 0 and s[j] == 1)
        if s[ix] == 1:
            nxt[ix] = 1 if live in survives else 0
        else:
            nxt[ix] = 1 if live in births else 0
    return nxt


# ==============================================================================
# AUTOMATON
# ==============================================================================
class Automaton:
    def __init__(self):
        self.topo = TOPO_HYPERBOLIC
        self.tiling = 1
        self.rule = RULE_GOL
        self.grid = self._build_grid()
        self.grid.reset_random(0.5)
        self.gen = 0
        self.running = True
        self.gen_per_sec = 8

    def _build_grid(self):
        if self.topo == TOPO_HYPERBOLIC:
            return build_hyperbolic(self.tiling)
        if self.topo == TOPO_DONUT:
            g, _ = build_donut()
            return g
        if self.topo == TOPO_RECTANGLE:
            g, _ = build_rectangle()
            return g
        if self.topo == TOPO_HEX:
            return build_hex()
        raise ValueError(self.topo)

    def available_rules(self):
        if self.topo == TOPO_HEX:
            return HEX_RULES
        return SQUARE_RULES

    def set_rule(self, rule):
        if rule in self.available_rules():
            self.rule = rule
            return True
        return False

    def set_topo(self, topo):
        if topo == self.topo:
            return
        if topo == TOPO_HEX and self.rule not in HEX_RULES:
            self.rule = HEX_RULES[0]
        elif topo != TOPO_HEX and self.rule not in SQUARE_RULES:
            self.rule = SQUARE_RULES[2]  # GoL
        self.topo = topo
        self.grid = self._build_grid()
        self.grid.reset_random(0.5)
        self.gen = 0

    def set_tiling(self, n):
        if n == self.tiling or self.topo != TOPO_HYPERBOLIC:
            return
        self.tiling = n
        self.grid = self._build_grid()
        self.grid.reset_random(0.5)
        self.gen = 0

    def step(self):
        if self.topo == TOPO_HEX:
            # Immer 2D mit 6 Nachbarn
            b, s = parse_bs(self.rule)
            new = step_2d(self.grid, b, s)
        elif self.rule in (RULE_R30, RULE_R110):
            tbl = R30_TABLE if self.rule == RULE_R30 else R110_TABLE
            new = step_1d_rowwise(self.grid, tbl)
        else:
            b, s = parse_bs(self.rule)
            new = step_2d(self.grid, b, s)
        changed = new != self.grid.states
        self.grid.flash[changed] = 1.0
        self.grid.next_states = new
        self.grid.states = new
        self.gen += 1

    def find_cell(self, sx, sy):
        """Finde naechste Zelle zu (sx, sy). Lineare Suche (klein genug)."""
        if self.grid.n > 5000:
            return None
        best_ix, best_d2 = -1, 1e18
        for ix in range(self.grid.n):
            if not self.grid.visible[ix]:
                continue
            dx = self.grid.screen[ix, 0] - sx
            dy = self.grid.screen[ix, 1] - sy
            d2 = dx * dx + dy * dy
            r = self.grid.radii[ix]
            if d2 < (r + 2) ** 2 and d2 < best_d2:
                best_d2 = d2
                best_ix = ix
        return best_ix if best_ix >= 0 else None


# ==============================================================================
# UI
# ==============================================================================
class UI:
    def __init__(self, surface, automaton):
        self.s = surface
        self.automaton = automaton
        self._font_blank = pygame.Surface((1, 1), pygame.SRCALPHA)  # Dummy-1x1
        self.font = self.font_small = self.font_legend = None
        try:
            if not pygame.font.get_init():
                pygame.font.init()
            self.font       = pygame.font.SysFont(FONT_NAME, FONT_SIZE)
            self.font_small = pygame.font.SysFont(FONT_NAME, FONT_SMALL)
            self.font_legend = pygame.font.SysFont(FONT_NAME, 14)
        except (NotImplementedError, pygame.error):
            # font module nicht verfuegbar (z. B. headless/dummy-Treiber)
            pass
        # Layout-Panels
        self.panel_top   = pygame.Rect(20, 20, WINDOW_W - 40, 130)
        self.panel_bot   = pygame.Rect(20, WINDOW_H - 100, WINDOW_W - 40, 80)
        self.panel_help  = pygame.Rect(20, WINDOW_H - 180, WINDOW_W - 40, 60)
        # Topologie-Buttons (Reihe 1)
        self.topo_buttons = []
        bw, bh, gap = 150, 30, 10
        x0 = self.panel_top.x + 16
        y0 = self.panel_top.y + 14
        for i, t in enumerate(TOPOLOGIES):
            self.topo_buttons.append((t, pygame.Rect(x0 + i * (bw + gap), y0, bw, bh)))
        # Tiling-Buttons (Reihe 2, nur bei Hyperbolic sichtbar)
        self.tiling_buttons = []
        y1 = self.panel_top.y + 52
        for i, n in enumerate(TILING_OPTIONS):
            self.tiling_buttons.append((n, pygame.Rect(x0 + i * (bw + gap), y0, bw, bh)))
        # Regel-Buttons (Reihe 3)
        self.rule_buttons = []
        y2 = self.panel_top.y + 90
        for i, r in enumerate(SQUARE_RULES):
            self.rule_buttons.append((r, pygame.Rect(x0 + i * (bw + gap), y2, bw, bh)))
        # Speed-Slider
        self.slider_x0 = self.panel_bot.x + 200
        self.slider_x1 = self.panel_bot.right - 30
        self.slider_y  = self.panel_bot.centery

    def _glass_rect(self, rect, alpha=UI_ALPHA):
        panel = pygame.Surface(rect.size, pygame.SRCALPHA)
        panel.fill((*TEAL, alpha))
        overlay = pygame.Surface(rect.size, pygame.SRCALPHA)
        for y in range(rect.height):
            a = int(alpha * 0.5)
            overlay.set_at((0, y), (*PETROL_DK, a // 4))
        panel.blit(overlay, (0, 0))
        pygame.draw.rect(panel, WHITE, panel.get_rect(), 1)
        self.s.blit(panel, rect.topleft)

    def _txt(self, font, text, color):
        """Sicherer Text-Render: gibt leere Surface zurueck, wenn font=None."""
        if font is None:
            return self._font_blank
        return font.render(text, True, color)

    def _draw_button(self, rect, label, active, dim=False):
        bg = (*TEAL, UI_ALPHA) if not dim else (*PETROL_DK, UI_ALPHA // 2)
        panel = pygame.Surface(rect.size, pygame.SRCALPHA)
        panel.fill(bg)
        pygame.draw.rect(panel, WHITE, panel.get_rect(), 1)
        if active:
            pygame.draw.rect(panel, FOREST, panel.get_rect(), 2)
        self.s.blit(panel, rect.topleft)
        text = self._txt(self.font_small, label, WHITE)
        self.s.blit(text, (rect.x + 8, rect.y + 7))

    def draw(self):
        self._glass_rect(self.panel_top)
        self._glass_rect(self.panel_bot)
        self._glass_rect(self.panel_help)

        # Reihen-Labels
        lbl_topo   = self._txt(self.font_small, "Topologie / Topology", FOREST)
        lbl_tiling = self._txt(self.font_small, "Kacheln / Tiling",     FOREST)
        lbl_rule   = self._txt(self.font_small, "Regel / Rule",         FOREST)
        # Labels am Rand der Panel
        self.s.blit(lbl_topo,   (self.panel_top.x + 16, self.panel_top.y - 0))  # within row 1
        # Buttons
        for t, rect in self.topo_buttons:
            label = {"hyperbolic": "Hyperbolic", "donut": "Donut",
                     "rectangle": "Rectangle", "hex": "Hex"}[t]
            self._draw_button(rect, label, active=(t == self.automaton.topo))
        for n, rect in self.tiling_buttons:
            # Versteckt, wenn nicht Hyperbolic
            visible = self.automaton.topo == TOPO_HYPERBOLIC
            self._draw_button(rect, f"{n}x{n}", active=(n == self.automaton.tiling),
                              dim=not visible)
        for r, rect in self.rule_buttons:
            self._draw_button(rect, r, active=(r == self.automaton.rule))
        # Im Hex-Modus: SQUARE_RULES dim zeichnen, HEX_RULES aktiv
        if self.automaton.topo == TOPO_HEX:
            for r, rect in self.rule_buttons:
                self._draw_button(rect, r, active=False, dim=True)

        # Regel-Status-Text (rechts oben)
        if self.automaton.topo == TOPO_HEX:
            hex_buttons = [self.rule_buttons[i] for i in range(len(self.rule_buttons))]
            # Eigentlich braeuchte es einen zweiten Satz; fuer die v2 vereinfacht:
            # wir nutzen die gleiche Button-Position, zeigen aber Hex-Regeln.
            pass

        # Speed-Slider
        pygame.draw.line(self.s, WHITE, (self.slider_x0, self.slider_y),
                         (self.slider_x1, self.slider_y), 1)
        frac = (self.automaton.gen_per_sec - 1) / 29
        knob_x = int(self.slider_x0 + frac * (self.slider_x1 - self.slider_x0))
        pygame.draw.circle(self.s, FOREST, (knob_x, self.slider_y), 8, 0)
        pygame.draw.circle(self.s, WHITE, (knob_x, self.slider_y), 8, 1)
        speed_label = self._txt(self.font_small,
            f"Tempo / Speed: {self.automaton.gen_per_sec} Gen/s", WHITE)
        self.s.blit(speed_label, (self.panel_bot.x + 20, self.panel_bot.centery - 8))

        # Status-Text rechts
        status = (f"Topo: {self.automaton.topo}  |  "
                  f"Tiles: {self.automaton.tiling}x{self.automaton.tiling}  |  "
                  f"Rule: {self.automaton.rule}  |  "
                  f"Gen: {self.automaton.gen}  |  "
                  f"{'PLAY' if self.automaton.running else 'PAUSE'}")
        st = self._txt(self.font_small, status, FOREST)
        self.s.blit(st, (self.panel_bot.x + 20, self.panel_bot.y + 36))

        # Help
        help_text = ("T: Topologie  |  1/2/3: Regel  |  +/-: Tempo  |  "
                     "Space: Pause  |  R: Reset  |  Klick: Toggle  |  Esc: Quit")
        ht = self._txt(self.font_legend, help_text, WHITE)
        self.s.blit(ht, (self.panel_help.x + 16, self.panel_help.centery - 8))

    def draw_grid(self):
        a = self.automaton
        g = a.grid
        # Hintergrund
        self.s.fill(PETROL_DK)
        for ix in range(g.n):
            if not g.visible[ix]:
                continue
            sx, sy = g.screen[ix]
            r = g.radii[ix]
            alive = g.states[ix] == 1
            flash = g.flash[ix]
            if alive:
                col = FOREST
            else:
                col = (*TEAL, 60)
            # Punkt
            if a.topo == TOPO_HEX or a.topo == TOPO_DONUT or a.topo == TOPO_RECTANGLE:
                # Quadrate (Donut/Rectangle) oder Hexes (Hex)
                if a.topo == TOPO_HEX:
                    pts = []
                    for k in range(6):
                        ang = math.pi / 180 * (60 * k - 30)
                        pts.append((int(sx + r * math.cos(ang)),
                                    int(sy + r * math.sin(ang))))
                    pygame.draw.polygon(self.s, col, pts, 0)
                    pygame.draw.polygon(self.s, WHITE, pts, 1)
                else:
                    rect = pygame.Rect(int(sx - r), int(sy - r), int(2 * r), int(2 * r))
                    pygame.draw.rect(self.s, col, rect, 0)
                    pygame.draw.rect(self.s, WHITE, rect, 1)
            else:
                # Hyperbolic: Kreise
                pygame.draw.circle(self.s, col, (int(sx), int(sy)), max(1, int(r)))
        g.flash *= 0.92  # decay

    def handle_click(self, mx, my):
        # Topologie
        for t, rect in self.topo_buttons:
            if rect.collidepoint(mx, my):
                self.automaton.set_topo(t)
                return
        # Tiling
        for n, rect in self.tiling_buttons:
            if rect.collidepoint(mx, my) and self.automaton.topo == TOPO_HYPERBOLIC:
                self.automaton.set_tiling(n)
                return
        # Regeln
        for r, rect in self.rule_buttons:
            if rect.collidepoint(mx, my):
                if self.automaton.set_rule(r):
                    return
        # Speed-Slider
        if self.slider_y - 10 <= my <= self.slider_y + 10:
            if self.slider_x0 <= mx <= self.slider_x1:
                frac = (mx - self.slider_x0) / (self.slider_x1 - self.slider_x0)
                self.automaton.gen_per_sec = max(1, min(30, int(1 + frac * 29)))
        # Sonst: Zelle toggeln
        ix = self.automaton.find_cell(mx, my)
        if ix is not None and ix >= 0:
            self.automaton.grid.toggle(ix)


# ==============================================================================
# MAIN
# ==============================================================================
def main():
    pygame.init()
    screen = pygame.display.set_mode((WINDOW_W, WINDOW_H))
    pygame.display.set_caption("Zellularautomat v2 – Hyperbolic / Donut / Rectangle / Hex")
    auto = Automaton()
    ui = UI(screen, auto)
    clock = pygame.time.Clock()
    last_step = 0
    running = True
    while running:
        for ev in pygame.event.get():
            if ev.type == pygame.QUIT:
                running = False
            elif ev.type == pygame.KEYDOWN:
                if ev.key == pygame.K_ESCAPE:
                    running = False
                elif ev.key == pygame.K_SPACE:
                    auto.running = not auto.running
                elif ev.key == pygame.K_r:
                    auto.grid.reset_random(0.5)
                    auto.gen = 0
                elif ev.key == pygame.K_t:
                    idx = TOPOLOGIES.index(auto.topo)
                    auto.set_topo(TOPOLOGIES[(idx + 1) % len(TOPOLOGIES)])
                elif ev.key in (pygame.K_1, pygame.K_2, pygame.K_3):
                    available = auto.available_rules()
                    idx = [pygame.K_1, pygame.K_2, pygame.K_3].index(ev.key)
                    if idx < len(available):
                        auto.set_rule(available[idx])
                elif ev.key in (pygame.K_PLUS, pygame.K_KP_PLUS, pygame.K_EQUALS):
                    auto.gen_per_sec = min(30, auto.gen_per_sec + 1)
                elif ev.key in (pygame.K_MINUS, pygame.K_KP_MINUS):
                    auto.gen_per_sec = max(1, auto.gen_per_sec - 1)
            elif ev.type == pygame.MOUSEBUTTONDOWN:
                ui.handle_click(*ev.pos)

        now = pygame.time.get_ticks()
        if auto.running and now - last_step >= 1000 // auto.gen_per_sec:
            auto.step()
            last_step = now

        ui.draw_grid()
        ui.draw()
        pygame.display.flip()
        clock.tick(60)
    pygame.quit()


if __name__ == "__main__":
    main()
