From b8eeaa4739932b04212b340802dda1f53e31615c Mon Sep 17 00:00:00 2001 From: Rbanh Date: Mon, 24 Feb 2025 18:38:44 -0500 Subject: [PATCH] Major code restructuring: Reorganized project structure, added AI components - Added new module structure (ai, config, core, ui) - Moved snake and food logic into core module - Added training configuration - Updated gitignore for project-specific files - Modified tests to match new structure --- .gitignore | 12 + ROADMAP.md | 106 ++++-- requirements.txt | 11 +- src/__init__.py | 15 +- src/ai/config_ui.py | 235 ++++++++++++ src/ai/controller.py | 70 ++++ src/ai/environment.py | 787 +++++++++++++++++++++++++++++++++++++++ src/ai/model_networks.py | 93 +++++ src/ai/train.py | 423 +++++++++++++++++++++ src/ai/visualize.py | 540 +++++++++++++++++++++++++++ src/cli.py | 16 +- src/config/__init__.py | 36 ++ src/config/colors.py | 25 ++ src/config/constants.py | 31 ++ src/config/settings.py | 52 +++ src/core/__init__.py | 16 + src/core/food.py | 82 ++++ src/core/game_session.py | 239 ++++++++++++ src/core/snake.py | 357 ++++++++++++++++++ src/food.py | 50 --- src/game.py | 399 +++++++------------- src/main.py | 10 +- src/menu.py | 2 + src/snake.py | 153 -------- src/ui/__init__.py | 20 + src/ui/game_area.py | 138 +++++++ src/ui/menu.py | 127 +++++++ src/ui/menu_item.py | 84 +++++ src/ui/pause_menu.py | 129 +++++++ src/ui/settings_menu.py | 159 ++++++++ tests/__init__.py | 4 +- tests/test_food.py | 2 +- tests/test_snake.py | 3 +- training_config.json | 15 + 34 files changed, 3928 insertions(+), 513 deletions(-) create mode 100644 src/ai/config_ui.py create mode 100644 src/ai/controller.py create mode 100644 src/ai/environment.py create mode 100644 src/ai/model_networks.py create mode 100644 src/ai/train.py create mode 100644 src/ai/visualize.py create mode 100644 src/config/__init__.py create mode 100644 src/config/colors.py create mode 100644 src/config/constants.py create mode 100644 src/config/settings.py create mode 100644 src/core/__init__.py create mode 100644 src/core/food.py create mode 100644 src/core/game_session.py create mode 100644 src/core/snake.py delete mode 100644 src/food.py delete mode 100644 src/snake.py create mode 100644 src/ui/__init__.py create mode 100644 src/ui/game_area.py create mode 100644 src/ui/menu.py create mode 100644 src/ui/menu_item.py create mode 100644 src/ui/pause_menu.py create mode 100644 src/ui/settings_menu.py create mode 100644 training_config.json diff --git a/.gitignore b/.gitignore index 755b4a9..37c6be5 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ ENV/ .virtualenv/ virtualenv/ *venv/ +Activate.ps1 # Python __pycache__/ @@ -39,3 +40,14 @@ wheels/ # OS .DS_Store Thumbs.db + +# Project specific +logs/ +models/ +.pytest_cache/ +path/ +training_data/ +*.log +*.pth +*.ckpt +*.h5 diff --git a/ROADMAP.md b/ROADMAP.md index ea0c5d3..5ed09ab 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -3,19 +3,26 @@ This is a living document that tracks the development progress of the AI Snake Game project. ## Current Status -Last Updated: [Menu System Implementation] -- Added main menu interface with game mode selection -- Implemented menu navigation and selection -- Added game mode display during gameplay -- Integrated menu system with game states +Last Updated: [Control and Menu Improvements] +- Fixed menu interaction issues: + - Added proper mouse click handling + - Fixed menu item highlighting + - Improved keyboard and mouse navigation +- Improved snake movement system: + - Implemented input buffer for direction changes + - Fixed rapid direction change issues + - Improved responsiveness while maintaining safety - Previous updates: + - Added settings menu with game options + - Added wrap-around mode toggle + - Added speed increase toggle + - Added main menu interface + - Added game mode selection + - Implemented menu navigation + - Added game mode display - Added game restart functionality - - Added "Press any key to restart" message - - Implemented game state reset - Added Food class with spawning mechanics - - Integrated food system with main game loop - Implemented basic scoring system - - Added game over screen with score display ## Phase 1: Project Setup and Basic Structure ✅ - [x] Create project structure and virtual environment @@ -27,11 +34,18 @@ Last Updated: [Menu System Implementation] ## Phase 2: Core Game Development ✅ - [x] Implement basic game window using Pygame - [x] Create Snake class with movement mechanics + - [x] Basic movement and collision + - [x] Input buffering for direction changes + - [x] Fix rapid direction change issues + - [x] Prevent 180-degree turns - [x] Implement food spawning system - [x] Create Food class - [x] Add random food placement - [x] Add collision detection with food -- [x] Add collision detection (walls and self) +- [x] Add collision detection + - [x] Wall collisions + - [x] Self collisions + - [x] Add wrap-around mode option - [x] Implement scoring system - [x] Add score display - [ ] Add high score tracking @@ -39,12 +53,19 @@ Last Updated: [Menu System Implementation] - [x] Implement game over state transition - [x] Add restart functionality -## Phase 3: Game Polish and UI 🔄 +## Phase 3: Game Polish and UI ✅ - [x] Add main menu interface - [x] Create menu layout - [x] Add game mode selection - - [ ] Add settings options + - [x] Add settings menu + - [x] Fix menu interaction issues + - [x] Mouse click handling + - [x] Menu item highlighting + - [x] Keyboard/mouse navigation - [x] Implement pause functionality +- [x] Add settings options + - [x] Wrap-around mode toggle + - [x] Speed increase toggle - [ ] Add visual effects and animations - [ ] Snake movement animation - [ ] Food spawn animation @@ -58,36 +79,63 @@ Last Updated: [Menu System Implementation] - [x] Add game over screen with restart option ## Phase 4: AI Implementation 🔄 -- [ ] Design AI algorithm (pathfinding) - - [ ] Research and select appropriate algorithm - - [ ] Implement basic pathfinding -- [ ] Implement AI snake movement logic - - [ ] Create AI controller class - - [ ] Implement decision making -- [ ] Add AI difficulty levels - - [ ] Easy mode (basic pathfinding) - - [ ] Medium mode (improved decision making) +- [ ] Set up ML training environment + - [ ] Create gym-like interface for the game + - [ ] Define observation space (snake + food state) + - [ ] Define action space (4 directions) + - [ ] Implement reward system +- [ ] Implement ML infrastructure + - [ ] Add ML dependencies (PyTorch/TensorFlow) + - [ ] Create neural network architecture + - [ ] Set up training pipeline +- [ ] Train AI models for different difficulties + - [ ] Easy mode (basic food seeking) + - [ ] Medium mode (balanced survival/food seeking) - [ ] Hard mode (optimal pathfinding) +- [ ] Create AI controller class + - [ ] Model loading and inference + - [ ] Real-time decision making + - [ ] Performance optimization - [x] Create mode selection (Player vs AI) -- [ ] Optimize AI performance +- [ ] Add training utilities + - [ ] Save/load model checkpoints + - [ ] Training progress visualization + - [ ] Model performance metrics ## Phase 5: Final Polish -- [ ] Add configuration options - - [ ] Game speed settings +- [x] Add configuration options + - [x] Game speed settings + - [x] Wrap-around option - [ ] Visual settings - [ ] Sound settings - [ ] Implement save/load functionality for high scores - [ ] Bug fixing and performance optimization + - [x] Fix menu interaction bugs + - [x] Fix movement control issues + - [x] Optimize input handling - [ ] Code cleanup and documentation - - [ ] Add docstrings + - [x] Add docstrings - [ ] Create API documentation - [ ] Update README with final features -- [ ] Final testing and refinements + +## Resolved Issues +1. Menu Interaction + - Issue: Menu items not responding to clicks + - Solution: Added proper mouse click event handling in game loop + - Solution: Fixed menu item highlighting and selection state + +2. Snake Movement + - Issue: Rapid direction changes causing self-collision + - Issue: Missed inputs during quick turns + - Solution: Implemented input buffer system + - Solution: Process one direction change per grid movement + - Solution: Store up to 2 pending direction changes ## Next Steps 1. Implement AI snake movement - - Design pathfinding algorithm - - Create AI controller class + - Set up ML training environment + - Create neural network architecture + - Train initial model 2. Add high score system - Implement score persistence - Add high score display @@ -96,7 +144,7 @@ Last Updated: [Menu System Implementation] - Add game sounds ## Known Issues -- None reported yet +- None reported ## Future Considerations - Multiplayer support diff --git a/requirements.txt b/requirements.txt index 8db134c..1003524 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,12 @@ pygame==2.5.2 -numpy==1.24.3 +numpy==1.26.3 black==23.12.1 # for code formatting pytest==7.4.3 # for testing -pytest-xdist==3.5.0 # for parallel test execution \ No newline at end of file +pytest-xdist==3.5.0 # for parallel test execution +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.1.2+cu118 # CUDA 11.8 support +stable-baselines3[extra]==2.2.1 # for reinforcement learning +tensorboard==2.15.1 # for training visualization +gymnasium==0.29.1 # for RL environment interface +tqdm==4.66.1 +rich==13.7.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index ad42af4..80daf36 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -9,11 +9,12 @@ This package contains all the core game components including: - AI controllers (coming soon) """ +from src import config +from src import core +from src import ai +from src import ui from src.game import Game, GameState -from src.snake import Snake, Direction -from src.food import Food -from src.menu import Menu, GameMode, MenuItem -from src.cli import main as cli_main +from src.ui.menu import GameMode __version__ = '0.1.0' __author__ = 'Rbanh' @@ -28,5 +29,9 @@ __all__ = [ 'Menu', 'GameMode', 'MenuItem', - 'cli_main' + 'cli_main', + 'config', + 'core', + 'ai', + 'ui' ] \ No newline at end of file diff --git a/src/ai/config_ui.py b/src/ai/config_ui.py new file mode 100644 index 0000000..2739b12 --- /dev/null +++ b/src/ai/config_ui.py @@ -0,0 +1,235 @@ +""" +Training Configuration UI + +This module provides a graphical interface for configuring AI training parameters. +""" + +import pygame +import json +import os +from typing import Dict, Any, Optional, Tuple, List + +class ConfigUI: + def __init__(self, width: int = 800, height: int = 600): + """Initialize the configuration UI.""" + pygame.init() + self.width = width + self.height = height + self.screen = pygame.display.set_mode((width, height)) + pygame.display.set_caption("Training Configuration") + + # Fonts + self.title_font = pygame.font.Font(None, 48) + self.header_font = pygame.font.Font(None, 36) + self.text_font = pygame.font.Font(None, 24) + + # Colors + self.colors = { + 'background': (0, 0, 0), + 'text': (200, 200, 200), + 'highlight': (0, 255, 0), + 'button': (40, 40, 40), + 'button_hover': (60, 60, 60), + 'input_bg': (30, 30, 30), + 'input_active': (50, 50, 50) + } + + # Parameters + self.parameters = { + 'timesteps': { + 'value': 1000000, + 'type': 'int', + 'min': 1000, + 'max': 10000000, + 'description': 'Total timesteps to train for' + }, + 'learning_rate': { + 'value': 0.0003, + 'type': 'float', + 'min': 0.00001, + 'max': 0.01, + 'description': 'Learning rate for training' + }, + 'batch_size': { + 'value': 64, + 'type': 'int', + 'min': 32, + 'max': 512, + 'description': 'Batch size for training' + }, + 'n_envs': { + 'value': 8, + 'type': 'int', + 'min': 1, + 'max': 32, + 'description': 'Number of parallel environments' + }, + 'n_steps': { + 'value': 2048, + 'type': 'int', + 'min': 128, + 'max': 8192, + 'description': 'Number of steps per update' + } + } + + # UI state + self.active_input = None + self.input_text = "" + self.scroll_offset = 0 + self.max_scroll = max(0, len(self.parameters) * 60 - (height - 200)) + + # Load saved config if exists + self.config_file = "training_config.json" + self.load_config() + + def load_config(self) -> None: + """Load configuration from file.""" + if os.path.exists(self.config_file): + try: + with open(self.config_file, 'r') as f: + saved_config = json.load(f) + for key, value in saved_config.items(): + if key in self.parameters: + self.parameters[key]['value'] = value + except: + pass + + def save_config(self) -> None: + """Save configuration to file.""" + config = {key: param['value'] for key, param in self.parameters.items()} + with open(self.config_file, 'w') as f: + json.dump(config, f, indent=4) + + def draw_text_input(self, rect: pygame.Rect, value: Any, active: bool) -> None: + """Draw a text input field.""" + color = self.colors['input_active'] if active else self.colors['input_bg'] + pygame.draw.rect(self.screen, color, rect) + pygame.draw.rect(self.screen, self.colors['text'], rect, 1) + + text = str(value) + if active: + text = self.input_text + "|" + + text_surface = self.text_font.render(text, True, self.colors['text']) + text_rect = text_surface.get_rect(midleft=(rect.left + 5, rect.centery)) + self.screen.blit(text_surface, text_rect) + + def draw_button(self, rect: pygame.Rect, text: str, hover: bool = False) -> None: + """Draw a button.""" + color = self.colors['button_hover'] if hover else self.colors['button'] + pygame.draw.rect(self.screen, color, rect) + pygame.draw.rect(self.screen, self.colors['text'], rect, 1) + + text_surface = self.text_font.render(text, True, self.colors['text']) + text_rect = text_surface.get_rect(center=rect.center) + self.screen.blit(text_surface, text_rect) + + def validate_input(self, param: Dict[str, Any], value: str) -> Optional[Any]: + """Validate and convert input value.""" + try: + if param['type'] == 'int': + val = int(value) + else: + val = float(value) + + if val < param['min'] or val > param['max']: + return None + return val + except: + return None + + def run(self) -> Optional[Dict[str, Any]]: + """Run the configuration UI. Returns the config dict if saved, None if cancelled.""" + running = True + save_clicked = False + mouse_pos = (0, 0) + + while running: + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + + elif event.type == pygame.MOUSEBUTTONDOWN: + mouse_pos = event.pos + + # Check parameter inputs + y = 100 - self.scroll_offset + for name, param in self.parameters.items(): + input_rect = pygame.Rect(300, y, 200, 30) + if input_rect.collidepoint(mouse_pos): + self.active_input = name + self.input_text = str(param['value']) + y += 60 + + # Check buttons + save_rect = pygame.Rect(self.width//2 - 150, self.height - 60, 140, 40) + cancel_rect = pygame.Rect(self.width//2 + 10, self.height - 60, 140, 40) + + if save_rect.collidepoint(mouse_pos): + self.save_config() + save_clicked = True + running = False + elif cancel_rect.collidepoint(mouse_pos): + running = False + + elif event.type == pygame.MOUSEBUTTONUP: + if event.button == 4: # Mouse wheel up + self.scroll_offset = max(0, self.scroll_offset - 30) + elif event.button == 5: # Mouse wheel down + self.scroll_offset = min(self.max_scroll, self.scroll_offset + 30) + + elif event.type == pygame.KEYDOWN: + if self.active_input is not None: + if event.key == pygame.K_RETURN: + param = self.parameters[self.active_input] + if val := self.validate_input(param, self.input_text): + param['value'] = val + self.active_input = None + elif event.key == pygame.K_BACKSPACE: + self.input_text = self.input_text[:-1] + else: + if event.unicode.isnumeric() or event.unicode == '.': + self.input_text += event.unicode + + # Draw UI + self.screen.fill(self.colors['background']) + + # Title + title = "Training Configuration" + title_surface = self.title_font.render(title, True, self.colors['highlight']) + title_rect = title_surface.get_rect(midtop=(self.width//2, 20)) + self.screen.blit(title_surface, title_rect) + + # Parameters + y = 100 - self.scroll_offset + for name, param in self.parameters.items(): + if 0 <= y <= self.height - 100: + # Parameter name and description + name_surface = self.header_font.render(name, True, self.colors['text']) + desc_surface = self.text_font.render(param['description'], True, self.colors['text']) + self.screen.blit(name_surface, (20, y)) + self.screen.blit(desc_surface, (20, y + 30)) + + # Input field + input_rect = pygame.Rect(300, y, 200, 30) + self.draw_text_input(input_rect, param['value'], name == self.active_input) + y += 60 + + # Buttons + save_rect = pygame.Rect(self.width//2 - 150, self.height - 60, 140, 40) + cancel_rect = pygame.Rect(self.width//2 + 10, self.height - 60, 140, 40) + + save_hover = save_rect.collidepoint(mouse_pos) + cancel_hover = cancel_rect.collidepoint(mouse_pos) + + self.draw_button(save_rect, "Save", save_hover) + self.draw_button(cancel_rect, "Cancel", cancel_hover) + + pygame.display.flip() + + pygame.quit() + + if save_clicked: + return {key: param['value'] for key, param in self.parameters.items()} + return None \ No newline at end of file diff --git a/src/ai/controller.py b/src/ai/controller.py new file mode 100644 index 0000000..b539eea --- /dev/null +++ b/src/ai/controller.py @@ -0,0 +1,70 @@ +""" +AI Controller for Snake Game + +This module provides the AI controller that uses trained models to play the game. +It handles model loading, state processing, and decision making during gameplay. +""" + +import os +import numpy as np +from stable_baselines3 import PPO +from ai.environment import SnakeEnv, Direction + +class AIController: + """AI controller that uses trained models to play the game.""" + + def __init__(self, difficulty: str = "medium"): + """ + Initialize the AI controller. + + Args: + difficulty: "easy", "medium", or "hard" + """ + self.difficulty = difficulty + self.model = None + self.env = SnakeEnv() # For state processing + self._load_model() + + def _load_model(self): + """Load the appropriate model based on difficulty.""" + model_path = f"models/{self.difficulty}/best_model.zip" + if not os.path.exists(model_path): + # Fall back to final model if best model doesn't exist + model_path = f"models/{self.difficulty}/final_model.zip" + + if not os.path.exists(model_path): + raise FileNotFoundError( + f"No model found for difficulty {self.difficulty}. " + "Please train the model first." + ) + + self.model = PPO.load(model_path) + + def get_action(self, game_state: dict) -> Direction: + """ + Get the next action based on the current game state. + + Args: + game_state: Dictionary containing: + - snake: Snake object + - food: Food object + - width: Game width + - height: Game height + + Returns: + Direction enum indicating the chosen action + """ + # Update environment with current game state + self.env.snake = game_state["snake"] + self.env.food = game_state["food"] + self.env.width = game_state["width"] + self.env.height = game_state["height"] + + # Get state observation + state = self.env._get_state() + + # Get action from model + action, _ = self.model.predict(state, deterministic=True) + + # Convert action index to Direction + return self.env.action_space[action] \ No newline at end of file diff --git a/src/ai/environment.py b/src/ai/environment.py new file mode 100644 index 0000000..7cfa8e1 --- /dev/null +++ b/src/ai/environment.py @@ -0,0 +1,787 @@ +""" +Snake Game Environment for Reinforcement Learning + +This module provides a gym-like interface for training AI agents to play the snake game. +It includes: +- State observation space +- Action space +- Reward system +- Environment dynamics +""" + +import numpy as np +import random +import gymnasium as gym +from gymnasium import spaces +from typing import Tuple, List, Dict, Any, Optional +import pygame +from src.core import Snake +from src.core import Direction +from src.core import GameSession +from src.ui import GameArea +from src.config import GameRules + +def manhattan_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> int: + """Calculate Manhattan distance between two points.""" + return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) + +def manhattan_distance_wrap(pos1: Tuple[int, int], pos2: Tuple[int, int], size: int) -> int: + """Calculate Manhattan distance between two points with wrap-around.""" + dx = abs(pos1[0] - pos2[0]) + dy = abs(pos1[1] - pos2[1]) + return min(dx, size - dx) + min(dy, size - dy) + +class SnakeEnv(gym.Env): + """A Gymnasium environment for training snake AI agents.""" + + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} + + def __init__(self, size=30, render_mode=None, difficulty="easy"): + super().__init__() + self.size = size + self.render_mode = render_mode + self.difficulty = difficulty + + # Action to Direction mapping + self.action_to_direction = { + 0: Direction.UP, + 1: Direction.RIGHT, + 2: Direction.DOWN, + 3: Direction.LEFT + } + + # Direction to index mapping for one-hot encoding + self.direction_to_index = { + Direction.UP: 0, + Direction.RIGHT: 1, + Direction.DOWN: 2, + Direction.LEFT: 3 + } + + # Initialize reward scales with emphasis on precision + self.reward_scales = { + 'food': 20.0, # Base food reward (will be scaled exponentially) + 'death': -5.0, # Reduced death penalty to encourage exploration + 'distance': 0.5, # Increased distance reward + 'survival': -0.01, # Small survival penalty + 'milestone': 10.0, # Larger milestone rewards + 'efficiency': 0.2, # Small efficiency reward + 'exploration': 0.5, # Increased exploration reward + 'safety': 1.0, # Safety reward for avoiding danger + 'timeout': -1.0, # Reduced timeout penalty + 'wrap_bonus': 0.2, # Small wrap-around bonus + 'near_miss': -0.1, # Very small near miss penalty + 'repetitive': -2 # Increased base penalty for repetitive movement + } + + # Near miss detection + self.near_miss_threshold = 1 # Only penalize very close misses + + # Time limit parameters - more lenient + self.base_time_limit = 200 # Double base time limit + self.length_time_bonus = 100 # Double length bonus + self.max_time_limit = 1000 # Much longer maximum time + + # Exploration decay + self.initial_exploration_bonus = 2.0 # Higher initial exploration + self.exploration_decay = 0.9995 # Slower decay + self.min_exploration_bonus = 0.5 # Higher minimum exploration + self.current_exploration_bonus = self.initial_exploration_bonus + self.total_episodes = 0 # Track number of episodes for decay + + # Curriculum learning parameters + self.curriculum_stage = 0 + self.success_threshold = { + 0: 3, # Stage 0: Basic movement (need more consistent success) + 1: 5, # Stage 1: Food collection + 2: 7, # Stage 2: Longer snake + 3: 10 # Stage 3: Full difficulty + } + self.consecutive_successes = 0 + self.stage_requirements = { + 0: 1, # Stage 0: Get one food + 1: 3, # Stage 1: Get three food + 2: 5, # Stage 2: Get five food + 3: 7 # Stage 3: Get seven food + } + + # Dynamic episode length based on curriculum stage + self.base_max_steps = 300 # More steps for exploration + self.max_steps = self.base_max_steps + + # Window dimensions for rendering + self.window_width = 1024 + self.window_height = 768 + + # Create game area and rules + self.game_area = GameArea(self.window_width, self.window_height) + self.rules = GameRules() + self.rules.wrap_around = False # No wrap-around for AI training + + # Define action and observation spaces + self.action_space = spaces.Discrete(4) # Up, Right, Down, Left + + num_channels = 26 # Total number of channels we created in _get_normalized_observation + self.observation_space = spaces.Box( + low=0, + high=1, + shape=(num_channels, size, size), # (channels, height, width) + dtype=np.float32 + ) + + # Game session + self.session = None + self.steps = 0 + self.steps_since_food = 0 + self.current_time = 0 + self.last_score = 0 + self.last_distance = float('inf') + self.last_food_steps = 0 + self.steps_since_direction_change = 0 # Track steps since last direction change + self.last_direction = None # Track previous direction + self.recent_positions = [] + self.steps_in_same_direction = 0 # Track steps without direction change + self.last_min_food_distance = float('inf') # Initialize minimum distance tracking + + # For rendering + self.window = None + self.clock = None + + self.score_milestone_rewards = {5: 10.0, 10: 20.0, 15: 30.0} # Bonus rewards at score milestones + self.last_milestone = 0 # Track last milestone reached + self.base_survival_penalty = -1.0 # Much larger constant penalty per step + + # Movement control parameters + self.max_direction_steps = 10 # Reduced from 20 to 10 + self.direction_step_increase = 2 # Reduced from 5 to 2 + self.min_safe_turns = 2 # Reduced from 3 to 2 + self.repetitive_threshold = 10 # Base threshold before repetitive penalty kicks in + self.repetitive_scale = 1.2 # Base scale for penalty growth + self.repetitive_wrap_scale = 2.0 # Much faster growth rate in wrap mode + self.max_repetitive_penalty = -5.0 # Increased maximum repetitive movement penalty + + # Initialize base rewards with adjusted values + rewards = { + 'food': 0.0, + 'death': 0.0, + 'distance': 0.0, + 'survival': self.base_survival_penalty * 0.5, + 'direction': 0.0, + 'efficiency': 0.0, + 'progress': 0.0, + 'alignment': 0.0, + 'repetitive': 0.0 + } + + def _get_normalized_observation(self) -> np.ndarray: + """ + Create a multi-channel 2D observation from the game state. + Spatial channels encode the grid layout (snake & food), + while extra scalar features are broadcast over additional channels. + """ + state = self.session.get_state() + snake_body = state["snake_body"] + food_position = state["food_position"] + head = snake_body[0] + body = snake_body + + # --- Spatial Channels --- + # Channel 0: Snake positions (1 for snake, 0 otherwise) + snake_channel = np.zeros((self.size, self.size), dtype=np.float32) + for pos in snake_body: + snake_channel[pos[1], pos[0]] = 1.0 + assert pos[0] < self.size and pos[1] < self.size, \ + f"Snake position {pos} exceeds grid size {self.size}" + + # Channel 1: Food position + food_channel = np.zeros((self.size, self.size), dtype=np.float32) + food_channel[food_position[1], food_position[0]] = 1.0 + + # --- Auxiliary Scalar Features --- + # Compute distances and apply wrap-around adjustments + dx = food_position[0] - head[0] + dy = food_position[1] - head[1] + if self.rules.wrap_around: + if dx > self.size/2: + wrap_dx = dx - self.size + elif dx < -self.size/2: + wrap_dx = dx + self.size + else: + wrap_dx = dx + if dy > self.size/2: + wrap_dy = dy - self.size + elif dy < -self.size/2: + wrap_dy = dy + self.size + else: + wrap_dy = dy + else: + wrap_dx = dx + wrap_dy = dy + + # Normalize + dx_norm = dx / self.size + dy_norm = dy / self.size + wrap_dx_norm = wrap_dx / self.size + wrap_dy_norm = wrap_dy / self.size + + # Danger indicators (vector of 4 values) + dangers = self._get_danger_observations_all_directions(head, body) + + # Other auxiliary scalar features + snake_length_norm = len(body) / (self.size**2) + # One-hot encode current direction + direction_one_hot = [0.0] * 4 + direction_one_hot[self.direction_to_index[state["snake_direction"]]] = 1.0 + + # One-hot encode relative food direction + food_dir = [0.0] * 4 + if abs(dx_norm) > abs(dy_norm): + if dx_norm > 0: + food_dir[self.direction_to_index[Direction.RIGHT]] = 1.0 + else: + food_dir[self.direction_to_index[Direction.LEFT]] = 1.0 + else: + if dy_norm > 0: + food_dir[self.direction_to_index[Direction.DOWN]] = 1.0 + else: + food_dir[self.direction_to_index[Direction.UP]] = 1.0 + + current_time_limit = min( + self.max_time_limit, + self.base_time_limit + (len(body) - 1) * self.length_time_bonus + ) + time_pressure = np.clip(self.steps_since_food / current_time_limit, 0, 1) + wrap_mode = float(self.rules.wrap_around) + progress = min(1.0, len(body) / 10) + recent_food = min(1.0, self.steps_since_food / 50) + curriculum_stage_norm = float(self.curriculum_stage) / 3 + advanced_length = float(len(body) > 5) + manhattan_norm = min(1.0, manhattan_distance(head, food_position) / (self.size * 2)) + + # --- Create "broadcasted" channels for each scalar --- + def broadcast_channel(value): + return np.full((self.size, self.size), value, dtype=np.float32) + + dx_channel = broadcast_channel(dx_norm) + dy_channel = broadcast_channel(dy_norm) + wrap_dx_channel = broadcast_channel(wrap_dx_norm) + wrap_dy_channel = broadcast_channel(wrap_dy_norm) + snake_length_chan = broadcast_channel(snake_length_norm) + time_pressure_chan = broadcast_channel(time_pressure) + wrap_mode_chan = broadcast_channel(wrap_mode) + progress_chan = broadcast_channel(progress) + recent_food_chan = broadcast_channel(recent_food) + curriculum_chan = broadcast_channel(curriculum_stage_norm) + advanced_length_chan = broadcast_channel(advanced_length) + manhattan_chan = broadcast_channel(manhattan_norm) + + # For one-hot features, create one channel per value + direction_channels = [broadcast_channel(v) for v in direction_one_hot] + food_dir_channels = [broadcast_channel(v) for v in food_dir] + + # For dangers, one channel per direction (4 channels) + danger_channels = [broadcast_channel(d) for d in dangers] + + # --- Stack all channels --- + # You can adjust the channel order as desired. + channels = [ + snake_channel, # Channel 0 + food_channel, # Channel 1 + dx_channel, # Channel 2 + dy_channel, # Channel 3 + wrap_dx_channel, # Channel 4 + wrap_dy_channel, # Channel 5 + ] + channels.extend(danger_channels) # Channels 6-9 + channels.append(snake_length_chan) # Channel 10 + channels.extend(direction_channels) # Channels 11-14 + channels.extend(food_dir_channels) # Channels 15-18 + channels.append(time_pressure_chan) # Channel 19 + channels.append(wrap_mode_chan) # Channel 20 + channels.append(progress_chan) # Channel 21 + channels.append(recent_food_chan) # Channel 22 + channels.append(curriculum_chan) # Channel 23 + channels.append(advanced_length_chan) # Channel 24 + channels.append(manhattan_chan) # Channel 25 + + # Final observation: shape (num_channels, size, size) + observation = np.stack(channels, axis=0) + return np.clip(observation, 0, 1) + + def _get_danger_observations_all_directions(self, head: Tuple[int, int], body: List[Tuple[int, int]]) -> List[float]: + """Get danger observations in all four directions.""" + dangers = [0.0] * 4 # [UP, RIGHT, DOWN, LEFT] + + # Check each direction + directions = [Direction.UP, Direction.RIGHT, Direction.DOWN, Direction.LEFT] + for i, direction in enumerate(directions): + dx, dy = direction.to_vector() + next_pos = (head[0] + dx, head[1] + dy) + + # Check wall collision + if not self.rules.wrap_around: + if (next_pos[0] < 0 or next_pos[0] >= self.size or + next_pos[1] < 0 or next_pos[1] >= self.size): + dangers[i] = 1.0 + continue + else: + next_pos = (next_pos[0] % self.size, next_pos[1] % self.size) + + # Check self collision + if next_pos in body[1:]: + dangers[i] = 1.0 + + return dangers + + def reset(self, seed=None, options=None): + """Reset with curriculum-based difficulty and exploration decay.""" + # Call parent reset without unpacking + super().reset(seed=seed) + if seed is not None: + random.seed(seed) + np.random.seed(seed) + + # Update exploration bonus + self.total_episodes += 1 + self.current_exploration_bonus = max( + self.min_exploration_bonus, + self.initial_exploration_bonus * (self.exploration_decay ** self.total_episodes) + ) + + # Initialize game session with curriculum-appropriate settings + self.session = GameSession( + self.size, + self.size, + self.rules + ) + + # Reset counters + self.steps = 0 + self.steps_since_food = 0 + self.current_time = 0 + self.last_score = 0 + self.last_distance = float('inf') + self.last_direction = self.session.get_state()["snake_direction"] # Track initial direction + self.recent_positions = [] + self.steps_in_same_direction = 0 # Track steps without direction change + self.last_min_food_distance = float('inf') # Initialize minimum distance tracking + + # Get initial observation + observation = self._get_normalized_observation() + info = { + "curriculum_stage": self.curriculum_stage, + "exploration_bonus": self.current_exploration_bonus + } + + return observation, info + + def _calculate_area_penalty(self, head_pos): + """Calculate penalty for staying in a small area.""" + # Add current position to history + self.recent_positions.append(head_pos) + if len(self.recent_positions) > self.max_recent_positions: + self.recent_positions.pop(0) + + if len(self.recent_positions) < 10: # Need minimum history to calculate + return 0.0 + + # Calculate bounding box of recent positions + x_coords = [p[0] for p in self.recent_positions] + y_coords = [p[1] for p in self.recent_positions] + area_width = max(x_coords) - min(x_coords) + 1 + area_height = max(y_coords) - min(y_coords) + 1 + area = area_width * area_height + + # Calculate unique positions visited + unique_positions = len(set(self.recent_positions)) + + # Penalize small areas and repeated positions + area_penalty = 0.0 + if area < 9: # 3x3 grid or smaller + area_penalty = -0.2 * (9 - area) / 8 # Max penalty -0.2 for 1x1 area + + # Additional penalty for revisiting same positions frequently + repetition_penalty = -0.3 * (1 - unique_positions / len(self.recent_positions)) + + return area_penalty + repetition_penalty + + def step(self, action): + """Improved reward structure with length-aware exploration.""" + self.steps += 1 + self.steps_since_food += 1 + + # Calculate current time limit based on snake length + snake_length = len(self.session.get_state()["snake_body"]) + current_time_limit = min( + self.max_time_limit, + self.base_time_limit + (snake_length - 1) * self.length_time_bonus + ) + + # Get state before action + prev_state = self.session.get_state() + prev_head = prev_state["snake_head"] + prev_food_dist = manhattan_distance(prev_head, prev_state["food_position"]) + if self.rules.wrap_around: + prev_food_dist = min(prev_food_dist, + manhattan_distance_wrap(prev_head, prev_state["food_position"], self.size)) + + # Convert action to Direction and apply + direction = self.action_to_direction[int(action)] + + # Update game state + state, base_reward, done = self.session.step(direction, self.current_time) + + # Initialize reward components + rewards = { + 'food': 0.0, + 'death': 0.0, + 'distance': 0.0, + 'survival': 0.0, + 'milestone': 0.0, + 'efficiency': 0.0, + 'exploration': 0.0, + 'safety': 0.0, + 'timeout': 0.0, + 'direction': 0.0, + 'repetitive': 0.0 + } + + # Get current state info + curr_head = state["snake_body"][0] + curr_food = state["food_position"] + curr_food_dist = manhattan_distance(curr_head, curr_food) + if self.rules.wrap_around: + curr_food_dist = min(curr_food_dist, + manhattan_distance_wrap(curr_head, curr_food, self.size)) + + # Update current time + self.current_time += self.session.move_cooldown + + # Food reward with exponential scaling based on snake length + food_eaten = state["score"] > prev_state["score"] + if food_eaten: + # Reset timer when food is eaten + self.steps_since_food = 0 + + # Calculate exponential food reward + base_food_reward = self.reward_scales['food'] + length_multiplier = 1.2 ** (len(state["snake_body"]) - 1) # 20% increase per length + rewards['food'] = base_food_reward * length_multiplier + + # Add milestone rewards for achieving certain scores + if state["score"] in [5, 10, 15, 20]: + rewards['milestone'] = self.reward_scales['milestone'] * (state["score"] / 5) + else: + # Only apply survival penalty if we haven't eaten food in a while + if self.steps_since_food > current_time_limit / 2: + penalty_factor = (self.steps_since_food - current_time_limit/2) / (current_time_limit/2) + rewards['survival'] = self.reward_scales['survival'] * penalty_factor + + # Check if direction changed and update counters + direction_changed = direction != self.last_direction + if direction_changed: + self.steps_in_same_direction = 0 + else: + self.steps_in_same_direction += 1 + # Apply growing penalty for repetitive movement + if self.steps_in_same_direction > (self.repetitive_threshold - 5 if self.rules.wrap_around else self.repetitive_threshold): + excess_steps = self.steps_in_same_direction - (self.repetitive_threshold - 5 if self.rules.wrap_around else self.repetitive_threshold) + # Use much faster growth rate in wrap-around mode + growth_rate = self.repetitive_wrap_scale if self.rules.wrap_around else self.repetitive_scale + repetitive_penalty = self.reward_scales['repetitive'] * (growth_rate ** excess_steps) + + # Apply additional multiplier in wrap-around mode that grows with steps + if self.rules.wrap_around: + wrap_multiplier = 1.0 + (excess_steps * 0.5) # Multiplier grows with each step + repetitive_penalty *= wrap_multiplier + rewards['repetitive'] = repetitive_penalty # No maximum cap - let it grow unbounded in wrap-around mode + + else: + # Cap the penalty when not in wrap-around mode + rewards['repetitive'] = max(repetitive_penalty, self.max_repetitive_penalty) + + + self.last_direction = direction + + # Check for timeout or stuck in loop - only if we haven't just eaten food + if (self.steps_since_food >= current_time_limit and not food_eaten): + done = True + rewards['timeout'] = self.reward_scales['timeout'] + # Scale timeout penalty based on distance to food + if curr_food_dist < 5: # Harsher penalty for timing out near food + rewards['timeout'] *= 1.5 + # Additional penalty for getting stuck in a loop + if self.steps_in_same_direction >= 30: # Also update this threshold + rewards['timeout'] *= 1.2 + + # Death penalty - calculate AFTER food rewards + if done and not self.steps_since_food >= current_time_limit: + base_death_penalty = self.reward_scales['death'] + + # Check if death was due to self-collision + if curr_head in state["snake_body"][1:]: + # Base multiplier for self-collision + self_collision_multiplier = 2.0 + + # Additional penalty scaling with score + # At score 5: 2.5x penalty + # At score 10: 3.0x penalty + # At score 15: 3.5x penalty + score_multiplier = 1.0 + (state["score"] / 10) + + # Apply both multipliers + rewards['death'] = base_death_penalty * self_collision_multiplier * score_multiplier + + # Additional penalty if we died near food + if curr_food_dist < 5: + rewards['death'] *= 1.5 # Even harsher if we collide with ourselves near food + else: + # For wall collisions, keep original scaling but make it less punishing at higher scores + rewards['death'] = base_death_penalty * (1.0 - min(0.5, state["score"] / 20)) + + # Add an immediate negative reward for losing potential score + potential_loss_penalty = -0.5 * state["score"] # Larger penalty for dying with higher scores + rewards['death'] += potential_loss_penalty + + # Track minimum distance to food and check for near misses + if curr_food_dist < self.last_min_food_distance: + self.last_min_food_distance = curr_food_dist + elif curr_food_dist > self.last_min_food_distance: + # We're moving away from our closest approach + if self.last_min_food_distance <= self.near_miss_threshold: + # We got very close but missed + rewards['distance'] += self.reward_scales['near_miss'] + self.last_min_food_distance = float('inf') # Reset tracking + + # Progressive distance reward - more reward for getting closer when near food + distance_change = prev_food_dist - curr_food_dist + if curr_food_dist < 5: # When close to food + distance_multiplier = 2.0 # Double the reward/penalty + else: + distance_multiplier = 1.0 + rewards['distance'] = distance_change * self.reward_scales['distance'] * distance_multiplier + + # Calculate direction penalty + if curr_food_dist <= 5: # Only apply when close to food + # Calculate optimal direction to food + dx = curr_food[0] - curr_head[0] + dy = curr_food[1] - curr_head[1] + + if self.rules.wrap_around: + # Adjust for wrap-around + if dx > self.size/2: dx -= self.size + elif dx < -self.size/2: dx += self.size + if dy > self.size/2: dy -= self.size + elif dy < -self.size/2: dy += self.size + + # Determine optimal direction(s) + optimal_directions = [] + if abs(dx) > abs(dy): + if dx > 0: optimal_directions.append(Direction.RIGHT) + elif dx < 0: optimal_directions.append(Direction.LEFT) + if abs(dy) >= abs(dx): + if dy > 0: optimal_directions.append(Direction.DOWN) + elif dy < 0: optimal_directions.append(Direction.UP) + + # Base penalty scales with proximity (closer = higher penalty) + base_penalty = (6 - curr_food_dist) * 0.1 # Scales from 0.1 to 0.5 + + if direction not in optimal_directions: + # Check if we're moving directly away from food + opposite_dirs = { + Direction.UP: Direction.DOWN, + Direction.DOWN: Direction.UP, + Direction.LEFT: Direction.RIGHT, + Direction.RIGHT: Direction.LEFT + } + optimal_opposites = [opposite_dirs[d] for d in optimal_directions] + + if direction in optimal_opposites: + # Double penalty for moving directly away from food + rewards['direction'] = -2 * base_penalty + else: + # Regular penalty for suboptimal direction + rewards['direction'] = -base_penalty + elif curr_food_dist <= 2: # Small reward for correct direction when very close + rewards['direction'] = 0.1 + + # Efficiency reward - encourage purposeful movement + if not done and self.steps_since_food < 50: # Only apply when actively hunting + rewards['efficiency'] = self.reward_scales['efficiency'] * (1.0 - self.steps_since_food / 50) + + # Calculate final reward - simple sum of components + reward = sum(rewards.values()) + + # Get observation + observation = self._get_normalized_observation() + + # No more episode-based truncation + truncated = False + + # Include reward components and exploration info in info dict for monitoring + info = { + "score": state["score"], + "reward_components": rewards, + "exploration_bonus": self.current_exploration_bonus, + "direction_changed": direction_changed, + "steps_in_same_direction": self.steps_in_same_direction, + "snake_length": snake_length, + "steps": self.steps # Include total steps for monitoring + } + + return observation, reward, done, truncated, info + + def render(self): + """Render the environment.""" + if self.window is None and self.render_mode == "human": + pygame.init() + self.window = pygame.display.set_mode((self.window_width, self.window_height)) + pygame.display.set_caption("Snake Environment") + self.clock = pygame.time.Clock() + + if self.window is not None: + self.window.fill((0, 0, 0)) + + # Draw game area + self.game_area.draw(self.window) + + # Draw game session + self.session.render(self.window, self.game_area) + + # Draw score + font = pygame.font.Font(None, 36) + score_text = f"Score: {self.session.score}" + score_surface = font.render(score_text, True, (0, 255, 0)) + score_rect = score_surface.get_rect(midtop=(self.window_width // 2, 20)) + self.window.blit(score_surface, score_rect) + + pygame.display.flip() + self.clock.tick(self.metadata["render_fps"]) + + if self.render_mode == "rgb_array": + return np.transpose( + np.array(pygame.surfarray.pixels3d(self.window)), + axes=(1, 0, 2) + ) + + def close(self): + """Close the environment.""" + if self.window is not None: + pygame.quit() + self.window = None + + def get_state(self) -> Dict: + """Get the current game state.""" + state = self.session.get_state() + state['steps_since_food'] = self.steps_since_food + + # Calculate and include current time limit + snake_length = len(state["snake_body"]) + current_time_limit = min( + self.max_time_limit, + self.base_time_limit + (snake_length - 1) * self.length_time_bonus + ) + state['current_time_limit'] = current_time_limit + + return state + + def _get_adjacent_positions(self, position: Tuple[int, int]) -> List[Tuple[int, int]]: + """ + Get all adjacent positions to the given position. + + Args: + position: Current position (x, y) + + Returns: + List of adjacent positions [(x, y), ...] + + Raises: + ValueError: If position is None or not a tuple of 2 integers + ValueError: If position coordinates are not integers + """ + if not position or not isinstance(position, tuple) or len(position) != 2: + raise ValueError("Position must be a tuple of 2 coordinates") + + try: + x, y = int(position[0]), int(position[1]) + except (TypeError, ValueError): + raise ValueError("Position coordinates must be integers") + + # Get positions in all 4 directions (up, right, down, left) + adjacent = [ + (x, y - 1), # up + (x + 1, y), # right + (x, y + 1), # down + (x - 1, y) # left + ] + + # Validate size attribute exists and is positive integer + if not hasattr(self, 'size') or not isinstance(self.size, int) or self.size <= 0: + raise ValueError("Environment size must be a positive integer") + + # If wrap-around is enabled, adjust out-of-bounds positions + if hasattr(self, 'rules') and hasattr(self.rules, 'wrap_around') and self.rules.wrap_around: + adjacent = [ + (pos[0] % self.size, pos[1] % self.size) + for pos in adjacent + ] + return adjacent + + # Otherwise, only return positions that are within bounds + return [ + (pos[0], pos[1]) for pos in adjacent + if 0 <= pos[0] < self.size and 0 <= pos[1] < self.size + ] + + def _update_curriculum(self, score: int, done: bool) -> None: + """Update curriculum stage based on agent's performance.""" + if done: + if score >= self.stage_requirements[self.curriculum_stage]: + self.consecutive_successes += 1 + else: + self.consecutive_successes = 0 + + # Check for stage advancement + if (self.curriculum_stage < 3 and + self.consecutive_successes >= self.success_threshold[self.curriculum_stage]): + self.curriculum_stage += 1 + self.consecutive_successes = 0 + # Increase episode length with curriculum stage + self.max_steps = self.base_max_steps * (1 + self.curriculum_stage * 0.5) + + def _get_curriculum_rules(self) -> GameRules: + """Get game rules appropriate for current curriculum stage.""" + rules = GameRules() + + # Adjust rules based on curriculum stage + if self.curriculum_stage == 0: + rules.wrap_around = True # Make it easier to avoid walls + rules.speed_increase = False + rules.initial_move_cooldown = 200 # Much slower movement + elif self.curriculum_stage == 1: + rules.wrap_around = True + rules.speed_increase = False # Still no speed increase + rules.initial_move_cooldown = 150 + elif self.curriculum_stage == 2: + rules.wrap_around = True # Keep wrap-around until final stage + rules.speed_increase = True + rules.initial_move_cooldown = 120 + else: # Stage 3 + rules.wrap_around = False + rules.speed_increase = True + rules.initial_move_cooldown = 100 + + return rules + +# Register the environment with Gymnasium +from gymnasium.envs.registration import register + +register( + id='Snake-v0', + entry_point='src.ai.environment:SnakeEnv', + kwargs={ + 'size': 30, + 'render_mode': None, + 'difficulty': 'easy' + } +) \ No newline at end of file diff --git a/src/ai/model_networks.py b/src/ai/model_networks.py new file mode 100644 index 0000000..aa2d273 --- /dev/null +++ b/src/ai/model_networks.py @@ -0,0 +1,93 @@ +import gym +import numpy as np +import torch +import torch.nn as nn +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + + +class CustomRecurrentCNN(BaseFeaturesExtractor): + """ + A simple CNN with an LSTM layer added on top. This enables the network to keep an internal + memory of past observations so that it can adapt to changing rules/environment dynamics. + """ + def __init__(self, observation_space, features_dim=512): + # Initialize with a features_dim that matches the LSTM's output. + super(CustomRecurrentCNN, self).__init__(observation_space, features_dim) + + # Calculate the expected CNN output size first + n_input_channels = observation_space.shape[0] # Number of input channels + + self.cnn = nn.Sequential( + # Layer 1: (n_channels, 30, 30) -> (32, 15, 15) + nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + + # Layer 2: (32, 15, 15) -> (64, 8, 8) + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + + # Layer 3: (64, 8, 8) -> (64, 4, 4) + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + + nn.Flatten() + ) + + # Calculate CNN output dimension using a dummy forward pass + with torch.no_grad(): + dummy_input = torch.zeros(1, *observation_space.shape) + cnn_output = self.cnn(dummy_input) + self._cnn_output_dim = cnn_output.shape[1] # Should be 64 * 4 * 4 = 1024 + + # Add linear layer to reduce CNN output to desired feature dimension + self.fc = nn.Linear(self._cnn_output_dim, 256) + + # LSTM layer + self.lstm = nn.LSTM(256, features_dim, batch_first=True) + + def forward(self, observations): + # Pass through CNN: (batch, channels, height, width) -> (batch, cnn_features) + cnn_out = self.cnn(observations) + + # Pass through linear layer: (batch, cnn_features) -> (batch, 256) + fc_out = self.fc(cnn_out) + + # Add time dimension for LSTM: (batch, 256) -> (batch, 1, 256) + lstm_input = fc_out.unsqueeze(1) + + # LSTM: (batch, 1, 256) -> (batch, 1, features_dim) + lstm_out, _ = self.lstm(lstm_input) + + # Remove time dimension: (batch, 1, features_dim) -> (batch, features_dim) + return lstm_out.squeeze(1) + + +class CustomCNN(BaseFeaturesExtractor): + """ + Custom feature extractor for the snake environment. + Uses fully connected layers since our input is a 1D vector of normalized features. + """ + + def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): + super().__init__(observation_space, features_dim) + self.input_dim = int(np.prod(observation_space.shape)) + + self.cnn = nn.Sequential( + nn.Linear(self.input_dim, 512), + nn.ReLU(), + nn.BatchNorm1d(512), + nn.Linear(512, 512), + nn.ReLU(), + nn.BatchNorm1d(512), + nn.Linear(512, features_dim), + nn.ReLU(), + ) + + def forward(self, observations: torch.Tensor) -> torch.Tensor: + # Handle both single and batch observations + if len(observations.shape) == 1: + observations = observations.unsqueeze(0) + + # Ensure the input dimension matches what we expect + flat_obs = observations.reshape(observations.shape[0], self.input_dim) + return self.cnn(flat_obs) \ No newline at end of file diff --git a/src/ai/train.py b/src/ai/train.py new file mode 100644 index 0000000..d14f171 --- /dev/null +++ b/src/ai/train.py @@ -0,0 +1,423 @@ +""" +Training script for Snake AI using Stable Baselines3. + +This script sets up and trains RL agents using the SnakeEnv environment. +It supports multiple difficulty levels and provides real-time visualization +of the training process. +""" + +import os +import time +import numpy as np +import torch +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.monitor import Monitor +import gymnasium as gym +import torch.nn as nn +import argparse +from queue import Queue +from threading import Thread +from typing import Dict, Any, Optional +import json +from queue import Empty + +from src.ai.environment import SnakeEnv +from src.ai.visualize import run_dashboard +from src.ai.model_networks import CustomRecurrentCNN, CustomCNN + +class VisualizationCallback(BaseCallback): + """Callback for updating the training visualization.""" + + def __init__(self, viz_queue: Queue, total_timesteps: int, eval_freq: int = 1000, n_steps: int = 512, n_envs: int = 32, initial_timesteps: int = 0): + super().__init__() + self.viz_queue = viz_queue + self.total_timesteps = total_timesteps + self.eval_freq = eval_freq + self.n_steps = n_steps + self.n_envs = n_envs + self.initial_timesteps = initial_timesteps # Store initial timesteps + + # Training metrics + self.episode_rewards = [] + self.episode_lengths = [] + + # Evaluation metrics + self.eval_scores = [] + self.high_score = 0 + self.recent_high_score = 0 + + # Current episode tracking + self.current_reward = 0 + self.current_length = 0 + + # Evaluation games (one with wrap, one without) + self.eval_envs = [ + Monitor(SnakeEnv(render_mode="rgb_array")) for _ in range(2) + ] + self.eval_envs[0].unwrapped.rules.wrap_around = False + self.eval_envs[1].unwrapped.rules.wrap_around = True + + self.eval_obs = [None, None] + self.eval_done = [True, True] + self.eval_scores_current = [0, 0] + + # Real-time demo environment + self.demo_env = Monitor(SnakeEnv(render_mode="rgb_array")) + self.demo_env.unwrapped.rules.wrap_around = True + self.demo_obs = None + self.demo_done = True + self.last_demo_time = 0 + self.demo_speed = 50 + + def _on_step(self) -> bool: + """Update visualization on each step.""" + try: + current_time = time.time() * 1000 # Convert to milliseconds + + # Update episode tracking + self.current_length += 1 + reward = self.locals.get("rewards", [0])[0] + self.current_reward += reward + + # Get action probabilities and value estimate for demo game + if self.demo_obs is not None: + # Ensure proper tensor shape (batch_size, obs_dim) + demo_obs_tensor = torch.as_tensor(self.demo_obs).float() + if len(demo_obs_tensor.shape) == 1: + demo_obs_tensor = demo_obs_tensor.unsqueeze(0) + + # Move tensor to same device as model's policy + device = self.model.policy.device + demo_obs_tensor = demo_obs_tensor.to(device) + + with torch.no_grad(): + # Get action distribution and value estimate + dist = self.model.policy.get_distribution(demo_obs_tensor) + action_probs = dist.distribution.probs + value_estimate = self.model.policy.predict_values(demo_obs_tensor) + + # Move results back to CPU and get first item (since we added batch dimension) + action_probs = action_probs[0].cpu().numpy() + value_estimate = value_estimate[0].cpu().numpy() + else: + action_probs = np.zeros(4) + value_estimate = np.array([0.0]) + + # Check if episode ended + dones = self.locals.get("dones", [False]) + if any(dones): + self.episode_rewards.append(self.current_reward) + self.episode_lengths.append(self.current_length) + self.current_reward = 0 + self.current_length = 0 + + # Run evaluation games (keep these deterministic as they're meant to show best performance) + for i, (env, obs, done) in enumerate(zip(self.eval_envs, self.eval_obs, self.eval_done)): + if done: + if obs is not None: # If this wasn't the first reset + self.eval_scores.append(self.eval_scores_current[i]) + # Update high scores based on evaluation performance + self.high_score = max(self.high_score, self.eval_scores_current[i]) + self.recent_high_score = max(self.recent_high_score, self.eval_scores_current[i]) + self.eval_obs[i] = env.reset()[0] + self.eval_done[i] = False + self.eval_scores_current[i] = 0 + + eval_action, _ = self.model.predict(self.eval_obs[i], deterministic=True) + eval_action = eval_action.item() if hasattr(eval_action, 'item') else eval_action + self.eval_obs[i], _, done, truncated, info = env.step(eval_action) + self.eval_done[i] = done or truncated + # Update current evaluation score + self.eval_scores_current[i] = info.get("score", 0) + + # Run real-time demo game + if current_time - self.last_demo_time >= self.demo_speed: + if self.demo_done: + self.demo_obs = self.demo_env.reset()[0] + self.demo_done = False + + demo_action, _ = self.model.predict(self.demo_obs, deterministic=True) + demo_action = demo_action.item() if hasattr(demo_action, 'item') else demo_action + self.demo_obs, _, done, truncated, info = self.demo_env.step(demo_action) + self.demo_done = done or truncated + self.last_demo_time = current_time + + # Get current states from actual training environments + # Sample one no-wrap and one wrap environment from the training environments + no_wrap_idx = 0 # First half are no-wrap + wrap_idx = self.n_envs // 2 # Second half are wrap + + training_state_1 = self.training_env.get_attr("session")[no_wrap_idx].get_state() + training_state_1["wrap_around"] = False + + training_state_2 = self.training_env.get_attr("session")[wrap_idx].get_state() + training_state_2["wrap_around"] = True + + eval_state_1 = self.eval_envs[0].unwrapped.session.get_state() + eval_state_1["wrap_around"] = self.eval_envs[0].unwrapped.rules.wrap_around + + eval_state_2 = self.eval_envs[1].unwrapped.session.get_state() + eval_state_2["wrap_around"] = self.eval_envs[1].unwrapped.rules.wrap_around + + demo_state = self.demo_env.unwrapped.session.get_state() + demo_state["wrap_around"] = self.demo_env.unwrapped.rules.wrap_around + + # Update viz_data + viz_data = { + 'training_state': [training_state_1, training_state_2], + 'eval_states': [eval_state_1, eval_state_2], + 'demo_state': demo_state, + 'weights_info': { + 'action_probs': action_probs.tolist() if 'action_probs' in locals() else [0.25] * 4, + 'value_estimate': float(value_estimate[0]) if 'value_estimate' in locals() else 0.0, + 'action_labels': ['Up', 'Right', 'Down', 'Left'] + }, + 'training_info': { + 'total_timesteps': self.num_timesteps, + 'initial_timesteps': self.initial_timesteps, # Add initial timesteps + 'target_timesteps': self.total_timesteps, + 'episode_reward': float(self.current_reward), + 'mean_reward': float(np.mean(self.episode_rewards[-100:])) if self.episode_rewards else 0.0, + 'episode_length': int(self.current_length), + 'mean_length': float(np.mean(self.episode_lengths[-100:])) if self.episode_lengths else 0.0, + 'mean_eval_score': float(np.mean(self.eval_scores[-100:])) if self.eval_scores else 0.0, + 'rewards_history': self.episode_rewards[-1000:], + 'lengths_history': self.episode_lengths[-1000:], + 'eval_scores_history': self.eval_scores[-1000:], + 'high_score': self.high_score, + 'recent_high_score': self.recent_high_score + } + } + + # Send update to visualization + self.viz_queue.put(('update', viz_data)) + + # Reset recent high score periodically + if self.num_timesteps % (self.n_steps * self.n_envs) == 0: + print(f"\nEvaluation High Score: {self.high_score}") + print(f"Recent Evaluation High Score: {self.recent_high_score}") + self.recent_high_score = 0 + + return True + + except Exception as e: + print(f"Error in visualization callback: {e}") + import traceback + traceback.print_exc() + return True + +def make_env(seed: int = 0, wrap_walls: bool = False) -> gym.Env: + """Create a training environment. + + Args: + seed: Random seed + wrap_walls: Whether snake can wrap around walls + """ + def _init() -> gym.Env: + env = SnakeEnv(render_mode=None) + env = Monitor(env) + # Set wrap-around mode + env.unwrapped.rules.wrap_around = wrap_walls + env.reset(seed=seed) + return env + + set_random_seed(seed) + return _init + +def train_model( + total_timesteps: int, + viz_queue: Queue, + model_path: Optional[str] = None, + cuda: bool = True, + learning_rate: float = 3e-4, # Standard PPO learning rate + batch_size: int = 256, # More reasonable batch size + n_envs: int = 16, # Balanced number of environments + n_steps: int = 2048 # Standard PPO steps +) -> None: + """ + Train the snake AI model. + + Args: + total_timesteps: Number of timesteps to train for + viz_queue: Queue for sending visualization updates + model_path: Path to save/load the model + cuda: Whether to use CUDA for training + learning_rate: Learning rate for the optimizer + batch_size: Batch size for training + n_envs: Number of parallel environments + n_steps: Number of steps per environment + """ + try: + # Set up environments - half with wrap, half without + envs = [] + for i in range(n_envs): + wrap_walls = i >= n_envs // 2 # Half the envs with wrap-around + envs.append(make_env(i, wrap_walls)) + + env = SubprocVecEnv(envs) + + # PPO hyperparameters - adjusted for more stable learning + ppo_params = dict( + learning_rate=1e-4, # Slightly higher learning rate + n_steps=1024, # Shorter horizon for faster updates + batch_size=128, # Smaller batch size for more frequent updates + n_epochs=4, # Balanced number of epochs + gamma=0.98, # Slightly lower discount for more immediate rewards + gae_lambda=0.9, # Lower GAE lambda for more immediate advantage estimates + clip_range=0.1, # Smaller clip range for more conservative updates + clip_range_vf=0.1, # Also clip value function + normalize_advantage=True, + ent_coef=0.2, # Moderate entropy for exploration + vf_coef=0.8, # Balanced value function coefficient + max_grad_norm=0.3, # More conservative gradient clipping + target_kl=None, + verbose=1, + device="cuda" if cuda and torch.cuda.is_available() else "cpu" + ) + + # Create or load model with updated architecture + policy_kwargs = dict( + features_extractor_class=CustomRecurrentCNN, # Use the new recurrent extractor + features_extractor_kwargs=dict(features_dim=512), # You can adjust features_dim as needed + net_arch=dict( + pi=[512, 256, 128], # Policy network layers + vf=[512, 256, 128] # Value network layers + ), + activation_fn=nn.ReLU, + normalize_images=False + ) + + # Track initial timesteps for progress tracking + initial_timesteps = 0 + + if model_path and os.path.exists(model_path): + print(f"Loading model from {model_path}") + try: + model = PPO.load( + model_path, + env=env, + custom_objects={"policy_kwargs": policy_kwargs}, + **ppo_params + ) + initial_timesteps = model.num_timesteps + print(f"Loaded model with {initial_timesteps:,} timesteps of training") + except Exception as e: + print(f"Failed to load model with error: {e}") + print("Creating new model instead") + model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, **ppo_params) + else: + print("Creating new model") + model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, **ppo_params) + + # Set up visualization callback with initial timesteps + viz_callback = VisualizationCallback( + viz_queue=viz_queue, + total_timesteps=total_timesteps, + eval_freq=1000, + n_steps=n_steps, + n_envs=n_envs, + initial_timesteps=initial_timesteps # Pass initial timesteps to callback + ) + + # Training loop with stop check + should_stop = False + timesteps_per_batch = n_steps * n_envs + num_batches = total_timesteps // timesteps_per_batch + + for i in range(num_batches): + if should_stop: + break + + # Train for one batch + model.learn( + total_timesteps=timesteps_per_batch, + callback=viz_callback, + progress_bar=True, + reset_num_timesteps=False + ) + + # Check for stop signal from visualization + try: + while True: # Process all pending messages + msg_type, _ = viz_queue.get_nowait() + if msg_type == 'stop_training': + should_stop = True + break + except Empty: + pass + + # Save model if training completed or interrupted + if model_path: + print(f"Saving model to {model_path}") + model.save(model_path) + + except Exception as e: + print(f"Error in training process: {e}") + import traceback + traceback.print_exc() + finally: + # Clean up + env.close() + # Signal visualization to stop if we haven't already + viz_queue.put(('stop', None)) + +def main(): + """Main training script.""" + parser = argparse.ArgumentParser(description="Train Snake AI") + parser.add_argument("--config", action="store_true", + help="Show configuration UI before training") + parser.add_argument("--model", type=str, default="models/snake_ai.zip", + help="Path to save/load model") + parser.add_argument("--no-cuda", action="store_true", + help="Disable CUDA training") + args = parser.parse_args() + + # Get training configuration + if args.config: + from src.ai.config_ui import ConfigUI + config_ui = ConfigUI() + config = config_ui.run() + if config is None: # User cancelled + return + else: + # Load saved config or use defaults + config_file = "training_config.json" + if os.path.exists(config_file): + with open(config_file, 'r') as f: + config = json.load(f) + else: + config = { + 'timesteps': 2000000, + 'learning_rate': 0.0001, + 'batch_size': 512, + 'n_envs': 32, + 'n_steps': 512 + } + + # Create visualization queue + viz_queue = Queue() + + # Start visualization in separate thread + viz_thread = Thread(target=run_dashboard, args=(viz_queue,)) + viz_thread.start() + + # Start training with configuration + train_model( + total_timesteps=config['timesteps'], + viz_queue=viz_queue, + model_path=args.model, + cuda=not args.no_cuda, + learning_rate=config['learning_rate'], + batch_size=config['batch_size'], + n_envs=config['n_envs'], + n_steps=config['n_steps'] + ) + + # Wait for visualization to finish + viz_thread.join() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/ai/visualize.py b/src/ai/visualize.py new file mode 100644 index 0000000..1e54b63 --- /dev/null +++ b/src/ai/visualize.py @@ -0,0 +1,540 @@ +""" +Training Visualization Dashboard + +This module provides a real-time visualization of the training process, +showing both the training metrics and evaluation games. +""" + +import pygame +import numpy as np +from typing import Dict, Any, Optional, List, Tuple +import threading +from queue import Queue, Empty +import time + +class TrainingDashboard: + def __init__(self, width: int = 1920, height: int = 1080): + """Initialize the training dashboard.""" + self.width = width + self.height = height + + # Fonts + self.title_font = None + self.header_font = None + self.text_font = None + + # Colors + self.colors = { + 'background': (0, 0, 0), + 'text': (200, 200, 200), + 'highlight': (0, 255, 0), + 'grid': (40, 40, 40), + 'border': (0, 255, 0), + 'snake': (0, 255, 0), + 'food': (255, 0, 0), + 'graph_bg': (20, 20, 20), + 'graph_line': (0, 255, 0), + 'graph_grid': (40, 40, 40) + } + + # Layout + self.layout = { + 'margin': 20, + 'game_size': 400, + 'graph_height': 200, + 'metrics_width': 300 + } + + # Training metrics history + self.metrics_history = { + 'rewards': [], + 'scores': [], + 'lengths': [] + } + self.max_history = 1000 + + # Current state + self.current_state = None + self.eval_state = None + self.demo_state = None + self.training_info = None + self.model_info = None + + # Performance tracking + self.frame_times = [] + self.max_frame_times = 60 + + # Initialize the clock + self.clock = None + self.target_fps = 60 + + def run(self, update_queue: Queue) -> None: + """Run the dashboard, processing updates from the queue.""" + pygame.init() + try: + self.screen = pygame.display.set_mode((self.width, self.height)) + pygame.display.set_caption("Snake AI Training Dashboard") + + # Initialize fonts after pygame is initialized + self.title_font = pygame.font.Font(None, 48) + self.header_font = pygame.font.Font(None, 36) + self.text_font = pygame.font.Font(None, 24) + + # Initialize clock + self.clock = pygame.time.Clock() + + running = True + while running: + # Handle events + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_ESCAPE: + running = False + + # Process updates from queue + try: + while True: # Process all available updates + update_type, data = update_queue.get_nowait() + if update_type == 'update': + self.update(data) + elif update_type == 'stop': + running = False + break + except Empty: + pass + + # Render dashboard + self.render() + self.clock.tick(self.target_fps) + finally: + pygame.quit() + + def update(self, data: Dict[str, Any]) -> None: + """Update the dashboard with new data.""" + self.current_state = data # Store entire data dict + if 'training_info' in data: + self.training_info = data['training_info'] + # Update metrics history + for key in ['rewards', 'scores', 'lengths']: + if key in self.training_info: + self.metrics_history[key].append(self.training_info[key]) + if len(self.metrics_history[key]) > self.max_history: + self.metrics_history[key].pop(0) + + def draw_game_view(self, state: Dict[str, Any], position: Tuple[int, int], size: int, title: str) -> None: + """Draw a game view at the specified position.""" + if not state: + return + + x, y = position + + # Draw game area background + game_rect = pygame.Rect(x, y, size, size) + pygame.draw.rect(self.screen, self.colors['grid'], game_rect) + pygame.draw.rect(self.screen, self.colors['border'], game_rect, 2) + + # Get time limit info directly from state + steps_since_food = state.get('steps_since_food', 0) + current_time_limit = state.get('current_time_limit', 100) # Default to base limit if not provided + time_left = max(0, current_time_limit - steps_since_food) + + # Calculate cell size based on grid dimensions + cell_size = size // 30 # Assuming 30x30 grid + + # Draw snake + if "snake_body" in state: + for segment in state["snake_body"]: + segment_rect = pygame.Rect( + x + segment[0] * cell_size, + y + segment[1] * cell_size, + cell_size, cell_size + ) + pygame.draw.rect(self.screen, self.colors['snake'], segment_rect) + + # Draw food + if "food_position" in state: + food_rect = pygame.Rect( + x + state["food_position"][0] * cell_size, + y + state["food_position"][1] * cell_size, + cell_size, cell_size + ) + pygame.draw.rect(self.screen, self.colors['food'], food_rect) + + # Draw score and timer above game area + score = state.get('score', 0) + score_text = f"Score: {score}" + timer_text = f"Time: {time_left}/{current_time_limit}" # Show both current and max time + combined_text = f"{score_text} | {timer_text}" + score_surface = self.text_font.render(combined_text, True, self.colors['text']) + score_rect = score_surface.get_rect(midtop=(x + size//2, y - 25)) + self.screen.blit(score_surface, score_rect) + + # Draw title below game area, including wrap mode for training view + if title == "Training Episode" and 'wrap_around' in state: + wrap_mode = "(Wrap)" if state['wrap_around'] else "(No Wrap)" + title = f"{title} {wrap_mode}" + title_text = title + title_surface = self.header_font.render(title_text, True, self.colors['text']) + title_rect = title_surface.get_rect(midtop=(x + size//2, y + size + 10)) + self.screen.blit(title_surface, title_rect) + + def draw_metrics(self, x: int, y: int, width: int) -> None: + """Draw training metrics.""" + if not self.training_info: + return + + font = pygame.font.Font(None, 24) + y_offset = 0 + + # Draw high scores + high_score = self.training_info.get("high_score", 0) + recent_high = self.training_info.get("recent_high_score", 0) + + metrics = [ + ("All-time High Score", high_score), + ("Recent High Score", recent_high), + ("", ""), # Empty line for spacing + ("Training Metrics", ""), + ("Total Steps", self.training_info.get('total_timesteps', 0)), + ("Episode Reward", f"{self.training_info.get('episode_reward', 0):.2f}"), + ("Mean Reward", f"{self.training_info.get('mean_reward', 0):.2f}"), + ("Episode Length", self.training_info.get('episode_length', 0)), + ("Mean Length", f"{self.training_info.get('mean_length', 0):.2f}"), + ("FPS", f"{len(self.frame_times) / sum(self.frame_times):.1f}" if self.frame_times else "0") + ] + + for i, (label, value) in enumerate(metrics): + text = f"{label}: {value}" if label else "" + surface = self.text_font.render(text, True, self.colors['text']) + self.screen.blit(surface, (x, y + i * 25)) + + def draw_graph(self, data: List[float], rect: pygame.Rect, title: str, + color: Tuple[int, int, int], min_val: float, max_val: float, + smoothing_window: int = 10) -> None: + """Draw a line graph of the given data. + + Args: + data: List of data points. + rect: Rectangle area for drawing the graph. + title: Title for the graph. + color: Color for the graph line. + min_val: Minimum value for scaling the y-axis. + max_val: Maximum value for scaling the y-axis. + smoothing_window: Optional smoothing window size for drawing + long term trends. Increase this value for a smoother, longer-term trend. + Default is 10. + """ + if not data: + return + + # Create a separate surface for the graph + graph_surface = pygame.Surface((rect.width, rect.height)) + graph_surface.fill(self.colors['graph_bg']) + + # Draw grid lines on the graph surface + num_lines = 5 + for i in range(num_lines): + y_line = i * rect.height / (num_lines - 1) + pygame.draw.line(graph_surface, self.colors['graph_grid'], + (0, y_line), (rect.width, y_line)) + + # Add padding to the min/max values for better visualization + value_range = max_val - min_val + if value_range == 0: + value_range = 1.0 + padding = value_range * 0.1 + min_val_adjusted = min_val - padding + max_val_adjusted = max_val + padding + + # Smooth the data using the provided smoothing_window parameter + window_size = smoothing_window if len(data) >= smoothing_window else len(data) + if window_size > 1: + smoothed_data = np.convolve(data, np.ones(window_size) / window_size, mode='valid') + else: + smoothed_data = data + + points = [] + for i, value in enumerate(smoothed_data): + x_point = i * rect.width / len(smoothed_data) + y_point = rect.height - ((value - min_val_adjusted) * rect.height / + (max_val_adjusted - min_val_adjusted)) + points.append((x_point, y_point)) + + if len(points) > 1: + pygame.draw.lines(graph_surface, color, False, points, 2) + + # Draw border on the graph surface + pygame.draw.rect(graph_surface, self.colors['border'], + pygame.Rect(0, 0, rect.width, rect.height), 1) + + # Blit the graph surface onto the main screen + self.screen.blit(graph_surface, rect) + + # Draw title and current value + current_value = data[-1] if data else 0 + title_text = f"{title} (Current: {current_value:.2f})" + title_surface = self.text_font.render(title_text, True, self.colors['text']) + title_rect = title_surface.get_rect(midtop=(rect.centerx, rect.top - 20)) + self.screen.blit(title_surface, title_rect) + + # Draw value labels to the right of the graph + label_spacing = rect.height / (num_lines - 1) + label_x = rect.right + 10 + for i in range(num_lines): + value = max_val_adjusted - (i * (max_val_adjusted - min_val_adjusted) / (num_lines - 1)) + label_text = f"{value:.1f}" + label_surface = self.text_font.render(label_text, True, self.colors['text']) + label_rect = label_surface.get_rect( + left=label_x, + centery=rect.top + i * label_spacing + ) + self.screen.blit(label_surface, label_rect) + + def draw_weights_visualization(self, x: int, y: int, width: int, height: int, weights_info: Dict) -> None: + """Draw the neural network weights visualization.""" + if not weights_info: + return + + # Get data + action_probs = weights_info.get('action_probs', [0, 0, 0, 0]) + value_estimate = weights_info.get('value_estimate', 0.0) + action_labels = weights_info.get('action_labels', ['Up', 'Right', 'Down', 'Left']) + + # Draw background + rect = pygame.Rect(x, y, width, height) + pygame.draw.rect(self.screen, self.colors['graph_bg'], rect) + pygame.draw.rect(self.screen, self.colors['border'], rect, 1) + + # Draw title + title = "Action Probabilities" + title_surface = self.header_font.render(title, True, self.colors['text']) + title_rect = title_surface.get_rect(midtop=(x + width//2, y + 5)) + self.screen.blit(title_surface, title_rect) + + # Draw action probability bars + bar_height = 20 + bar_spacing = 30 + bar_start_y = y + 50 + max_bar_width = width - 120 # Leave space for labels + + for i, (prob, label) in enumerate(zip(action_probs, action_labels)): + # Draw label + label_surface = self.text_font.render(label, True, self.colors['text']) + self.screen.blit(label_surface, (x + 10, bar_start_y + i * bar_spacing)) + + # Draw bar background + bar_bg_rect = pygame.Rect(x + 70, bar_start_y + i * bar_spacing, max_bar_width, bar_height) + pygame.draw.rect(self.screen, self.colors['grid'], bar_bg_rect) + + # Draw probability bar + bar_width = int(prob * max_bar_width) + bar_rect = pygame.Rect(x + 70, bar_start_y + i * bar_spacing, bar_width, bar_height) + pygame.draw.rect(self.screen, self.colors['highlight'], bar_rect) + + # Draw probability value + prob_text = f"{prob:.2f}" + prob_surface = self.text_font.render(prob_text, True, self.colors['text']) + self.screen.blit(prob_surface, (x + 80 + max_bar_width, bar_start_y + i * bar_spacing)) + + # Draw value estimate + value_text = f"Value Estimate: {value_estimate:.2f}" + value_surface = self.text_font.render(value_text, True, self.colors['text']) + value_rect = value_surface.get_rect(midbottom=(x + width//2, y + height - 10)) + self.screen.blit(value_surface, value_rect) + + def draw_progress_bar(self, x: int, y: int, width: int, height: int, progress: float, text: str) -> None: + """Draw a progress bar with text. + + Args: + x, y: Position of the progress bar + width, height: Dimensions of the progress bar + progress: Progress value between 0 and 1 + text: Text to display above the progress bar + """ + # Draw background + bg_rect = pygame.Rect(x, y, width, height) + pygame.draw.rect(self.screen, self.colors['graph_bg'], bg_rect) + pygame.draw.rect(self.screen, self.colors['border'], bg_rect, 1) + + # Draw progress + if progress > 0: + progress_width = int(width * progress) + progress_rect = pygame.Rect(x, y, progress_width, height) + pygame.draw.rect(self.screen, self.colors['highlight'], progress_rect) + + # Draw text + text_surface = self.text_font.render(text, True, self.colors['text']) + text_rect = text_surface.get_rect(bottomleft=(x, y - 5)) + self.screen.blit(text_surface, text_rect) + + def render(self) -> None: + """Render the dashboard.""" + if not self.current_state: + return + + self.screen.fill(self.colors['background']) + + # Draw game views + if 'training_state' in self.current_state: + # Draw both training games + self.draw_game_view( + self.current_state['training_state'][0], # No wrap training + (50, 50), + 250, + "Training (No Wrap)" + ) + self.draw_game_view( + self.current_state['training_state'][1], # Wrap training + (350, 50), + 250, + "Training (Wrap)" + ) + + if 'eval_states' in self.current_state: + self.draw_game_view( + self.current_state['eval_states'][0], + (650, 50), + 250, + "Evaluation (No Wrap)" + ) + self.draw_game_view( + self.current_state['eval_states'][1], + (950, 50), + 250, + "Evaluation (Wrap)" + ) + + if 'demo_state' in self.current_state: + self.draw_game_view( + self.current_state['demo_state'], + (1250, 50), + 250, + "Demo Game" + ) + + # Draw metrics + if self.training_info: + metrics_data = [ + (self.training_info['rewards_history'], "Episode Rewards"), + (self.training_info['lengths_history'], "Episode Lengths"), + (self.training_info['eval_scores_history'], "Evaluation Scores") + ] + + for i, (data, label) in enumerate(metrics_data): + if data: # Only draw if we have data + rect = pygame.Rect(50, 400 + i * 200, 500, 150) + self.draw_graph( + data, + rect, + label, + self.colors['graph_line'], + min(data) if data else 0, + max(data) if data else 1 + ) + + # Draw current metrics text + metrics_text = [ + f"Total Steps: {self.training_info.get('total_timesteps', 0):,}", + f"Mean Reward: {self.training_info.get('mean_reward', 0):.2f}", + f"Mean Length: {self.training_info.get('mean_length', 0):.1f}", + f"Mean Eval Score: {self.training_info.get('mean_eval_score', 0):.2f}", + f"High Score: {self.training_info.get('high_score', 0)}", + f"Recent High Score: {self.training_info.get('recent_high_score', 0)}" + ] + + # Add vertical spacer between metrics sections + metrics_text.insert(3, "") # Insert empty string as spacer after first 3 metrics + + font = pygame.font.Font(None, 36) + for i, text in enumerate(metrics_text): + surface = font.render(text, True, self.colors['text']) + self.screen.blit(surface, (600, 400 + i * 30)) + + # Draw action probabilities if available + if 'weights_info' in self.current_state: + self.draw_weights_visualization( + 1100, 400, + 300, 300, + self.current_state['weights_info'] + ) + + # Draw trend indicator graph + if self.training_info: + # Calculate trend indicators + window = 100 # Use last 100 episodes for trends + rewards = self.training_info['rewards_history'][-window:] + lengths = self.training_info['lengths_history'][-window:] + eval_scores = self.training_info['eval_scores_history'][-window:] + + if rewards and lengths and eval_scores: + # Normalize each metric to 0-1 range for fair combination + norm_rewards = [(r - min(rewards)) / (max(rewards) - min(rewards) + 1e-8) for r in rewards] + norm_lengths = [(l - min(lengths)) / (max(lengths) - min(lengths) + 1e-8) for l in lengths] + norm_scores = [(s - min(eval_scores)) / (max(eval_scores) - min(eval_scores) + 1e-8) for s in eval_scores] + + # Combine metrics with weights (rewards: 0.4, lengths: 0.3, eval_scores: 0.3) + trend_data = [0.4 * r + 0.3 * l + 0.3 * s + for r, l, s in zip(norm_rewards, norm_lengths, norm_scores)] + # Add trend direction indicator + + trending = 0 # -1, 0, 1 + + if len(trend_data) >= 2: + recent_trend = trend_data[-1] - trend_data[0] + trending = 1 if recent_trend > 0.1 else -1 if recent_trend < -0.1 else 0 + trend_text = "↑ Improving" if trending == 1 else "↓ Declining" if trending == -1 else "→ Stable" + trend_color = (0, 255, 0) if trending == 1 else (255, 0, 0) if trending == -1 else (200, 200, 0) + trend_surface = self.text_font.render(trend_text, True, trend_color) + trend_rect = trend_surface.get_rect(midtop=(1250, 860)) + self.screen.blit(trend_surface, trend_rect) + + graph_color = (0, 255, 0) if trending == 1 else (255, 0, 0) if trending == -1 else (200, 200, 0) + + # Draw trend graph 50 pixels below the action probabilities graph + trend_rect = pygame.Rect(1100, 750, 300, 150) + self.draw_graph( + trend_data, + trend_rect, + "Learning Trend", + graph_color, # color based on trend + 0, # Min value (normalized) + 1, # Max value (normalized) + smoothing_window=10 # Use a larger smoothing window for longer-term trends + ) + + # Draw training progress bar at the bottom of the screen + if self.training_info: + # Get current session steps (from start of this training run) + current_session_steps = self.training_info.get('total_timesteps', 0) - self.training_info.get('initial_timesteps', 0) + target_timesteps = self.training_info.get('target_timesteps', 1000000) + progress = min(1.0, current_session_steps / target_timesteps) + + # Show both current session progress and total model training + total_model_steps = self.training_info.get('total_timesteps', 0) + progress_text = ( + f"Current Session: {current_session_steps:,}/{target_timesteps:,} steps ({progress*100:.1f}%) | " + f"Total Model Training: {total_model_steps:,} steps" + ) + self.draw_progress_bar(50, self.height - 40, self.width - 100, 20, progress, progress_text) + + pygame.display.flip() + + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self.running = False + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_ESCAPE: + self.running = False + +def run_dashboard(update_queue: Queue) -> None: + """Run the training dashboard in a separate thread.""" + dashboard = TrainingDashboard() + try: + dashboard.run(update_queue) + finally: + # Send stop signal back to training process + update_queue.put(('stop_training', None)) + pygame.quit() \ No newline at end of file diff --git a/src/cli.py b/src/cli.py index cc7c0d4..001b15d 100644 --- a/src/cli.py +++ b/src/cli.py @@ -6,7 +6,8 @@ It handles command-line arguments for different game modes and configurations. """ import argparse -from src import Game, GameMode +from src.game import Game, GameState +from src.ui.menu import GameMode def main(): """ @@ -35,12 +36,19 @@ def main(): if args.test: game.width = 400 game.height = 300 - game.block_size = 20 - game.fps = 30 + game.game_session.grid_size = 20 + game.clock.tick(30) # Set FPS to 30 # Apply debug mode settings if args.debug: - game.debug = True + game.settings.debug_mode = True + if game.game_session: + game.game_session.show_grid = True + + # Apply AI mode settings + if args.ai_only: + game.game_mode = GameMode.AI_MEDIUM + game.state = GameState.PLAYING # Run the game try: diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..79f0613 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,36 @@ +""" +Game Configuration Package + +This package contains game configuration and constants. +""" + +from .constants import * +from .settings import GameRules, GameSettings + +__all__ = [ + 'GameRules', + 'GameSettings', + # Include all constants + 'WINDOW_WIDTH', + 'WINDOW_HEIGHT', + 'DEFAULT_GRID_WIDTH', + 'DEFAULT_GRID_HEIGHT', + 'DEFAULT_GRID_SIZE', + 'DEFAULT_PADDING', + 'DEFAULT_SCORE_HEIGHT', + 'BLACK', + 'WHITE', + 'GREEN', + 'DARK_GREEN', + 'GRAY', + 'DARK_GRAY', + 'GRID_COLOR', + 'TITLE_FONT_SIZE', + 'SUBTITLE_FONT_SIZE', + 'SCORE_FONT_SIZE', + 'DEBUG_FONT_SIZE', + 'FPS', + 'DEFAULT_MOVE_COOLDOWN', + 'MIN_MOVE_COOLDOWN', + 'MENU_ITEM_SPACING' +] \ No newline at end of file diff --git a/src/config/colors.py b/src/config/colors.py new file mode 100644 index 0000000..53ff256 --- /dev/null +++ b/src/config/colors.py @@ -0,0 +1,25 @@ +""" +Game Colors + +This module contains all the color values used throughout the game. +""" + +# Modern color palette +BLACK = (0, 0, 0) +WHITE = (255, 255, 255) +GREEN = (0, 255, 0) +DARK_GREEN = (0, 100, 0) +NEON_GREEN = (57, 255, 20) # Brighter, more neon green +LIME_GREEN = (50, 205, 50) # Softer green for variety +FOREST_GREEN = (0, 100, 0) # Dark green for contrast +DARKER_GREEN = (0, 40, 0) # Very dark green for borders +GRAY = (128, 128, 128) +DARK_GRAY = (25, 25, 25) # Darker background for better contrast +GRID_COLOR = (40, 40, 40) +SUBTLE_GRID_COLOR = (35, 35, 35) +RED = (255, 0, 0) +NEON_RED = (255, 20, 57) # Brighter, more neon red for food +GLOW_GREEN = (150, 255, 150) # Light green for glow effects +SNAKE_GRADIENT_START = (57, 255, 20) # Head color +SNAKE_GRADIENT_END = (0, 150, 0) # Tail color + diff --git a/src/config/constants.py b/src/config/constants.py new file mode 100644 index 0000000..1b08175 --- /dev/null +++ b/src/config/constants.py @@ -0,0 +1,31 @@ +""" +Game Constants + +This module contains all the constant values used throughout the game. +""" + +# Window settings +WINDOW_WIDTH = 1024 +WINDOW_HEIGHT = 768 + +# Grid settings +DEFAULT_GRID_WIDTH = 30 +DEFAULT_GRID_HEIGHT = 30 +DEFAULT_GRID_SIZE = 20 + +# Game area settings +DEFAULT_PADDING = 40 +DEFAULT_SCORE_HEIGHT = 60 +# Font sizes +TITLE_FONT_SIZE = 72 +SUBTITLE_FONT_SIZE = 36 +SCORE_FONT_SIZE = 36 +DEBUG_FONT_SIZE = 24 + +# Game timing +FPS = 60 +DEFAULT_MOVE_COOLDOWN = 100 +MIN_MOVE_COOLDOWN = 50 + +# Menu settings +MENU_ITEM_SPACING = 50 \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000..6db2abd --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,52 @@ +""" +Game Settings + +This module contains the game settings and rules that can be configured. +""" + +class GameRules: + """Game rules that can be configured through the settings menu.""" + + def __init__(self): + self.wrap_around = True # Whether snake wraps around screen edges + self.speed_increase = True # Whether snake speeds up as it grows + self.min_move_cooldown = 50 # Minimum movement delay in milliseconds + self.initial_move_cooldown = 100 # Initial movement delay + self.starting_length = 3 # Starting length of the snake + + def update_rule(self, rule_name, value): + """ + Update a game rule with a new value. + + Args: + rule_name (str): Name of the rule to update + value: New value for the rule + + Returns: + bool: True if update was successful, False otherwise + """ + if hasattr(self, rule_name): + # Define validation rules for specific attributes + validators = { + 'starting_length': lambda x: 1 <= x <= 10, + 'wrap_around': lambda x: isinstance(x, bool), + 'speed_increase': lambda x: isinstance(x, bool), + 'min_move_cooldown': lambda x: x > 0, + 'initial_move_cooldown': lambda x: x > 0 + } + + # Validate the value if there's a validator for this rule + if rule_name in validators and not validators[rule_name](value): + return False + + setattr(self, rule_name, value) + return True + return False + +class GameSettings: + """Global game settings.""" + + def __init__(self): + self.debug_mode = False # Whether to show debug information + self.show_grid = self.debug_mode # Whether to show the grid lines + self.rules = GameRules() # Game rules instance \ No newline at end of file diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..74e8f29 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,16 @@ +""" +Game Core Package + +This package contains the core game mechanics and entities. +""" + +from .snake import Snake, Direction +from .food import Food +from .game_session import GameSession + +__all__ = [ + 'Snake', + 'Direction', + 'Food', + 'GameSession' +] \ No newline at end of file diff --git a/src/core/food.py b/src/core/food.py new file mode 100644 index 0000000..922f8fe --- /dev/null +++ b/src/core/food.py @@ -0,0 +1,82 @@ +""" +Food Module + +This module provides the Food class that represents the food item in the game. +""" + +import pygame +import random +from typing import Tuple, List +from src.ui import GameArea +from src.config.colors import * + +class Food: + def __init__(self, color: Tuple[int, int, int] = NEON_RED): + """ + Initialize food with a color. + + Args: + color: RGB color tuple for the food + """ + self.color = color + self.position = (0, 0) # Will be set by spawn_at_position() + + def spawn_at_position(self, available_positions: List[Tuple[int, int]]) -> None: + """ + Spawn food at a random position from the available positions. + + Args: + available_positions: List of valid grid positions where food can spawn + """ + if not available_positions: + # No valid positions (snake fills screen) - game should be won + return + + # Choose random position from valid positions + self.position = random.choice(available_positions) + + def draw(self, screen: pygame.Surface, game_area: GameArea) -> None: + """ + Draw the food on the screen with a glowing effect + + Args: + screen: Pygame surface to draw on + game_area: GameArea instance for coordinate conversion + """ + screen_x, screen_y = game_area.get_screen_pos(self.position[0], self.position[1]) + + # Draw outer glow + glow_size = game_area.grid_size + 4 + glow_surface = pygame.Surface((glow_size, glow_size), pygame.SRCALPHA) + + # Create radial gradient for glow + for i in range(3): + alpha = 100 - (i * 30) # Fade out alpha + size = glow_size - (i * 2) + pos = (glow_size - size) // 2 + pygame.draw.circle(glow_surface, (*NEON_RED[:3], alpha), (glow_size//2, glow_size//2), size//2) + + # Draw glow + screen.blit(glow_surface, (screen_x - 2, screen_y - 2)) + + # Draw main food body + food_rect = pygame.Rect( + screen_x + 2, + screen_y + 2, + game_area.grid_size - 4, + game_area.grid_size - 4 + ) + pygame.draw.rect(screen, NEON_RED, food_rect, border_radius=4) + + # Draw highlight + highlight_rect = pygame.Rect( + screen_x + 4, + screen_y + 4, + game_area.grid_size - 8, + game_area.grid_size - 8 + ) + pygame.draw.rect(screen, (*NEON_RED, 200), highlight_rect, border_radius=3) + + def check_collision(self, position: Tuple[int, int]) -> bool: + """Check if the given position collides with the food""" + return self.position == position \ No newline at end of file diff --git a/src/core/game_session.py b/src/core/game_session.py new file mode 100644 index 0000000..bf43155 --- /dev/null +++ b/src/core/game_session.py @@ -0,0 +1,239 @@ +""" +Game Session Manager + +This module provides a reusable game session class that encapsulates the core gameplay mechanics. +It can be used by both the main game and the training environment to ensure consistent behavior. +""" + +import pygame +from typing import Tuple, Dict, Optional +from src.core import Snake, Direction, Food +from src.ui import GameArea +from src.config import ( + WINDOW_WIDTH, + WINDOW_HEIGHT, + FPS +) +from src.config.colors import * +from src.config.settings import GameRules + +class GameSession: + """Manages a single game session with consistent rules and boundaries.""" + + def __init__(self, grid_width: int, grid_height: int, rules : GameRules, window_width: int = WINDOW_WIDTH, window_height: int = WINDOW_HEIGHT): + """ + Initialize a new game session. + + Args: + grid_width: Number of grid cells horizontally + grid_height: Number of grid cells vertically + rules: GameRules instance containing game settings + window_width: Width of the game window + window_height: Height of the game window + """ + self.grid_width = grid_width + self.grid_height = grid_height + self.rules = rules + self.window_width = window_width + self.window_height = window_height + + # Game objects + self.snake = None + self.food = None + + # Game state + self.score = 0 + self.is_game_over = False + self.move_cooldown = rules.initial_move_cooldown + self.last_move_time = 0 + + self.reset() + + def reset(self) -> Dict: + """ + Reset the game session to its initial state, including the game area. + + Returns: + Dict containing the initial game state + """ + # Initialize game area with window dimensions + self.game_area = GameArea( + window_width=self.window_width, + window_height=self.window_height, + grid_width=self.grid_width, + grid_height=self.grid_height + ) + + # Initialize snake in the middle of the grid + start_x = self.grid_width // 2 + start_y = self.grid_height // 2 + self.snake = Snake( + start_pos=(start_x, start_y), + grid_width=self.grid_width, + grid_height=self.grid_height + ) + + # Initialize food + self.food = Food() + self._spawn_food() + + # Reset game state + self.score = 0 + self.is_game_over = False + self.move_cooldown = self.rules.initial_move_cooldown + self.last_move_time = 0 + + return self.get_state() + + def _spawn_food(self) -> None: + """Spawn food in a random empty grid cell.""" + # Get all occupied positions + occupied = set(self.snake.body) + + # Get all possible positions + all_positions = [(x, y) for x in range(self.grid_width) + for y in range(self.grid_height)] + + # Filter out occupied positions + available = [pos for pos in all_positions if pos not in occupied] + + if available: + self.food.spawn_at_position(available) + + def step(self, action: Optional[Direction] = None, current_time: int = 0) -> Tuple[Dict, float, bool]: + """ + Advance the game state by one step. + + Args: + action: Optional direction to change to + current_time: Current game time in milliseconds + + Returns: + Tuple of (game_state, reward, done) + """ + if self.is_game_over: + return self.get_state(), 0, True + + # Apply action if provided + if action is not None: + self.snake.change_direction(action) + + # Update snake movement animation + self.snake.update_movement(1.0 / FPS) # + + # Check if enough time has passed for next move + if current_time - self.last_move_time < self.move_cooldown: + return self.get_state(), 0, False + + self.last_move_time = current_time + + # Process any pending direction changes first + if self.snake.input_buffer: + # Try first direction + next_direction = self.snake.input_buffer[0] + if self.snake.is_direction_valid(next_direction): + self.snake.direction = next_direction + self.snake.input_buffer.pop(0) + # If invalid and we have another input, try that one first + elif len(self.snake.input_buffer) > 1: + alt_direction = self.snake.input_buffer[1] + if self.snake.is_direction_valid(alt_direction): + self.snake.direction = alt_direction + self.snake.input_buffer.pop(1) # Remove the successful second input + # Remove the first input regardless as it's had its chance + self.snake.input_buffer.pop(0) + else: + # Single invalid input, just remove it + self.snake.input_buffer.pop(0) + + # Now get movement vector with updated direction + dx, dy = self.snake.direction.to_vector() + new_head = (self.snake.body[0][0] + dx, self.snake.body[0][1] + dy) + + # Check for immediate wall collision (before moving) + if not self.rules.wrap_around: + if (new_head[0] < 0 or new_head[0] >= self.grid_width or + new_head[1] < 0 or new_head[1] >= self.grid_height): + self.is_game_over = True + return self.get_state(), -1, True + + # check for collision with snake body + if new_head in self.snake.body[1:]: + self.is_game_over = True + return self.get_state(), -1, True + + # Move snake + head_x, head_y = self.snake.move() + + # Handle wrap-around if enabled + if self.rules.wrap_around: + head_x = head_x % self.grid_width + head_y = head_y % self.grid_height + self.snake.body[0] = (head_x, head_y) + self.snake.target_positions[0] = (head_x, head_y) # Update target position for smooth movement + + # Check self collision + if (head_x, head_y) in self.snake.body[1:]: + self.is_game_over = True + return self.get_state(), -1, True + + # Check food collision + if (head_x, head_y) == self.food.position: + self.score += 1 + self.snake.grow() + self._spawn_food() + + # Increase speed if enabled + if self.rules.speed_increase: + self.move_cooldown = max( + self.rules.min_move_cooldown, + self.move_cooldown - 2 + ) + return self.get_state(), 1, False + + return self.get_state(), 0, False + + def get_state(self) -> Dict: + """ + Get the current game state. + + Returns: + Dict containing the game state + """ + return { + "grid_width": self.grid_width, + "grid_height": self.grid_height, + "snake_head": self.snake.body[0], + "snake_body": self.snake.body, + "snake_direction": self.snake.direction, + "food_position": self.food.position, + "score": self.score, + "game_over": self.is_game_over, + "move_cooldown": self.move_cooldown + } + + def render(self, screen: pygame.Surface) -> None: + """ + Render the game state to the screen. + + Args: + screen: Pygame surface to render to + """ + # Draw game area first + self.game_area.draw(screen) + + # Draw game objects + self.snake.draw(screen, self.game_area) + self.food.draw(screen, self.game_area) + + # Draw grid (optional, for debugging) + if hasattr(self, 'show_grid') and self.show_grid: + for x in range(self.grid_width + 1): + start_pos = (self.game_area.x + x * self.game_area.grid_size, self.game_area.y) + end_pos = (self.game_area.x + x * self.game_area.grid_size, self.game_area.y + self.game_area.height) + pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1) + + for y in range(self.grid_height + 1): + start_pos = (self.game_area.x, self.game_area.y + y * self.game_area.grid_size) + end_pos = (self.game_area.x + self.game_area.width, self.game_area.y + y * self.game_area.grid_size) + pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1) \ No newline at end of file diff --git a/src/core/snake.py b/src/core/snake.py new file mode 100644 index 0000000..130fb9b --- /dev/null +++ b/src/core/snake.py @@ -0,0 +1,357 @@ +""" +Snake Module + +This module provides the Snake class and Direction enum for snake movement and rendering. +""" + +import pygame +from enum import Enum, auto +from typing import Tuple, List +import math +from src.config.colors import * +import pygame.gfxdraw + +class Direction(Enum): + UP = auto() + DOWN = auto() + LEFT = auto() + RIGHT = auto() + + def to_vector(self) -> Tuple[int, int]: + # Convert direction to a movement vector + if self == Direction.UP: + return (0, -1) + elif self == Direction.DOWN: + return (0, 1) + elif self == Direction.LEFT: + return (-1, 0) + else: # Direction.RIGHT + return (1, 0) + +class Snake: + def __init__(self, start_pos: Tuple[int, int], grid_width: int, grid_height: int, length: int = 3): + """ + Initialize snake at the given grid position. + + Args: + start_pos: Starting position in grid coordinates (x, y) + grid_width: Width of the game grid + grid_height: Height of the game grid + length: Initial length of the snake + """ + self.direction = Direction.RIGHT + self.body = [start_pos] # Head is at index 0 + self.growing = False + self.grid_width = grid_width + self.grid_height = grid_height + + # FIX: Start from 1 to avoid duplicate head segment. + for i in range(1, length): + self.body.append((start_pos[0] - i, start_pos[1])) + + # Initialize visual positions for each body segment + self.visual_positions = [(float(pos[0]), float(pos[1])) for pos in self.body] + self.move_progress = 1.0 # Progress of current move (0.0 to 1.0) + self.move_speed = 0.2 # Movement speed + self.target_positions = list(self.body) # Target grid positions + + # Input buffer for direction changes + self.input_buffer = [] # Store up to 2 pending direction changes + self.max_buffer_size = 2 + + # Create base segment surface with gradient + self.segment_size = 32 # Base size for segment surface + self.base_surface = pygame.Surface((self.segment_size, self.segment_size), pygame.SRCALPHA) + + # Create a radial gradient for the base segment + center = self.segment_size // 2 + for radius in range(center, -1, -1): + # Calculate color intensity based on distance from center + intensity = 1.0 - (radius / center) ** 0.8 # Adjust power for gradient shape + color = ( + int(SNAKE_GRADIENT_START[0] * intensity), + int(SNAKE_GRADIENT_START[1] * intensity), + int(SNAKE_GRADIENT_START[2] * intensity), + 255 + ) + pygame.draw.circle(self.base_surface, color, (center, center), radius) + + def move(self) -> Tuple[int, int]: + """ + Move the snake one grid cell in current direction. + + Returns: + New head position (x, y) + """ + # Get new head position + head = self.body[0] + new_head = self._get_new_head_position(head) + + # Update grid positions + self.body.insert(0, new_head) + if not self.growing: + self.body.pop() + else: + self.growing = False + + # Update visual positions + self.target_positions = list(self.body) # Copy current grid positions as targets + if len(self.visual_positions) < len(self.target_positions): + self.visual_positions.append(self.visual_positions[-1]) + + # Reset move progress for smooth animation + self.move_progress = 0.0 + + return new_head + + def update_movement(self, dt: float): + """Update visual position interpolation with consistent timing.""" + if self.move_progress >= 1.0: + return + + # Use cubic easing for smoother acceleration/deceleration + self.move_progress = min(1.0, self.move_progress + self.move_speed) + t = self.move_progress + overall_progress = t * t * (3 - 2 * t) + + # Update each segment with fixed delay windows + segment_count = len(self.visual_positions) + for i in range(segment_count): + if i >= len(self.target_positions): + continue + + current = self.visual_positions[i] + target = self.target_positions[i] + + dx = target[0] - current[0] + dy = target[1] - current[1] + + # Check if this segment needs to wrap + wrap_x = abs(dx) > 2 + wrap_y = abs(dy) > 2 + + if wrap_x or wrap_y: + self.visual_positions[i] = target + else: + # Use fixed delay windows instead of compounding delays + delay_window = 0.2 # Total delay spread across all segments + segment_delay = (i / segment_count) * delay_window + + # Calculate progress for this segment + segment_progress = max(0.0, (overall_progress - segment_delay) / (1.0 - segment_delay)) + segment_progress = min(1.0, segment_progress) + + # Apply smoothed movement + new_x = current[0] + dx * segment_progress + new_y = current[1] + dy * segment_progress + self.visual_positions[i] = (new_x, new_y) + + def _get_new_head_position(self, head: Tuple[int, int]) -> Tuple[int, int]: + """Calculate new head position based on current direction""" + x, y = head + dx, dy = self.direction.to_vector() + return (x + dx, y + dy) + + def grow(self): + """Mark the snake to grow on next move""" + self.growing = True + + def check_collision(self, width: int, height: int, wrap_around: bool = False) -> bool: + """ + Check if snake has collided with walls or itself. + + Args: + width: Game area width + height: Game area height + wrap_around: If True, snake wraps around screen edges instead of colliding + """ + head = self.body[0] + + if not wrap_around: + # Check wall collision + if (head[0] < 0 or head[0] >= width or + head[1] < 0 or head[1] >= height): + return True + + # Check self collision (skip head) + if head in self.body[1:]: + return True + + return False + + def wrap_position(self, width: int, height: int): + """Wrap snake's head position around screen edges""" + head_x, head_y = self.body[0] + wrapped_head = ( + head_x % width, + head_y % height + ) + self.body[0] = wrapped_head + + def is_direction_valid(self, new_direction: Direction) -> bool: + """ + Try to change direction, ensuring no 180-degree turns. + + Args: + new_direction: The direction to change to + + Returns: + True if direction was changed, False otherwise + """ + opposite_directions = { + Direction.UP: Direction.DOWN, + Direction.DOWN: Direction.UP, + Direction.LEFT: Direction.RIGHT, + Direction.RIGHT: Direction.LEFT + } + + if opposite_directions[new_direction] == self.direction or new_direction == self.direction: + return False + + self.direction = new_direction + return True + + def change_direction(self, new_direction: Direction): + """ + Buffer a direction change to be applied on next move. + + Args: + new_direction: The direction to change to + """ + if self.input_buffer and self.input_buffer[-1] == new_direction: + return + + if len(self.input_buffer) < self.max_buffer_size: + self.input_buffer.append(new_direction) + + def draw(self, screen: pygame.Surface, game_area): + """Draw snake with clean, connected design and cute tongue.""" + if len(self.visual_positions) < 2: + return + + base_size = game_area.grid_size + segment_size = base_size - 4 + + # Draw snake body segments + for i in range(len(self.visual_positions) - 1): + pos1 = self.visual_positions[i] + pos2 = self.visual_positions[i + 1] + + screen_positions = self._get_wrapped_segment_positions(pos1, pos2, game_area) + if not screen_positions: + continue + + for sp1, sp2 in screen_positions: + # Calculate segment properties + progress = max(0.4, 1.0 - (i / len(self.visual_positions))) + + color = ( + int(SNAKE_GRADIENT_START[0] * progress), + int(SNAKE_GRADIENT_START[1] * progress), + int(SNAKE_GRADIENT_START[2] * progress) + ) + + # Draw connecting segments with extra thickness + center1 = (sp1[0] + base_size//2, sp1[1] + base_size//2) + center2 = (sp2[0] + base_size//2, sp2[1] + base_size//2) + + # Draw a thicker base line first + pygame.draw.line(screen, color, center1, center2, segment_size + 4) + + # Draw larger circles at joints to ensure corner coverage + pygame.draw.circle(screen, color, center1, (segment_size + 2)//2) + pygame.draw.circle(screen, color, center2, (segment_size + 2)//2) + + # If this is a corner (check by looking at next segment) + if i < len(self.visual_positions) - 2: + pos3 = self.visual_positions[i + 2] + if pos2[0] != pos1[0] and pos3[1] != pos2[1] or \ + pos2[1] != pos1[1] and pos3[0] != pos2[0]: + # Draw extra circle at the corner to ensure coverage + pygame.draw.circle(screen, color, center2, (segment_size + 6)//2) + + # Draw head + head_pos = self.visual_positions[0] + screen_pos = game_area.get_screen_pos(head_pos[0], head_pos[1]) + center = (screen_pos[0] + base_size//2, screen_pos[1] + base_size//2) + + # Draw head circle + pygame.draw.circle(screen, SNAKE_GRADIENT_START, center, segment_size//2) + + # Calculate direction for eyes and tongue + next_pos = self.visual_positions[1] if len(self.visual_positions) > 1 else head_pos + dx = head_pos[0] - next_pos[0] # Reversed the direction calculation + dy = head_pos[1] - next_pos[1] # Reversed the direction calculation + angle = math.atan2(dy, dx) + + # Draw eyes + eye_offset = segment_size//4 + eye_size = 2 + eye_angle = angle + math.pi/2 + + for side in [-1, 1]: + eye_x = center[0] + math.cos(eye_angle) * eye_offset * side + eye_y = center[1] + math.sin(eye_angle) * eye_offset + pygame.draw.circle(screen, (0, 0, 0), (int(eye_x), int(eye_y)), eye_size) + + # Draw flickering tongue with more pronounced animation + time = pygame.time.get_ticks() + tongue_flick = math.sin(time * 0.01) * 0.8 + 0.2 # Slower, more pronounced flicking + + # Tongue base position (at front of head) + tongue_base_x = center[0] + math.cos(angle) * (segment_size//2) + tongue_base_y = center[1] + math.sin(angle) * (segment_size//2) + + # Tongue length varies with flicking animation + tongue_length = (3 + tongue_flick * 6) # Length varies between 3 and 9 pixels + fork_length = 4 # Slightly longer fork + fork_angle = math.pi/3 # Wider fork angle + + # Calculate tongue tip + tongue_tip_x = tongue_base_x + math.cos(angle) * tongue_length + tongue_tip_y = tongue_base_y + math.sin(angle) * tongue_length + + # Calculate fork tips + left_fork_x = tongue_tip_x + math.cos(angle + fork_angle) * fork_length + left_fork_y = tongue_tip_y + math.sin(angle + fork_angle) * fork_length + right_fork_x = tongue_tip_x + math.cos(angle - fork_angle) * fork_length + right_fork_y = tongue_tip_y + math.sin(angle - fork_angle) * fork_length + + # Draw tongue with brighter red + tongue_color = (255, 0, 0) # Bright red + + # Draw main tongue line + pygame.draw.line(screen, tongue_color, + (tongue_base_x, tongue_base_y), + (tongue_tip_x, tongue_tip_y), 2) + + # Draw forked tips + pygame.draw.line(screen, tongue_color, + (tongue_tip_x, tongue_tip_y), + (left_fork_x, left_fork_y), 2) + pygame.draw.line(screen, tongue_color, + (tongue_tip_x, tongue_tip_y), + (right_fork_x, right_fork_y), 2) + + def _get_wrapped_segment_positions(self, pos1, pos2, game_area): + """ + Get screen positions for segment rendering, handling wrap-around. + Only returns positions for segments that are close enough to interpolate. + """ + grid_w = self.grid_width + grid_h = self.grid_height + + # Calculate primary direction + dx = pos2[0] - pos1[0] + dy = pos2[1] - pos1[1] + + # If segments are too far apart (wrapping), don't draw connecting segment + if abs(dx) > 2 or abs(dy) > 2: + return [] + + # Convert to screen coordinates + screen_pos1 = game_area.get_screen_pos(pos1[0], pos1[1]) + screen_pos2 = game_area.get_screen_pos(pos2[0], pos2[1]) + + return [(screen_pos1, screen_pos2)] + diff --git a/src/food.py b/src/food.py deleted file mode 100644 index c1aa3e6..0000000 --- a/src/food.py +++ /dev/null @@ -1,50 +0,0 @@ -import pygame -import random -from typing import Tuple, List - -class Food: - def __init__(self, block_size: int, color: Tuple[int, int, int] = (255, 0, 0)): - self.block_size = block_size - self.color = color - self.position = (0, 0) # Will be set by spawn() - - def spawn(self, width: int, height: int, occupied_positions: List[Tuple[int, int]]) -> None: - """ - Spawn food at a random position that isn't occupied by the snake. - - Args: - width: Game area width - height: Game area height - occupied_positions: List of positions (typically snake body positions) where food shouldn't spawn - """ - # Calculate all valid grid positions - valid_positions = [ - (x, y) - for x in range(0, width, self.block_size) - for y in range(0, height, self.block_size) - if (x, y) not in occupied_positions - ] - - if not valid_positions: - # No valid positions (snake fills screen) - game should be won - return - - # Choose random position from valid positions - self.position = random.choice(valid_positions) - - def draw(self, screen: pygame.Surface) -> None: - """Draw the food on the screen""" - pygame.draw.rect( - screen, - self.color, - pygame.Rect( - self.position[0], - self.position[1], - self.block_size, - self.block_size - ) - ) - - def check_collision(self, position: Tuple[int, int]) -> bool: - """Check if the given position collides with the food""" - return self.position == position \ No newline at end of file diff --git a/src/game.py b/src/game.py index 268dfb7..b2b6c4c 100644 --- a/src/game.py +++ b/src/game.py @@ -1,190 +1,109 @@ +""" +Main Game Module + +This module provides the main game class that handles the game loop, +state management, and user interface. +""" + import pygame import sys from enum import Enum, auto -from src.snake import Snake, Direction -from src.food import Food -from src.menu import Menu, GameMode, MenuItem + +from src.config import ( + WINDOW_WIDTH, + WINDOW_HEIGHT, + FPS, + SCORE_FONT_SIZE +) +from src.config import GameSettings +from src.config.colors import BLACK, GREEN, NEON_GREEN +from src.core import Direction, GameSession +from src.ui import Menu, GameMode, SettingsMenu, PauseMenu, GameArea class GameState(Enum): + """Game states.""" MENU = auto() - SETTINGS = auto() # New state for settings menu + SETTINGS = auto() PLAYING = auto() GAME_OVER = auto() PAUSED = auto() -class GameRules: - def __init__(self): - self.wrap_around = False # Whether snake wraps around screen edges - self.speed_increase = True # Whether snake speeds up as it grows - self.min_move_cooldown = 50 # Minimum movement delay in milliseconds - self.initial_move_cooldown = 100 # Initial movement delay - -class SettingsMenu: - def __init__(self, width, height, rules): - self.width = width - self.height = height - self.rules = rules - self.setup_menu_items() - - self.title_font = pygame.font.Font(None, 72) - self.subtitle_font = pygame.font.Font(None, 36) - - # Create title surfaces - self.title_surface = self.title_font.render("Settings", True, (0, 255, 0)) - self.title_rect = self.title_surface.get_rect(center=(width//2, height//4)) - - # Initialize first item as selected - self.selected_index = 0 - self.menu_items[0].hover = True - self.menu_items[0]._setup_font() - - def setup_menu_items(self): - start_y = self.height // 2 - spacing = 50 - center_x = self.width // 2 - - self.menu_items = [ - MenuItem(f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}", - (center_x, start_y), - 'toggle_wrap'), - MenuItem(f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}", - (center_x, start_y + spacing), - 'toggle_speed'), - MenuItem("Back to Menu", - (center_x, start_y + spacing * 3), - 'back') - ] - - def update(self): - # Handle mouse hover - mouse_pos = pygame.mouse.get_pos() - for i, item in enumerate(self.menu_items): - if item.rect.collidepoint(mouse_pos): - # Update selected index when mouse hovers - self.selected_index = i - item.hover = True - item._setup_font() - else: - # Keep keyboard selection visible - item.hover = (i == self.selected_index) - item._setup_font() - - def handle_input(self, event): - if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: - # Handle mouse clicks - mouse_pos = pygame.mouse.get_pos() - for i, item in enumerate(self.menu_items): - if item.rect.collidepoint(mouse_pos): - self.selected_index = i - if item.action == 'toggle_wrap': - self.rules.wrap_around = not self.rules.wrap_around - item.text = f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}" - item._setup_font() - elif item.action == 'toggle_speed': - self.rules.speed_increase = not self.rules.speed_increase - item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}" - item._setup_font() - elif item.action == 'back': - return 'back' - - elif event.type == pygame.KEYDOWN: - if event.key == pygame.K_ESCAPE: - return 'back' - elif event.key == pygame.K_RETURN: - item = self.menu_items[self.selected_index] - if item.action == 'toggle_wrap': - self.rules.wrap_around = not self.rules.wrap_around - item.text = f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}" - item._setup_font() - elif item.action == 'toggle_speed': - self.rules.speed_increase = not self.rules.speed_increase - item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}" - item._setup_font() - elif item.action == 'back': - return 'back' - elif event.key in (pygame.K_UP, pygame.K_DOWN): - # Update selected index - if event.key == pygame.K_UP: - self.selected_index = (self.selected_index - 1) % len(self.menu_items) - else: - self.selected_index = (self.selected_index + 1) % len(self.menu_items) - - # Update hover states - for i, item in enumerate(self.menu_items): - item.hover = (i == self.selected_index) - item._setup_font() - - return None - - def draw(self, screen): - # Draw background - screen.fill((0, 0, 0)) - - # Draw title - screen.blit(self.title_surface, self.title_rect) - - # Draw menu items - for item in self.menu_items: - item.draw(screen) - - # Draw controls - controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to go back" - font = pygame.font.Font(None, 24) - controls_surface = font.render(controls_text, True, (100, 100, 100)) - screen.blit(controls_surface, - (self.width - controls_surface.get_width() - 10, - self.height - 30)) - class Game: + """Main game class that manages the game loop and state.""" + def __init__(self): - # Initialize pygame and font system - if not pygame.get_init(): - pygame.init() - if not pygame.font.get_init(): - pygame.font.init() + """Initialize the game.""" + pygame.init() + pygame.font.init() - # Initialize display - self.width = 800 - self.height = 600 + # Window setup + self.width = WINDOW_WIDTH + self.height = WINDOW_HEIGHT self.screen = pygame.display.set_mode((self.width, self.height)) - pygame.display.set_caption("AI Snake Game") + pygame.display.set_caption("Snake Game") - # Initialize clock + # Initialize game components self.clock = pygame.time.Clock() - self.fps = 60 - - # Game rules - self.rules = GameRules() - - # Game objects - self.block_size = 20 - self.menu = Menu(self.width, self.height) - self.reset_game() + self.font = pygame.font.Font(None, SCORE_FONT_SIZE) + self.settings = GameSettings() + self.menu = Menu() + self.settings_menu = SettingsMenu(rules=self.settings.rules) + self.pause_menu = PauseMenu() # Game state - self.state = GameState.MENU # Start in menu state - self.running = True + self.state = GameState.MENU + self.game_session = None self.game_mode = None + self.running = True - # Debug mode - self.debug = False - - # Add settings menu - self.settings_menu = SettingsMenu(self.width, self.height, self.rules) + self.reset_game() def reset_game(self): - """Reset the game state for a new game""" - self.snake = Snake((self.width // 2, self.height // 2), self.block_size) - self.food = Food(self.block_size) - self.food.spawn(self.width, self.height, self.snake.body) - self.score = 0 + """Initialize or reset the game session.""" + if self.game_session: + del self.game_session + self.game_session = GameSession( + grid_width=30, + grid_height=30, + rules=self.settings.rules, + window_width=self.width, + window_height=self.height + ) def handle_events(self): + """Handle game events.""" current_time = pygame.time.get_ticks() for event in pygame.event.get(): if event.type == pygame.QUIT: self.running = False + elif event.type == pygame.MOUSEBUTTONDOWN: + # Handle mouse clicks in menus + if self.state == GameState.MENU: + selected_mode = self.menu.handle_input(event) + if selected_mode is not None: + if selected_mode == GameMode.SETTINGS: + self.state = GameState.SETTINGS + elif selected_mode == GameMode.PLAYER: + self.game_mode = GameMode.PLAYER + self.state = GameState.PLAYING + self.reset_game() + elif selected_mode in (GameMode.AI_EASY, GameMode.AI_MEDIUM, GameMode.AI_HARD): + self.game_mode = selected_mode + self.state = GameState.PLAYING + self.reset_game() + else: # Quit selected + self.running = False + elif self.state == GameState.SETTINGS: + result = self.settings_menu.handle_input(event) + if result == 'back': + self.state = GameState.MENU + elif self.state == GameState.PAUSED: + result = self.pause_menu.handle_input(event) + if result == 'resume': + self.state = GameState.PLAYING + elif result == 'menu': + self.state = GameState.MENU elif event.type == pygame.KEYDOWN: if event.key == pygame.K_ESCAPE: if self.state == GameState.PLAYING: @@ -196,156 +115,108 @@ class Game: elif self.state == GameState.SETTINGS: self.state = GameState.MENU elif event.key == pygame.K_F3: # Toggle debug mode - self.debug = not self.debug - - if self.state == GameState.MENU: - # Handle menu input - selected_mode = self.menu.handle_input(event) - if selected_mode is not None: - if selected_mode == GameMode.SETTINGS: - self.state = GameState.SETTINGS - elif selected_mode == GameMode.PLAYER: - self.game_mode = GameMode.PLAYER - self.state = GameState.PLAYING - elif selected_mode in (GameMode.AI_EASY, GameMode.AI_MEDIUM, GameMode.AI_HARD): - self.game_mode = selected_mode - self.state = GameState.PLAYING - else: # Quit selected - self.running = False - - elif self.state == GameState.SETTINGS: - # Handle settings menu input - result = self.settings_menu.handle_input(event) - if result == 'back': - self.state = GameState.MENU - + self.settings.debug_mode = not self.settings.debug_mode + if self.game_session: + self.game_session.show_grid = self.settings.debug_mode elif self.state == GameState.PLAYING: # Handle snake direction if event.key == pygame.K_UP: - self.snake.change_direction(Direction.UP, current_time) + self.game_session.step(Direction.UP, current_time) elif event.key == pygame.K_DOWN: - self.snake.change_direction(Direction.DOWN, current_time) + self.game_session.step(Direction.DOWN, current_time) elif event.key == pygame.K_LEFT: - self.snake.change_direction(Direction.LEFT, current_time) + self.game_session.step(Direction.LEFT, current_time) elif event.key == pygame.K_RIGHT: - self.snake.change_direction(Direction.RIGHT, current_time) - + self.game_session.step(Direction.RIGHT, current_time) elif self.state == GameState.GAME_OVER: # Any key to return to menu in game over state self.state = GameState.MENU - self.reset_game() def update(self): + """Update game state.""" + current_time = pygame.time.get_ticks() + if self.state == GameState.MENU: self.menu.update() elif self.state == GameState.SETTINGS: self.settings_menu.update() + elif self.state == GameState.PAUSED: + self.pause_menu.update() elif self.state == GameState.PLAYING: - # Move snake - if self.snake.move(pygame.time.get_ticks()): - # Handle wrap-around - if self.rules.wrap_around: - self.snake.wrap_position(self.width, self.height) - - # Check collisions - if self.snake.check_collision(self.width, self.height, self.rules.wrap_around): - self.state = GameState.GAME_OVER - return - - # Check food collision - if self.food.check_collision(self.snake.body[0]): - self.snake.grow() - self.score += 1 - self.food.spawn(self.width, self.height, self.snake.body) - - # Increase speed if enabled - if self.rules.speed_increase: - self.snake.move_cooldown = max( - self.rules.min_move_cooldown, - self.rules.initial_move_cooldown - self.score * 2 - ) - - def draw_grid(self): - """Draw the game grid (debug mode)""" - for x in range(0, self.width, self.block_size): - pygame.draw.line(self.screen, (50, 50, 50), (x, 0), (x, self.height)) - for y in range(0, self.height, self.block_size): - pygame.draw.line(self.screen, (50, 50, 50), (0, y), (self.width, y)) + # Update game session + _, _, done = self.game_session.step(current_time=current_time) + if done: + self.state = GameState.GAME_OVER def draw_debug_info(self): - """Draw debug information""" + """Draw debug information.""" + if not self.game_session: + return + + state = self.game_session.get_state() font = pygame.font.Font(None, 24) debug_info = [ f"FPS: {int(self.clock.get_fps())}", - f"Snake Length: {len(self.snake.body)}", - f"Snake Head: {self.snake.body[0]}", - f"Food Pos: {self.food.position}", + f"Snake Length: {len(state['snake_body'])}", + f"Snake Head: {state['snake_body'][0]}", + f"Food Pos: {state['food_position']}", f"Game State: {self.state.name}", f"Game Mode: {self.game_mode.name if self.game_mode else 'None'}", - f"Wrap Around: {self.rules.wrap_around}", - f"Move Cooldown: {self.snake.move_cooldown}ms" + f"Move Cooldown: {state['move_cooldown']}ms" ] for i, text in enumerate(debug_info): - surface = font.render(text, True, (0, 255, 0)) + surface = font.render(text, True, GREEN) self.screen.blit(surface, (10, self.height - 200 + i * 25)) - def render(self): - # Clear screen - self.screen.fill((0, 0, 0)) # Black background + def draw_overlay(self, text): + """Draw a semi-transparent overlay with text.""" + overlay = pygame.Surface((self.width, self.height)) + overlay.fill(BLACK) + overlay.set_alpha(128) + self.screen.blit(overlay, (0, 0)) - # Draw grid in debug mode - if self.debug and self.state not in (GameState.MENU, GameState.SETTINGS): - self.draw_grid() + text_surface = self.font.render(text, True, NEON_GREEN) + text_rect = text_surface.get_rect(center=(self.width // 2, self.height // 2)) + self.screen.blit(text_surface, text_rect) + + def render(self): + """Render the current game state.""" + # Clear screen + self.screen.fill(BLACK) if self.state == GameState.MENU: self.menu.draw(self.screen) elif self.state == GameState.SETTINGS: self.settings_menu.draw(self.screen) - elif self.state == GameState.PLAYING or self.state == GameState.PAUSED: - # Draw game objects - self.snake.draw(self.screen) - self.food.draw(self.screen) + elif self.state in [GameState.PLAYING, GameState.PAUSED, GameState.GAME_OVER]: + # Draw game session + self.game_session.render(self.screen) - # Draw score - font = pygame.font.Font(None, 36) - score_text = font.render(f'Score: {self.score}', True, (255, 255, 255)) - self.screen.blit(score_text, (10, 10)) + # Draw score panel + score_text = f"Score: {self.game_session.score}" + score_surface = self.font.render(score_text, True, GREEN) + score_rect = score_surface.get_rect(midtop=(self.width // 2, 20)) + self.screen.blit(score_surface, score_rect) - # Draw game mode - mode_text = font.render(f'Mode: {self.game_mode.name}', True, (255, 255, 255)) - self.screen.blit(mode_text, (10, 50)) - - # Draw pause indicator + # Draw overlays if self.state == GameState.PAUSED: - pause_text = font.render('PAUSED', True, (255, 255, 255)) - text_rect = pause_text.get_rect(center=(self.width//2, self.height//2)) - self.screen.blit(pause_text, text_rect) - - elif self.state == GameState.GAME_OVER: - font = pygame.font.Font(None, 74) - text = font.render('Game Over!', True, (255, 0, 0)) - score_text = font.render(f'Score: {self.score}', True, (255, 255, 255)) - restart_text = font.render('Press any key for menu', True, (255, 255, 255)) + self.pause_menu.draw(self.screen) + elif self.state == GameState.GAME_OVER: + self.draw_overlay("GAME OVER") - text_rect = text.get_rect(center=(self.width//2, self.height//2 - 50)) - score_rect = score_text.get_rect(center=(self.width//2, self.height//2 + 50)) - restart_rect = restart_text.get_rect(center=(self.width//2, self.height//2 + 150)) - - self.screen.blit(text, text_rect) - self.screen.blit(score_text, score_rect) - self.screen.blit(restart_text, restart_rect) - - # Draw debug info - if self.debug: - self.draw_debug_info() + # Draw debug info if enabled + if self.settings.debug_mode: + self.draw_debug_info() # Update display pygame.display.flip() def run(self): + """Run the main game loop.""" + self.running = True while self.running: self.handle_events() self.update() self.render() - self.clock.tick(self.fps) \ No newline at end of file + self.clock.tick(FPS) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 6d52cc0..c252fba 100644 --- a/src/main.py +++ b/src/main.py @@ -1,8 +1,16 @@ +""" +Main entry point for the AI Snake Game. + +This module provides the standard entry point for running the game +without any command line arguments. +""" + import pygame import sys -from game import Game +from src.game import Game def main(): + """Run the game with default settings.""" # Initialize Pygame pygame.init() diff --git a/src/menu.py b/src/menu.py index a748ba0..b32fcc2 100644 --- a/src/menu.py +++ b/src/menu.py @@ -81,6 +81,8 @@ class Menu: for i, item in enumerate(self.menu_items): if item.rect.collidepoint(mouse_pos): self.selected_index = i + item.hover = True + item._setup_font() return item.action elif event.type == pygame.KEYDOWN: diff --git a/src/snake.py b/src/snake.py deleted file mode 100644 index 40fcdbd..0000000 --- a/src/snake.py +++ /dev/null @@ -1,153 +0,0 @@ -import pygame -from enum import Enum, auto -from typing import List, Tuple - -class Direction(Enum): - UP = auto() - DOWN = auto() - LEFT = auto() - RIGHT = auto() - -class Snake: - def __init__(self, start_pos: Tuple[int, int], block_size: int): - self.block_size = block_size - self.direction = Direction.RIGHT - self.body = [start_pos] # Head is at index 0 - self.growing = False - - # Movement cooldown - self.move_cooldown = 100 # milliseconds - self.last_move_time = 0 - - # Direction change cooldown - self.direction_change_cooldown = 50 # milliseconds - self.last_direction_change = 0 - self.queued_direction = None - self.last_valid_move_time = 0 # Track when we last actually moved - - def move(self, current_time: int) -> bool: - """Move the snake if enough time has passed. Returns True if moved.""" - if current_time - self.last_move_time < self.move_cooldown: - return False - - # Apply queued direction change if it exists and is valid - if self.queued_direction: - if current_time - self.last_valid_move_time >= self.direction_change_cooldown: - if self._apply_direction_change(self.queued_direction): - self.last_direction_change = current_time - self.queued_direction = None - - # Update position - head = self.body[0] - new_head = self._get_new_head_position(head) - - # Insert new head - self.body.insert(0, new_head) - - # Remove tail if not growing - if not self.growing: - self.body.pop() - else: - self.growing = False - - self.last_move_time = current_time - self.last_valid_move_time = current_time - return True - - def _get_new_head_position(self, head: Tuple[int, int]) -> Tuple[int, int]: - x, y = head - if self.direction == Direction.UP: - return (x, y - self.block_size) - elif self.direction == Direction.DOWN: - return (x, y + self.block_size) - elif self.direction == Direction.LEFT: - return (x - self.block_size, y) - else: # Direction.RIGHT - return (x + self.block_size, y) - - def grow(self): - """Mark the snake to grow on next move""" - self.growing = True - - def check_collision(self, width: int, height: int, wrap_around: bool = False) -> bool: - """ - Check if snake has collided with walls or itself. - - Args: - width: Game area width - height: Game area height - wrap_around: If True, snake wraps around screen edges instead of colliding - """ - head = self.body[0] - - if not wrap_around: - # Check wall collision - if (head[0] < 0 or head[0] >= width or - head[1] < 0 or head[1] >= height): - return True - - # Check self collision (skip head) - if head in self.body[1:]: - return True - - return False - - def wrap_position(self, width: int, height: int): - """Wrap snake's head position around screen edges""" - head_x, head_y = self.body[0] - wrapped_head = ( - head_x % width if head_x != width else 0, - head_y % height if head_y != height else 0 - ) - self.body[0] = wrapped_head - - def _apply_direction_change(self, new_direction: Direction) -> bool: - """ - Internal method to actually change direction. - Returns True if direction was changed, False otherwise. - """ - # Prevent 180-degree turns when snake is longer than 1 - if len(self.body) > 1: - opposites = { - Direction.UP: Direction.DOWN, - Direction.DOWN: Direction.UP, - Direction.LEFT: Direction.RIGHT, - Direction.RIGHT: Direction.LEFT - } - if opposites[new_direction] == self.direction: - return False - - self.direction = new_direction - return True - - def change_direction(self, new_direction: Direction, current_time: int): - """ - Queue a direction change if it's valid and cooldown has passed. - - Args: - new_direction: The desired new direction - current_time: Current game time in milliseconds - """ - # If we haven't moved since the last direction change, queue it - if current_time - self.last_valid_move_time < self.direction_change_cooldown: - self.queued_direction = new_direction - return - - # Try to change direction immediately - if self._apply_direction_change(new_direction): - self.last_direction_change = current_time - - def draw(self, screen: pygame.Surface): - """Draw the snake on the screen""" - # Draw head in a slightly different color - head_color = (0, 200, 0) # Darker green for head - body_color = (0, 255, 0) # Regular green for body - - # Draw body segments - for i, segment in enumerate(self.body): - color = head_color if i == 0 else body_color - pygame.draw.rect( - screen, - color, - pygame.Rect(segment[0], segment[1], self.block_size, self.block_size) - ) \ No newline at end of file diff --git a/src/ui/__init__.py b/src/ui/__init__.py new file mode 100644 index 0000000..8bbd38a --- /dev/null +++ b/src/ui/__init__.py @@ -0,0 +1,20 @@ +""" +Game UI Package + +This package contains all user interface components. +""" + +from .menu import Menu, GameMode +from .menu_item import MenuItem +from .settings_menu import SettingsMenu +from .pause_menu import PauseMenu +from .game_area import GameArea + +__all__ = [ + 'Menu', + 'GameMode', + 'MenuItem', + 'SettingsMenu', + 'PauseMenu', + 'GameArea' +] \ No newline at end of file diff --git a/src/ui/game_area.py b/src/ui/game_area.py new file mode 100644 index 0000000..a18eab2 --- /dev/null +++ b/src/ui/game_area.py @@ -0,0 +1,138 @@ +""" +Game Area Module + +This module provides the GameArea class that handles the rendering of the game area +and coordinate conversions between screen and grid coordinates. +""" + +import pygame +from src.config import ( + DEFAULT_PADDING, + DEFAULT_SCORE_HEIGHT, +) +from src.config.colors import * +from typing import Tuple + +class GameArea: + """Manages the game area rendering and coordinate conversions.""" + + def __init__(self, window_width, window_height, grid_size=20, grid_width=30, grid_height=30, + padding=DEFAULT_PADDING, score_height=DEFAULT_SCORE_HEIGHT): + # Grid properties + self.grid_size = grid_size + self.grid_width = grid_width + self.grid_height = grid_height + + # Area properties + self.padding = padding + self.score_height = score_height + self.border_thickness = 4 # Increased border thickness + + # Calculate game area dimensions + self.width = self.grid_width * self.grid_size + self.height = self.grid_height * self.grid_size + + # Center the game area in the window + self.x = (window_width - self.width) // 2 + self.y = self.score_height + self.padding + + # Create rectangle for game bounds + self.rect = pygame.Rect(self.x, self.y, self.width, self.height) + + # Create inner rectangle for gameplay area with thicker border + self.inner_rect = pygame.Rect( + self.x + self.border_thickness, + self.y + self.border_thickness, + self.width - (self.border_thickness * 2), + self.height - (self.border_thickness * 2) + ) + + self.screen_offset = (self.x, self.y) + + def get_grid_pos(self, screen_x, screen_y): + """Convert screen coordinates to grid position.""" + grid_x = (screen_x - self.x) // self.grid_size + grid_y = (screen_y - self.y) // self.grid_size + return grid_x, grid_y + + def get_screen_pos(self, grid_x, grid_y): + """Convert grid position to screen coordinates.""" + screen_x = self.x + (grid_x * self.grid_size) + screen_y = self.y + (grid_y * self.grid_size) + return screen_x, screen_y + + def is_within_bounds(self, grid_x, grid_y): + """Check if grid position is within game bounds.""" + return 0 <= grid_x < self.grid_width and 0 <= grid_y < self.grid_height + + def draw(self, screen): + """Draw the game area boundary and background with modern effects.""" + # Draw outer border with gradient effect + for i in range(self.border_thickness): + border_rect = pygame.Rect( + self.x + i, + self.y + i, + self.width - (i * 2), + self.height - (i * 2) + ) + # Gradient from darker to lighter + color_factor = i / self.border_thickness + r = int(DARKER_GREEN[0] * (1 - color_factor) + FOREST_GREEN[0] * color_factor) + g = int(DARKER_GREEN[1] * (1 - color_factor) + FOREST_GREEN[1] * color_factor) + b = int(DARKER_GREEN[2] * (1 - color_factor) + FOREST_GREEN[2] * color_factor) + pygame.draw.rect(screen, (r, g, b), border_rect) + + # Draw inner background + pygame.draw.rect(screen, DARK_GRAY, self.inner_rect) + + # Draw subtle grid pattern + if hasattr(self, 'show_grid') and self.show_grid: + # Debug grid lines + for x in range(self.grid_width + 1): + start_pos = (self.x + x * self.grid_size, self.y) + end_pos = (self.x + x * self.grid_size, self.y + self.height) + pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1) + + for y in range(self.grid_height + 1): + start_pos = (self.x, self.y + y * self.grid_size) + end_pos = (self.x + self.width, self.y + y * self.grid_size) + pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1) + else: + # Subtle grid dots with fade effect + for x in range(self.grid_width + 1): + for y in range(self.grid_height + 1): + pos_x = self.x + x * self.grid_size + pos_y = self.y + y * self.grid_size + + # Calculate distance from center for fade effect + center_x = self.x + (self.width / 2) + center_y = self.y + (self.height / 2) + dist_x = abs(pos_x - center_x) / (self.width / 2) + dist_y = abs(pos_y - center_y) / (self.height / 2) + dist = min(1.0, (dist_x + dist_y) / 2) + + # Fade dot color based on distance from center + dot_color = tuple(int(c * (1 - dist * 0.5)) for c in SUBTLE_GRID_COLOR) + pygame.draw.circle(screen, dot_color, (pos_x, pos_y), 1) + + # Draw inner border glow + glow_colors = [ + (int(NEON_GREEN[0] * 0.4), int(NEON_GREEN[1] * 0.4), int(NEON_GREEN[2] * 0.4)), + (int(NEON_GREEN[0] * 0.6), int(NEON_GREEN[1] * 0.6), int(NEON_GREEN[2] * 0.6)), + NEON_GREEN + ] + + for i, color in enumerate(glow_colors): + glow_rect = pygame.Rect( + self.inner_rect.x - i, + self.inner_rect.y - i, + self.inner_rect.width + (i * 2), + self.inner_rect.height + (i * 2) + ) + pygame.draw.rect(screen, color, glow_rect, 1) + + def grid_to_screen(self, grid_pos: Tuple[int, int]) -> Tuple[int, int]: + """Convert grid coordinates to screen coordinates""" + x = self.rect.left + grid_pos[0] * self.grid_size + self.grid_size // 2 + y = self.rect.top + grid_pos[1] * self.grid_size + self.grid_size // 2 + return (x, y) \ No newline at end of file diff --git a/src/ui/menu.py b/src/ui/menu.py new file mode 100644 index 0000000..ac40238 --- /dev/null +++ b/src/ui/menu.py @@ -0,0 +1,127 @@ +""" +Main Menu Module + +This module provides the main menu interface for the game. +""" + +import pygame +from enum import Enum, auto +from src.config import ( + WINDOW_WIDTH, + WINDOW_HEIGHT, + TITLE_FONT_SIZE, + SUBTITLE_FONT_SIZE, + MENU_ITEM_SPACING +) +from src.config.colors import BLACK, GRAY +from src.ui.menu_item import MenuItem + +class GameMode(Enum): + """Available game modes.""" + PLAYER = auto() + AI_EASY = auto() + AI_MEDIUM = auto() + AI_HARD = auto() + SETTINGS = auto() + +class Menu: + """Main menu interface.""" + + def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT): + """Initialize the main menu.""" + self.width = width + self.height = height + self.setup_menu_items() + + # Create fonts + self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE) + self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE) + + # Create title surfaces + self.title_surface = self.title_font.render("Snake Game", True, (0, 255, 0)) + self.title_rect = self.title_surface.get_rect(center=(width//2, height//4)) + + # Initialize first item as selected + self.selected_index = 0 + self.menu_items[0].hover = True + self.menu_items[0]._setup_font() + + def setup_menu_items(self): + """Setup the menu items.""" + start_y = self.height // 2 + spacing = MENU_ITEM_SPACING + center_x = self.width // 2 + + self.menu_items = [ + MenuItem("Player Game", (center_x, start_y), GameMode.PLAYER), + MenuItem("AI Game (Easy)", (center_x, start_y + spacing), GameMode.AI_EASY), + MenuItem("AI Game (Medium)", (center_x, start_y + spacing * 2), GameMode.AI_MEDIUM), + MenuItem("AI Game (Hard)", (center_x, start_y + spacing * 3), GameMode.AI_HARD), + MenuItem("Settings", (center_x, start_y + spacing * 4), GameMode.SETTINGS), + MenuItem("Quit", (center_x, start_y + spacing * 5), 'quit') + ] + + def update(self): + """Update menu state.""" + # Handle mouse hover + mouse_pos = pygame.mouse.get_pos() + for i, item in enumerate(self.menu_items): + if item.rect.collidepoint(mouse_pos): + # Update selected index when mouse hovers + self.selected_index = i + item.hover = True + item._setup_font() + else: + # Keep keyboard selection visible + item.hover = (i == self.selected_index) + item._setup_font() + + def handle_input(self, event): + """ + Handle input events. + + Returns: + GameMode or None: The selected game mode or None if no selection made + """ + if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: + # Handle mouse clicks + mouse_pos = pygame.mouse.get_pos() + for item in self.menu_items: + if item.rect.collidepoint(mouse_pos): + return item.action + + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_RETURN: + return self.menu_items[self.selected_index].action + elif event.key in (pygame.K_UP, pygame.K_DOWN): + # Update selected index + if event.key == pygame.K_UP: + self.selected_index = (self.selected_index - 1) % len(self.menu_items) + else: + self.selected_index = (self.selected_index + 1) % len(self.menu_items) + + # Update hover states + for i, item in enumerate(self.menu_items): + item.hover = (i == self.selected_index) + item._setup_font() + + return None + + def draw(self, screen): + """Draw the menu to the screen.""" + # Draw background + screen.fill(BLACK) + + # Draw title + screen.blit(self.title_surface, self.title_rect) + + # Draw menu items + for item in self.menu_items: + item.draw(screen) + + # Draw controls + controls_text = "Arrow keys or mouse to navigate, Enter to select" + controls_surface = self.subtitle_font.render(controls_text, True, GRAY) + screen.blit(controls_surface, + (self.width - controls_surface.get_width() - 10, + self.height - 30)) \ No newline at end of file diff --git a/src/ui/menu_item.py b/src/ui/menu_item.py new file mode 100644 index 0000000..8cc862e --- /dev/null +++ b/src/ui/menu_item.py @@ -0,0 +1,84 @@ +""" +Menu Item Module + +This module provides the MenuItem class that represents a clickable menu item. +""" + +import pygame +from src.config.colors import * + +class MenuItem: + """A clickable menu item with hover effects.""" + + def __init__(self, text, position, action, font_size_normal=36, font_size_hover=48): + """ + Initialize a menu item. + + Args: + text: The text to display + position: (x, y) tuple for center position + action: String identifier for the action to take when clicked + font_size_normal: Font size when not hovered (default 36) + font_size_hover: Font size when hovered (default 48) + """ + self.text = text + self.position = position + self.action = action + self.hover = False + self.font_size_normal = font_size_normal + self.font_size_hover = font_size_hover + self._setup_font() + + def _setup_font(self): + """Setup the font based on hover state.""" + size = self.font_size_hover if self.hover else self.font_size_normal + self.font = pygame.font.Font(None, size) + self.surface = self.font.render(self.text, True, GREEN if self.hover else WHITE) + self.rect = self.surface.get_rect(center=self.position) + + def draw(self, screen): + """Draw the menu item to the screen.""" + screen.blit(self.surface, self.rect) + + +class NumberMenuItem(MenuItem): + """A menu item that displays a number that can be adjusted with arrow keys.""" + + def __init__(self, text, position, action, min_value, max_value, + step=1, font_size_normal=36, font_size_hover=48): + """ + Initialize a number menu item. + + Args: + text: The text to display + position: (x, y) tuple for center position + action: String identifier for the action to take when clicked + min_value: Minimum allowed value + max_value: Maximum allowed value + step: Amount to increment/decrement by (default 1) + font_size_normal: Font size when not hovered (default 36) + font_size_hover: Font size when hovered (default 48) + """ + super().__init__(text, position, action, font_size_normal, font_size_hover) + self.min_value = min_value + self.max_value = max_value + self.step = step + self.current_value = int(text.split(": ")[1]) # Extract number from text + + def increment(self): + """Increment the current value by step amount.""" + if self.current_value + self.step <= self.max_value: + self.current_value += self.step + self._update_text() + + def decrement(self): + """Decrement the current value by step amount.""" + if self.current_value - self.step >= self.min_value: + self.current_value -= self.step + self._update_text() + + def _update_text(self): + """Update the displayed text with new value.""" + base_text = self.text.split(": ")[0] + self.text = f"{base_text}: {self.current_value}" + self._setup_font() diff --git a/src/ui/pause_menu.py b/src/ui/pause_menu.py new file mode 100644 index 0000000..fc17968 --- /dev/null +++ b/src/ui/pause_menu.py @@ -0,0 +1,129 @@ +""" +Pause Menu Module + +This module provides the pause menu interface that appears during gameplay. +""" + +import pygame +from src.config import ( + WINDOW_WIDTH, + WINDOW_HEIGHT, + TITLE_FONT_SIZE, + SUBTITLE_FONT_SIZE, + MENU_ITEM_SPACING +) +from src.config.colors import * +from src.ui.menu_item import MenuItem + +class PauseMenu: + """Pause menu interface.""" + + def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT): + """ + Initialize the pause menu. + + Args: + width: Window width + height: Window height + """ + self.width = width + self.height = height + self.setup_menu_items() + + self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE) + self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE) + + # Create title surfaces + self.title_surface = self.title_font.render("Paused", True, (0, 255, 0)) + self.title_rect = self.title_surface.get_rect(center=(width//2, height//3)) + + # Initialize first item as selected + self.selected_index = 0 + self.menu_items[0].hover = True + self.menu_items[0]._setup_font() + + def setup_menu_items(self): + """Setup the menu items.""" + start_y = self.height // 2 + spacing = MENU_ITEM_SPACING + center_x = self.width // 2 + + self.menu_items = [ + MenuItem("Resume", + (center_x, start_y), + 'resume'), + MenuItem("Return to Menu", + (center_x, start_y + spacing), + 'menu') + ] + + def update(self): + """Update menu state.""" + # Handle mouse hover + mouse_pos = pygame.mouse.get_pos() + for i, item in enumerate(self.menu_items): + if item.rect.collidepoint(mouse_pos): + # Update selected index when mouse hovers + self.selected_index = i + item.hover = True + item._setup_font() + else: + # Keep keyboard selection visible + item.hover = (i == self.selected_index) + item._setup_font() + + def handle_input(self, event): + """ + Handle input events. + + Returns: + str or None: The action to take or None if no action + """ + if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: + # Handle mouse clicks + mouse_pos = pygame.mouse.get_pos() + for i, item in enumerate(self.menu_items): + if item.rect.collidepoint(mouse_pos): + self.selected_index = i + return item.action + + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_ESCAPE: + return 'resume' + elif event.key == pygame.K_RETURN: + return self.menu_items[self.selected_index].action + elif event.key in (pygame.K_UP, pygame.K_DOWN): + # Update selected index + if event.key == pygame.K_UP: + self.selected_index = (self.selected_index - 1) % len(self.menu_items) + else: + self.selected_index = (self.selected_index + 1) % len(self.menu_items) + + # Update hover states + for i, item in enumerate(self.menu_items): + item.hover = (i == self.selected_index) + item._setup_font() + + return None + + def draw(self, screen): + """Draw the menu to the screen.""" + # Draw semi-transparent background + overlay = pygame.Surface((screen.get_width(), screen.get_height())) + overlay.fill(BLACK) + overlay.set_alpha(128) + screen.blit(overlay, (0, 0)) + + # Draw title + screen.blit(self.title_surface, self.title_rect) + + # Draw menu items + for item in self.menu_items: + item.draw(screen) + + # Draw controls + controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to resume" + controls_surface = self.subtitle_font.render(controls_text, True, GRAY) + screen.blit(controls_surface, + (self.width - controls_surface.get_width() - 10, + self.height - 30)) \ No newline at end of file diff --git a/src/ui/settings_menu.py b/src/ui/settings_menu.py new file mode 100644 index 0000000..52a7812 --- /dev/null +++ b/src/ui/settings_menu.py @@ -0,0 +1,159 @@ +""" +Settings Menu Module + +This module provides the settings menu interface for configuring game rules. +""" + +import pygame +from src.config import ( + WINDOW_WIDTH, + WINDOW_HEIGHT, + TITLE_FONT_SIZE, + SUBTITLE_FONT_SIZE, + MENU_ITEM_SPACING, +) +from src.config.colors import * +from src.ui.menu_item import MenuItem, NumberMenuItem + +class SettingsMenu: + """Settings menu interface.""" + + def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT, rules=None): + """ + Initialize the settings menu. + + Args: + width: Window width + height: Window height + rules: GameRules instance to modify + """ + self.width = width + self.height = height + self.rules = rules + self.setup_menu_items() + + self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE) + self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE) + + # Create title surfaces + self.title_surface = self.title_font.render("Settings", True, (0, 255, 0)) + self.title_rect = self.title_surface.get_rect(center=(width//2, height//4)) + + # Initialize first item as selected + self.selected_index = 0 + self.menu_items[0].hover = True + self.menu_items[0]._setup_font() + + def setup_menu_items(self): + """Setup the menu items.""" + start_y = self.height // 2 + spacing = MENU_ITEM_SPACING + center_x = self.width // 2 + + self.menu_items = [ + MenuItem(f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}", + (center_x, start_y), + 'toggle_wrap'), + MenuItem(f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}", + (center_x, start_y + spacing), + 'toggle_speed'), + NumberMenuItem(f"Starting Length: {self.rules.starting_length}", + (center_x, start_y + spacing * 2), + 'set_starting_length', + 1, 10), + MenuItem("Back to Menu", + (center_x, start_y + spacing * 3), + 'back') + ] + + def update(self): + """Update menu state.""" + # Handle mouse hover + mouse_pos = pygame.mouse.get_pos() + for i, item in enumerate(self.menu_items): + if item.rect.collidepoint(mouse_pos): + # Update selected index when mouse hovers + self.selected_index = i + item.hover = True + item._setup_font() + else: + # Keep keyboard selection visible + item.hover = (i == self.selected_index) + item._setup_font() + + def handle_input(self, event): + """ + Handle input events. + + Returns: + str or None: The action to take or None if no action + """ + if event.type == pygame.MOUSEBUTTONDOWN: + # Handle mouse clicks + mouse_pos = pygame.mouse.get_pos() + for i, item in enumerate(self.menu_items): + if item.rect.collidepoint(mouse_pos): + self.selected_index = i + if event.button == 1: + if item.action == 'toggle_wrap': + self.rules.update_rule('wrap_around', not self.rules.wrap_around) + elif item.action == 'toggle_speed': + self.rules.speed_increase = not self.rules.speed_increase + item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}" + item._setup_font() + elif item.action == 'set_starting_length': + item.increment() + self.rules.update_rule('starting_length', item.current_value) + elif item.action == 'back': + return 'back' + elif event.button == 3: + if item.action == 'set_starting_length': + item.decrement() + self.rules.update_rule('starting_length', item.current_value) + + elif event.type == pygame.KEYDOWN: + if event.key == pygame.K_ESCAPE: + return 'back' + elif event.key == pygame.K_RETURN: + item = self.menu_items[self.selected_index] + if item.action == 'toggle_wrap': + self.rules.update_rule('wrap_around', not self.rules.wrap_around) + elif item.action == 'toggle_speed': + self.rules.update_rule('speed_increase', not self.rules.speed_increase) + elif item.action == 'set_starting_length': + item.increment() + self.rules.update_rule('starting_length', item.current_value) + elif item.action == 'back': + return 'back' + elif event.key in (pygame.K_UP, pygame.K_DOWN): + # Update selected index + if event.key == pygame.K_UP: + self.selected_index = (self.selected_index - 1) % len(self.menu_items) + else: + self.selected_index = (self.selected_index + 1) % len(self.menu_items) + + # Update hover states + for i, item in enumerate(self.menu_items): + item.hover = (i == self.selected_index) + item._setup_font() + + return None + + def draw(self, screen): + """Draw the menu to the screen.""" + # Draw background + screen.fill(BLACK) + + # Draw title + screen.blit(self.title_surface, self.title_rect) + + # Draw menu items + for item in self.menu_items: + item.draw(screen) + + # Draw controls + controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to go back" + controls_surface = self.subtitle_font.render(controls_text, True, GRAY) + screen.blit(controls_surface, + (self.width - controls_surface.get_width() - 10, + self.height - 30)) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index 5c607dd..432aa93 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -8,12 +8,12 @@ This package contains all test modules including: """ # Import test utilities and fixtures -from tests.conftest import game_config, mock_screen +from conftest import game_config, mock_screen __version__ = '0.1.0' # Define test utilities available for import __all__ = [ 'game_config', - 'mock_screen' + 'mock_screen', ] \ No newline at end of file diff --git a/tests/test_food.py b/tests/test_food.py index d46d173..2622b7a 100644 --- a/tests/test_food.py +++ b/tests/test_food.py @@ -1,5 +1,5 @@ import pytest -from src.food import Food +from src.core import Food def test_food_initialization(): """Test food initialization""" diff --git a/tests/test_snake.py b/tests/test_snake.py index 2495817..9df2dea 100644 --- a/tests/test_snake.py +++ b/tests/test_snake.py @@ -1,5 +1,4 @@ -import pytest -from src.snake import Snake, Direction +from src.core import Snake, Direction def test_snake_initialization(): """Test snake initialization with default values""" diff --git a/training_config.json b/training_config.json new file mode 100644 index 0000000..bb6d2c2 --- /dev/null +++ b/training_config.json @@ -0,0 +1,15 @@ +{ + "timesteps": 3000000, + "learning_rate": 3e-4, + "batch_size": 256, + "n_envs": 16, + "n_steps": 2048, + "gamma": 0.99, + "ent_coef": 0.1, + "n_epochs": 10, + "gae_lambda": 0.95, + "clip_range": 0.2, + "vf_coef": 0.5, + "max_grad_norm": 0.5, + "normalize_advantage": true +} \ No newline at end of file