Major code restructuring: Reorganized project structure, added AI components - Added new module structure (ai, config, core, ui) - Moved snake and food logic into core module - Added training configuration - Updated gitignore for project-specific files - Modified tests to match new structure
This commit is contained in:
parent
251822ec35
commit
b8eeaa4739
12
.gitignore
vendored
12
.gitignore
vendored
@ -7,6 +7,7 @@ ENV/
|
||||
.virtualenv/
|
||||
virtualenv/
|
||||
*venv/
|
||||
Activate.ps1
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
@ -39,3 +40,14 @@ wheels/
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Project specific
|
||||
logs/
|
||||
models/
|
||||
.pytest_cache/
|
||||
path/
|
||||
training_data/
|
||||
*.log
|
||||
*.pth
|
||||
*.ckpt
|
||||
*.h5
|
||||
|
106
ROADMAP.md
106
ROADMAP.md
@ -3,19 +3,26 @@
|
||||
This is a living document that tracks the development progress of the AI Snake Game project.
|
||||
|
||||
## Current Status
|
||||
Last Updated: [Menu System Implementation]
|
||||
- Added main menu interface with game mode selection
|
||||
- Implemented menu navigation and selection
|
||||
- Added game mode display during gameplay
|
||||
- Integrated menu system with game states
|
||||
Last Updated: [Control and Menu Improvements]
|
||||
- Fixed menu interaction issues:
|
||||
- Added proper mouse click handling
|
||||
- Fixed menu item highlighting
|
||||
- Improved keyboard and mouse navigation
|
||||
- Improved snake movement system:
|
||||
- Implemented input buffer for direction changes
|
||||
- Fixed rapid direction change issues
|
||||
- Improved responsiveness while maintaining safety
|
||||
- Previous updates:
|
||||
- Added settings menu with game options
|
||||
- Added wrap-around mode toggle
|
||||
- Added speed increase toggle
|
||||
- Added main menu interface
|
||||
- Added game mode selection
|
||||
- Implemented menu navigation
|
||||
- Added game mode display
|
||||
- Added game restart functionality
|
||||
- Added "Press any key to restart" message
|
||||
- Implemented game state reset
|
||||
- Added Food class with spawning mechanics
|
||||
- Integrated food system with main game loop
|
||||
- Implemented basic scoring system
|
||||
- Added game over screen with score display
|
||||
|
||||
## Phase 1: Project Setup and Basic Structure ✅
|
||||
- [x] Create project structure and virtual environment
|
||||
@ -27,11 +34,18 @@ Last Updated: [Menu System Implementation]
|
||||
## Phase 2: Core Game Development ✅
|
||||
- [x] Implement basic game window using Pygame
|
||||
- [x] Create Snake class with movement mechanics
|
||||
- [x] Basic movement and collision
|
||||
- [x] Input buffering for direction changes
|
||||
- [x] Fix rapid direction change issues
|
||||
- [x] Prevent 180-degree turns
|
||||
- [x] Implement food spawning system
|
||||
- [x] Create Food class
|
||||
- [x] Add random food placement
|
||||
- [x] Add collision detection with food
|
||||
- [x] Add collision detection (walls and self)
|
||||
- [x] Add collision detection
|
||||
- [x] Wall collisions
|
||||
- [x] Self collisions
|
||||
- [x] Add wrap-around mode option
|
||||
- [x] Implement scoring system
|
||||
- [x] Add score display
|
||||
- [ ] Add high score tracking
|
||||
@ -39,12 +53,19 @@ Last Updated: [Menu System Implementation]
|
||||
- [x] Implement game over state transition
|
||||
- [x] Add restart functionality
|
||||
|
||||
## Phase 3: Game Polish and UI 🔄
|
||||
## Phase 3: Game Polish and UI ✅
|
||||
- [x] Add main menu interface
|
||||
- [x] Create menu layout
|
||||
- [x] Add game mode selection
|
||||
- [ ] Add settings options
|
||||
- [x] Add settings menu
|
||||
- [x] Fix menu interaction issues
|
||||
- [x] Mouse click handling
|
||||
- [x] Menu item highlighting
|
||||
- [x] Keyboard/mouse navigation
|
||||
- [x] Implement pause functionality
|
||||
- [x] Add settings options
|
||||
- [x] Wrap-around mode toggle
|
||||
- [x] Speed increase toggle
|
||||
- [ ] Add visual effects and animations
|
||||
- [ ] Snake movement animation
|
||||
- [ ] Food spawn animation
|
||||
@ -58,36 +79,63 @@ Last Updated: [Menu System Implementation]
|
||||
- [x] Add game over screen with restart option
|
||||
|
||||
## Phase 4: AI Implementation 🔄
|
||||
- [ ] Design AI algorithm (pathfinding)
|
||||
- [ ] Research and select appropriate algorithm
|
||||
- [ ] Implement basic pathfinding
|
||||
- [ ] Implement AI snake movement logic
|
||||
- [ ] Create AI controller class
|
||||
- [ ] Implement decision making
|
||||
- [ ] Add AI difficulty levels
|
||||
- [ ] Easy mode (basic pathfinding)
|
||||
- [ ] Medium mode (improved decision making)
|
||||
- [ ] Set up ML training environment
|
||||
- [ ] Create gym-like interface for the game
|
||||
- [ ] Define observation space (snake + food state)
|
||||
- [ ] Define action space (4 directions)
|
||||
- [ ] Implement reward system
|
||||
- [ ] Implement ML infrastructure
|
||||
- [ ] Add ML dependencies (PyTorch/TensorFlow)
|
||||
- [ ] Create neural network architecture
|
||||
- [ ] Set up training pipeline
|
||||
- [ ] Train AI models for different difficulties
|
||||
- [ ] Easy mode (basic food seeking)
|
||||
- [ ] Medium mode (balanced survival/food seeking)
|
||||
- [ ] Hard mode (optimal pathfinding)
|
||||
- [ ] Create AI controller class
|
||||
- [ ] Model loading and inference
|
||||
- [ ] Real-time decision making
|
||||
- [ ] Performance optimization
|
||||
- [x] Create mode selection (Player vs AI)
|
||||
- [ ] Optimize AI performance
|
||||
- [ ] Add training utilities
|
||||
- [ ] Save/load model checkpoints
|
||||
- [ ] Training progress visualization
|
||||
- [ ] Model performance metrics
|
||||
|
||||
## Phase 5: Final Polish
|
||||
- [ ] Add configuration options
|
||||
- [ ] Game speed settings
|
||||
- [x] Add configuration options
|
||||
- [x] Game speed settings
|
||||
- [x] Wrap-around option
|
||||
- [ ] Visual settings
|
||||
- [ ] Sound settings
|
||||
- [ ] Implement save/load functionality for high scores
|
||||
- [ ] Bug fixing and performance optimization
|
||||
- [x] Fix menu interaction bugs
|
||||
- [x] Fix movement control issues
|
||||
- [x] Optimize input handling
|
||||
- [ ] Code cleanup and documentation
|
||||
- [ ] Add docstrings
|
||||
- [x] Add docstrings
|
||||
- [ ] Create API documentation
|
||||
- [ ] Update README with final features
|
||||
- [ ] Final testing and refinements
|
||||
|
||||
## Resolved Issues
|
||||
1. Menu Interaction
|
||||
- Issue: Menu items not responding to clicks
|
||||
- Solution: Added proper mouse click event handling in game loop
|
||||
- Solution: Fixed menu item highlighting and selection state
|
||||
|
||||
2. Snake Movement
|
||||
- Issue: Rapid direction changes causing self-collision
|
||||
- Issue: Missed inputs during quick turns
|
||||
- Solution: Implemented input buffer system
|
||||
- Solution: Process one direction change per grid movement
|
||||
- Solution: Store up to 2 pending direction changes
|
||||
|
||||
## Next Steps
|
||||
1. Implement AI snake movement
|
||||
- Design pathfinding algorithm
|
||||
- Create AI controller class
|
||||
- Set up ML training environment
|
||||
- Create neural network architecture
|
||||
- Train initial model
|
||||
2. Add high score system
|
||||
- Implement score persistence
|
||||
- Add high score display
|
||||
@ -96,7 +144,7 @@ Last Updated: [Menu System Implementation]
|
||||
- Add game sounds
|
||||
|
||||
## Known Issues
|
||||
- None reported yet
|
||||
- None reported
|
||||
|
||||
## Future Considerations
|
||||
- Multiplayer support
|
||||
|
@ -1,5 +1,12 @@
|
||||
pygame==2.5.2
|
||||
numpy==1.24.3
|
||||
numpy==1.26.3
|
||||
black==23.12.1 # for code formatting
|
||||
pytest==7.4.3 # for testing
|
||||
pytest-xdist==3.5.0 # for parallel test execution
|
||||
pytest-xdist==3.5.0 # for parallel test execution
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.1.2+cu118 # CUDA 11.8 support
|
||||
stable-baselines3[extra]==2.2.1 # for reinforcement learning
|
||||
tensorboard==2.15.1 # for training visualization
|
||||
gymnasium==0.29.1 # for RL environment interface
|
||||
tqdm==4.66.1
|
||||
rich==13.7.0
|
@ -9,11 +9,12 @@ This package contains all the core game components including:
|
||||
- AI controllers (coming soon)
|
||||
"""
|
||||
|
||||
from src import config
|
||||
from src import core
|
||||
from src import ai
|
||||
from src import ui
|
||||
from src.game import Game, GameState
|
||||
from src.snake import Snake, Direction
|
||||
from src.food import Food
|
||||
from src.menu import Menu, GameMode, MenuItem
|
||||
from src.cli import main as cli_main
|
||||
from src.ui.menu import GameMode
|
||||
|
||||
__version__ = '0.1.0'
|
||||
__author__ = 'Rbanh'
|
||||
@ -28,5 +29,9 @@ __all__ = [
|
||||
'Menu',
|
||||
'GameMode',
|
||||
'MenuItem',
|
||||
'cli_main'
|
||||
'cli_main',
|
||||
'config',
|
||||
'core',
|
||||
'ai',
|
||||
'ui'
|
||||
]
|
235
src/ai/config_ui.py
Normal file
235
src/ai/config_ui.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""
|
||||
Training Configuration UI
|
||||
|
||||
This module provides a graphical interface for configuring AI training parameters.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
|
||||
class ConfigUI:
|
||||
def __init__(self, width: int = 800, height: int = 600):
|
||||
"""Initialize the configuration UI."""
|
||||
pygame.init()
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.screen = pygame.display.set_mode((width, height))
|
||||
pygame.display.set_caption("Training Configuration")
|
||||
|
||||
# Fonts
|
||||
self.title_font = pygame.font.Font(None, 48)
|
||||
self.header_font = pygame.font.Font(None, 36)
|
||||
self.text_font = pygame.font.Font(None, 24)
|
||||
|
||||
# Colors
|
||||
self.colors = {
|
||||
'background': (0, 0, 0),
|
||||
'text': (200, 200, 200),
|
||||
'highlight': (0, 255, 0),
|
||||
'button': (40, 40, 40),
|
||||
'button_hover': (60, 60, 60),
|
||||
'input_bg': (30, 30, 30),
|
||||
'input_active': (50, 50, 50)
|
||||
}
|
||||
|
||||
# Parameters
|
||||
self.parameters = {
|
||||
'timesteps': {
|
||||
'value': 1000000,
|
||||
'type': 'int',
|
||||
'min': 1000,
|
||||
'max': 10000000,
|
||||
'description': 'Total timesteps to train for'
|
||||
},
|
||||
'learning_rate': {
|
||||
'value': 0.0003,
|
||||
'type': 'float',
|
||||
'min': 0.00001,
|
||||
'max': 0.01,
|
||||
'description': 'Learning rate for training'
|
||||
},
|
||||
'batch_size': {
|
||||
'value': 64,
|
||||
'type': 'int',
|
||||
'min': 32,
|
||||
'max': 512,
|
||||
'description': 'Batch size for training'
|
||||
},
|
||||
'n_envs': {
|
||||
'value': 8,
|
||||
'type': 'int',
|
||||
'min': 1,
|
||||
'max': 32,
|
||||
'description': 'Number of parallel environments'
|
||||
},
|
||||
'n_steps': {
|
||||
'value': 2048,
|
||||
'type': 'int',
|
||||
'min': 128,
|
||||
'max': 8192,
|
||||
'description': 'Number of steps per update'
|
||||
}
|
||||
}
|
||||
|
||||
# UI state
|
||||
self.active_input = None
|
||||
self.input_text = ""
|
||||
self.scroll_offset = 0
|
||||
self.max_scroll = max(0, len(self.parameters) * 60 - (height - 200))
|
||||
|
||||
# Load saved config if exists
|
||||
self.config_file = "training_config.json"
|
||||
self.load_config()
|
||||
|
||||
def load_config(self) -> None:
|
||||
"""Load configuration from file."""
|
||||
if os.path.exists(self.config_file):
|
||||
try:
|
||||
with open(self.config_file, 'r') as f:
|
||||
saved_config = json.load(f)
|
||||
for key, value in saved_config.items():
|
||||
if key in self.parameters:
|
||||
self.parameters[key]['value'] = value
|
||||
except:
|
||||
pass
|
||||
|
||||
def save_config(self) -> None:
|
||||
"""Save configuration to file."""
|
||||
config = {key: param['value'] for key, param in self.parameters.items()}
|
||||
with open(self.config_file, 'w') as f:
|
||||
json.dump(config, f, indent=4)
|
||||
|
||||
def draw_text_input(self, rect: pygame.Rect, value: Any, active: bool) -> None:
|
||||
"""Draw a text input field."""
|
||||
color = self.colors['input_active'] if active else self.colors['input_bg']
|
||||
pygame.draw.rect(self.screen, color, rect)
|
||||
pygame.draw.rect(self.screen, self.colors['text'], rect, 1)
|
||||
|
||||
text = str(value)
|
||||
if active:
|
||||
text = self.input_text + "|"
|
||||
|
||||
text_surface = self.text_font.render(text, True, self.colors['text'])
|
||||
text_rect = text_surface.get_rect(midleft=(rect.left + 5, rect.centery))
|
||||
self.screen.blit(text_surface, text_rect)
|
||||
|
||||
def draw_button(self, rect: pygame.Rect, text: str, hover: bool = False) -> None:
|
||||
"""Draw a button."""
|
||||
color = self.colors['button_hover'] if hover else self.colors['button']
|
||||
pygame.draw.rect(self.screen, color, rect)
|
||||
pygame.draw.rect(self.screen, self.colors['text'], rect, 1)
|
||||
|
||||
text_surface = self.text_font.render(text, True, self.colors['text'])
|
||||
text_rect = text_surface.get_rect(center=rect.center)
|
||||
self.screen.blit(text_surface, text_rect)
|
||||
|
||||
def validate_input(self, param: Dict[str, Any], value: str) -> Optional[Any]:
|
||||
"""Validate and convert input value."""
|
||||
try:
|
||||
if param['type'] == 'int':
|
||||
val = int(value)
|
||||
else:
|
||||
val = float(value)
|
||||
|
||||
if val < param['min'] or val > param['max']:
|
||||
return None
|
||||
return val
|
||||
except:
|
||||
return None
|
||||
|
||||
def run(self) -> Optional[Dict[str, Any]]:
|
||||
"""Run the configuration UI. Returns the config dict if saved, None if cancelled."""
|
||||
running = True
|
||||
save_clicked = False
|
||||
mouse_pos = (0, 0)
|
||||
|
||||
while running:
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
running = False
|
||||
|
||||
elif event.type == pygame.MOUSEBUTTONDOWN:
|
||||
mouse_pos = event.pos
|
||||
|
||||
# Check parameter inputs
|
||||
y = 100 - self.scroll_offset
|
||||
for name, param in self.parameters.items():
|
||||
input_rect = pygame.Rect(300, y, 200, 30)
|
||||
if input_rect.collidepoint(mouse_pos):
|
||||
self.active_input = name
|
||||
self.input_text = str(param['value'])
|
||||
y += 60
|
||||
|
||||
# Check buttons
|
||||
save_rect = pygame.Rect(self.width//2 - 150, self.height - 60, 140, 40)
|
||||
cancel_rect = pygame.Rect(self.width//2 + 10, self.height - 60, 140, 40)
|
||||
|
||||
if save_rect.collidepoint(mouse_pos):
|
||||
self.save_config()
|
||||
save_clicked = True
|
||||
running = False
|
||||
elif cancel_rect.collidepoint(mouse_pos):
|
||||
running = False
|
||||
|
||||
elif event.type == pygame.MOUSEBUTTONUP:
|
||||
if event.button == 4: # Mouse wheel up
|
||||
self.scroll_offset = max(0, self.scroll_offset - 30)
|
||||
elif event.button == 5: # Mouse wheel down
|
||||
self.scroll_offset = min(self.max_scroll, self.scroll_offset + 30)
|
||||
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if self.active_input is not None:
|
||||
if event.key == pygame.K_RETURN:
|
||||
param = self.parameters[self.active_input]
|
||||
if val := self.validate_input(param, self.input_text):
|
||||
param['value'] = val
|
||||
self.active_input = None
|
||||
elif event.key == pygame.K_BACKSPACE:
|
||||
self.input_text = self.input_text[:-1]
|
||||
else:
|
||||
if event.unicode.isnumeric() or event.unicode == '.':
|
||||
self.input_text += event.unicode
|
||||
|
||||
# Draw UI
|
||||
self.screen.fill(self.colors['background'])
|
||||
|
||||
# Title
|
||||
title = "Training Configuration"
|
||||
title_surface = self.title_font.render(title, True, self.colors['highlight'])
|
||||
title_rect = title_surface.get_rect(midtop=(self.width//2, 20))
|
||||
self.screen.blit(title_surface, title_rect)
|
||||
|
||||
# Parameters
|
||||
y = 100 - self.scroll_offset
|
||||
for name, param in self.parameters.items():
|
||||
if 0 <= y <= self.height - 100:
|
||||
# Parameter name and description
|
||||
name_surface = self.header_font.render(name, True, self.colors['text'])
|
||||
desc_surface = self.text_font.render(param['description'], True, self.colors['text'])
|
||||
self.screen.blit(name_surface, (20, y))
|
||||
self.screen.blit(desc_surface, (20, y + 30))
|
||||
|
||||
# Input field
|
||||
input_rect = pygame.Rect(300, y, 200, 30)
|
||||
self.draw_text_input(input_rect, param['value'], name == self.active_input)
|
||||
y += 60
|
||||
|
||||
# Buttons
|
||||
save_rect = pygame.Rect(self.width//2 - 150, self.height - 60, 140, 40)
|
||||
cancel_rect = pygame.Rect(self.width//2 + 10, self.height - 60, 140, 40)
|
||||
|
||||
save_hover = save_rect.collidepoint(mouse_pos)
|
||||
cancel_hover = cancel_rect.collidepoint(mouse_pos)
|
||||
|
||||
self.draw_button(save_rect, "Save", save_hover)
|
||||
self.draw_button(cancel_rect, "Cancel", cancel_hover)
|
||||
|
||||
pygame.display.flip()
|
||||
|
||||
pygame.quit()
|
||||
|
||||
if save_clicked:
|
||||
return {key: param['value'] for key, param in self.parameters.items()}
|
||||
return None
|
70
src/ai/controller.py
Normal file
70
src/ai/controller.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""
|
||||
AI Controller for Snake Game
|
||||
|
||||
This module provides the AI controller that uses trained models to play the game.
|
||||
It handles model loading, state processing, and decision making during gameplay.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from ai.environment import SnakeEnv, Direction
|
||||
|
||||
class AIController:
|
||||
"""AI controller that uses trained models to play the game."""
|
||||
|
||||
def __init__(self, difficulty: str = "medium"):
|
||||
"""
|
||||
Initialize the AI controller.
|
||||
|
||||
Args:
|
||||
difficulty: "easy", "medium", or "hard"
|
||||
"""
|
||||
self.difficulty = difficulty
|
||||
self.model = None
|
||||
self.env = SnakeEnv() # For state processing
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the appropriate model based on difficulty."""
|
||||
model_path = f"models/{self.difficulty}/best_model.zip"
|
||||
if not os.path.exists(model_path):
|
||||
# Fall back to final model if best model doesn't exist
|
||||
model_path = f"models/{self.difficulty}/final_model.zip"
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"No model found for difficulty {self.difficulty}. "
|
||||
"Please train the model first."
|
||||
)
|
||||
|
||||
self.model = PPO.load(model_path)
|
||||
|
||||
def get_action(self, game_state: dict) -> Direction:
|
||||
"""
|
||||
Get the next action based on the current game state.
|
||||
|
||||
Args:
|
||||
game_state: Dictionary containing:
|
||||
- snake: Snake object
|
||||
- food: Food object
|
||||
- width: Game width
|
||||
- height: Game height
|
||||
|
||||
Returns:
|
||||
Direction enum indicating the chosen action
|
||||
"""
|
||||
# Update environment with current game state
|
||||
self.env.snake = game_state["snake"]
|
||||
self.env.food = game_state["food"]
|
||||
self.env.width = game_state["width"]
|
||||
self.env.height = game_state["height"]
|
||||
|
||||
# Get state observation
|
||||
state = self.env._get_state()
|
||||
|
||||
# Get action from model
|
||||
action, _ = self.model.predict(state, deterministic=True)
|
||||
|
||||
# Convert action index to Direction
|
||||
return self.env.action_space[action]
|
787
src/ai/environment.py
Normal file
787
src/ai/environment.py
Normal file
@ -0,0 +1,787 @@
|
||||
"""
|
||||
Snake Game Environment for Reinforcement Learning
|
||||
|
||||
This module provides a gym-like interface for training AI agents to play the snake game.
|
||||
It includes:
|
||||
- State observation space
|
||||
- Action space
|
||||
- Reward system
|
||||
- Environment dynamics
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from typing import Tuple, List, Dict, Any, Optional
|
||||
import pygame
|
||||
from src.core import Snake
|
||||
from src.core import Direction
|
||||
from src.core import GameSession
|
||||
from src.ui import GameArea
|
||||
from src.config import GameRules
|
||||
|
||||
def manhattan_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> int:
|
||||
"""Calculate Manhattan distance between two points."""
|
||||
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
|
||||
|
||||
def manhattan_distance_wrap(pos1: Tuple[int, int], pos2: Tuple[int, int], size: int) -> int:
|
||||
"""Calculate Manhattan distance between two points with wrap-around."""
|
||||
dx = abs(pos1[0] - pos2[0])
|
||||
dy = abs(pos1[1] - pos2[1])
|
||||
return min(dx, size - dx) + min(dy, size - dy)
|
||||
|
||||
class SnakeEnv(gym.Env):
|
||||
"""A Gymnasium environment for training snake AI agents."""
|
||||
|
||||
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
|
||||
|
||||
def __init__(self, size=30, render_mode=None, difficulty="easy"):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.render_mode = render_mode
|
||||
self.difficulty = difficulty
|
||||
|
||||
# Action to Direction mapping
|
||||
self.action_to_direction = {
|
||||
0: Direction.UP,
|
||||
1: Direction.RIGHT,
|
||||
2: Direction.DOWN,
|
||||
3: Direction.LEFT
|
||||
}
|
||||
|
||||
# Direction to index mapping for one-hot encoding
|
||||
self.direction_to_index = {
|
||||
Direction.UP: 0,
|
||||
Direction.RIGHT: 1,
|
||||
Direction.DOWN: 2,
|
||||
Direction.LEFT: 3
|
||||
}
|
||||
|
||||
# Initialize reward scales with emphasis on precision
|
||||
self.reward_scales = {
|
||||
'food': 20.0, # Base food reward (will be scaled exponentially)
|
||||
'death': -5.0, # Reduced death penalty to encourage exploration
|
||||
'distance': 0.5, # Increased distance reward
|
||||
'survival': -0.01, # Small survival penalty
|
||||
'milestone': 10.0, # Larger milestone rewards
|
||||
'efficiency': 0.2, # Small efficiency reward
|
||||
'exploration': 0.5, # Increased exploration reward
|
||||
'safety': 1.0, # Safety reward for avoiding danger
|
||||
'timeout': -1.0, # Reduced timeout penalty
|
||||
'wrap_bonus': 0.2, # Small wrap-around bonus
|
||||
'near_miss': -0.1, # Very small near miss penalty
|
||||
'repetitive': -2 # Increased base penalty for repetitive movement
|
||||
}
|
||||
|
||||
# Near miss detection
|
||||
self.near_miss_threshold = 1 # Only penalize very close misses
|
||||
|
||||
# Time limit parameters - more lenient
|
||||
self.base_time_limit = 200 # Double base time limit
|
||||
self.length_time_bonus = 100 # Double length bonus
|
||||
self.max_time_limit = 1000 # Much longer maximum time
|
||||
|
||||
# Exploration decay
|
||||
self.initial_exploration_bonus = 2.0 # Higher initial exploration
|
||||
self.exploration_decay = 0.9995 # Slower decay
|
||||
self.min_exploration_bonus = 0.5 # Higher minimum exploration
|
||||
self.current_exploration_bonus = self.initial_exploration_bonus
|
||||
self.total_episodes = 0 # Track number of episodes for decay
|
||||
|
||||
# Curriculum learning parameters
|
||||
self.curriculum_stage = 0
|
||||
self.success_threshold = {
|
||||
0: 3, # Stage 0: Basic movement (need more consistent success)
|
||||
1: 5, # Stage 1: Food collection
|
||||
2: 7, # Stage 2: Longer snake
|
||||
3: 10 # Stage 3: Full difficulty
|
||||
}
|
||||
self.consecutive_successes = 0
|
||||
self.stage_requirements = {
|
||||
0: 1, # Stage 0: Get one food
|
||||
1: 3, # Stage 1: Get three food
|
||||
2: 5, # Stage 2: Get five food
|
||||
3: 7 # Stage 3: Get seven food
|
||||
}
|
||||
|
||||
# Dynamic episode length based on curriculum stage
|
||||
self.base_max_steps = 300 # More steps for exploration
|
||||
self.max_steps = self.base_max_steps
|
||||
|
||||
# Window dimensions for rendering
|
||||
self.window_width = 1024
|
||||
self.window_height = 768
|
||||
|
||||
# Create game area and rules
|
||||
self.game_area = GameArea(self.window_width, self.window_height)
|
||||
self.rules = GameRules()
|
||||
self.rules.wrap_around = False # No wrap-around for AI training
|
||||
|
||||
# Define action and observation spaces
|
||||
self.action_space = spaces.Discrete(4) # Up, Right, Down, Left
|
||||
|
||||
num_channels = 26 # Total number of channels we created in _get_normalized_observation
|
||||
self.observation_space = spaces.Box(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=(num_channels, size, size), # (channels, height, width)
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# Game session
|
||||
self.session = None
|
||||
self.steps = 0
|
||||
self.steps_since_food = 0
|
||||
self.current_time = 0
|
||||
self.last_score = 0
|
||||
self.last_distance = float('inf')
|
||||
self.last_food_steps = 0
|
||||
self.steps_since_direction_change = 0 # Track steps since last direction change
|
||||
self.last_direction = None # Track previous direction
|
||||
self.recent_positions = []
|
||||
self.steps_in_same_direction = 0 # Track steps without direction change
|
||||
self.last_min_food_distance = float('inf') # Initialize minimum distance tracking
|
||||
|
||||
# For rendering
|
||||
self.window = None
|
||||
self.clock = None
|
||||
|
||||
self.score_milestone_rewards = {5: 10.0, 10: 20.0, 15: 30.0} # Bonus rewards at score milestones
|
||||
self.last_milestone = 0 # Track last milestone reached
|
||||
self.base_survival_penalty = -1.0 # Much larger constant penalty per step
|
||||
|
||||
# Movement control parameters
|
||||
self.max_direction_steps = 10 # Reduced from 20 to 10
|
||||
self.direction_step_increase = 2 # Reduced from 5 to 2
|
||||
self.min_safe_turns = 2 # Reduced from 3 to 2
|
||||
self.repetitive_threshold = 10 # Base threshold before repetitive penalty kicks in
|
||||
self.repetitive_scale = 1.2 # Base scale for penalty growth
|
||||
self.repetitive_wrap_scale = 2.0 # Much faster growth rate in wrap mode
|
||||
self.max_repetitive_penalty = -5.0 # Increased maximum repetitive movement penalty
|
||||
|
||||
# Initialize base rewards with adjusted values
|
||||
rewards = {
|
||||
'food': 0.0,
|
||||
'death': 0.0,
|
||||
'distance': 0.0,
|
||||
'survival': self.base_survival_penalty * 0.5,
|
||||
'direction': 0.0,
|
||||
'efficiency': 0.0,
|
||||
'progress': 0.0,
|
||||
'alignment': 0.0,
|
||||
'repetitive': 0.0
|
||||
}
|
||||
|
||||
def _get_normalized_observation(self) -> np.ndarray:
|
||||
"""
|
||||
Create a multi-channel 2D observation from the game state.
|
||||
Spatial channels encode the grid layout (snake & food),
|
||||
while extra scalar features are broadcast over additional channels.
|
||||
"""
|
||||
state = self.session.get_state()
|
||||
snake_body = state["snake_body"]
|
||||
food_position = state["food_position"]
|
||||
head = snake_body[0]
|
||||
body = snake_body
|
||||
|
||||
# --- Spatial Channels ---
|
||||
# Channel 0: Snake positions (1 for snake, 0 otherwise)
|
||||
snake_channel = np.zeros((self.size, self.size), dtype=np.float32)
|
||||
for pos in snake_body:
|
||||
snake_channel[pos[1], pos[0]] = 1.0
|
||||
assert pos[0] < self.size and pos[1] < self.size, \
|
||||
f"Snake position {pos} exceeds grid size {self.size}"
|
||||
|
||||
# Channel 1: Food position
|
||||
food_channel = np.zeros((self.size, self.size), dtype=np.float32)
|
||||
food_channel[food_position[1], food_position[0]] = 1.0
|
||||
|
||||
# --- Auxiliary Scalar Features ---
|
||||
# Compute distances and apply wrap-around adjustments
|
||||
dx = food_position[0] - head[0]
|
||||
dy = food_position[1] - head[1]
|
||||
if self.rules.wrap_around:
|
||||
if dx > self.size/2:
|
||||
wrap_dx = dx - self.size
|
||||
elif dx < -self.size/2:
|
||||
wrap_dx = dx + self.size
|
||||
else:
|
||||
wrap_dx = dx
|
||||
if dy > self.size/2:
|
||||
wrap_dy = dy - self.size
|
||||
elif dy < -self.size/2:
|
||||
wrap_dy = dy + self.size
|
||||
else:
|
||||
wrap_dy = dy
|
||||
else:
|
||||
wrap_dx = dx
|
||||
wrap_dy = dy
|
||||
|
||||
# Normalize
|
||||
dx_norm = dx / self.size
|
||||
dy_norm = dy / self.size
|
||||
wrap_dx_norm = wrap_dx / self.size
|
||||
wrap_dy_norm = wrap_dy / self.size
|
||||
|
||||
# Danger indicators (vector of 4 values)
|
||||
dangers = self._get_danger_observations_all_directions(head, body)
|
||||
|
||||
# Other auxiliary scalar features
|
||||
snake_length_norm = len(body) / (self.size**2)
|
||||
# One-hot encode current direction
|
||||
direction_one_hot = [0.0] * 4
|
||||
direction_one_hot[self.direction_to_index[state["snake_direction"]]] = 1.0
|
||||
|
||||
# One-hot encode relative food direction
|
||||
food_dir = [0.0] * 4
|
||||
if abs(dx_norm) > abs(dy_norm):
|
||||
if dx_norm > 0:
|
||||
food_dir[self.direction_to_index[Direction.RIGHT]] = 1.0
|
||||
else:
|
||||
food_dir[self.direction_to_index[Direction.LEFT]] = 1.0
|
||||
else:
|
||||
if dy_norm > 0:
|
||||
food_dir[self.direction_to_index[Direction.DOWN]] = 1.0
|
||||
else:
|
||||
food_dir[self.direction_to_index[Direction.UP]] = 1.0
|
||||
|
||||
current_time_limit = min(
|
||||
self.max_time_limit,
|
||||
self.base_time_limit + (len(body) - 1) * self.length_time_bonus
|
||||
)
|
||||
time_pressure = np.clip(self.steps_since_food / current_time_limit, 0, 1)
|
||||
wrap_mode = float(self.rules.wrap_around)
|
||||
progress = min(1.0, len(body) / 10)
|
||||
recent_food = min(1.0, self.steps_since_food / 50)
|
||||
curriculum_stage_norm = float(self.curriculum_stage) / 3
|
||||
advanced_length = float(len(body) > 5)
|
||||
manhattan_norm = min(1.0, manhattan_distance(head, food_position) / (self.size * 2))
|
||||
|
||||
# --- Create "broadcasted" channels for each scalar ---
|
||||
def broadcast_channel(value):
|
||||
return np.full((self.size, self.size), value, dtype=np.float32)
|
||||
|
||||
dx_channel = broadcast_channel(dx_norm)
|
||||
dy_channel = broadcast_channel(dy_norm)
|
||||
wrap_dx_channel = broadcast_channel(wrap_dx_norm)
|
||||
wrap_dy_channel = broadcast_channel(wrap_dy_norm)
|
||||
snake_length_chan = broadcast_channel(snake_length_norm)
|
||||
time_pressure_chan = broadcast_channel(time_pressure)
|
||||
wrap_mode_chan = broadcast_channel(wrap_mode)
|
||||
progress_chan = broadcast_channel(progress)
|
||||
recent_food_chan = broadcast_channel(recent_food)
|
||||
curriculum_chan = broadcast_channel(curriculum_stage_norm)
|
||||
advanced_length_chan = broadcast_channel(advanced_length)
|
||||
manhattan_chan = broadcast_channel(manhattan_norm)
|
||||
|
||||
# For one-hot features, create one channel per value
|
||||
direction_channels = [broadcast_channel(v) for v in direction_one_hot]
|
||||
food_dir_channels = [broadcast_channel(v) for v in food_dir]
|
||||
|
||||
# For dangers, one channel per direction (4 channels)
|
||||
danger_channels = [broadcast_channel(d) for d in dangers]
|
||||
|
||||
# --- Stack all channels ---
|
||||
# You can adjust the channel order as desired.
|
||||
channels = [
|
||||
snake_channel, # Channel 0
|
||||
food_channel, # Channel 1
|
||||
dx_channel, # Channel 2
|
||||
dy_channel, # Channel 3
|
||||
wrap_dx_channel, # Channel 4
|
||||
wrap_dy_channel, # Channel 5
|
||||
]
|
||||
channels.extend(danger_channels) # Channels 6-9
|
||||
channels.append(snake_length_chan) # Channel 10
|
||||
channels.extend(direction_channels) # Channels 11-14
|
||||
channels.extend(food_dir_channels) # Channels 15-18
|
||||
channels.append(time_pressure_chan) # Channel 19
|
||||
channels.append(wrap_mode_chan) # Channel 20
|
||||
channels.append(progress_chan) # Channel 21
|
||||
channels.append(recent_food_chan) # Channel 22
|
||||
channels.append(curriculum_chan) # Channel 23
|
||||
channels.append(advanced_length_chan) # Channel 24
|
||||
channels.append(manhattan_chan) # Channel 25
|
||||
|
||||
# Final observation: shape (num_channels, size, size)
|
||||
observation = np.stack(channels, axis=0)
|
||||
return np.clip(observation, 0, 1)
|
||||
|
||||
def _get_danger_observations_all_directions(self, head: Tuple[int, int], body: List[Tuple[int, int]]) -> List[float]:
|
||||
"""Get danger observations in all four directions."""
|
||||
dangers = [0.0] * 4 # [UP, RIGHT, DOWN, LEFT]
|
||||
|
||||
# Check each direction
|
||||
directions = [Direction.UP, Direction.RIGHT, Direction.DOWN, Direction.LEFT]
|
||||
for i, direction in enumerate(directions):
|
||||
dx, dy = direction.to_vector()
|
||||
next_pos = (head[0] + dx, head[1] + dy)
|
||||
|
||||
# Check wall collision
|
||||
if not self.rules.wrap_around:
|
||||
if (next_pos[0] < 0 or next_pos[0] >= self.size or
|
||||
next_pos[1] < 0 or next_pos[1] >= self.size):
|
||||
dangers[i] = 1.0
|
||||
continue
|
||||
else:
|
||||
next_pos = (next_pos[0] % self.size, next_pos[1] % self.size)
|
||||
|
||||
# Check self collision
|
||||
if next_pos in body[1:]:
|
||||
dangers[i] = 1.0
|
||||
|
||||
return dangers
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""Reset with curriculum-based difficulty and exploration decay."""
|
||||
# Call parent reset without unpacking
|
||||
super().reset(seed=seed)
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# Update exploration bonus
|
||||
self.total_episodes += 1
|
||||
self.current_exploration_bonus = max(
|
||||
self.min_exploration_bonus,
|
||||
self.initial_exploration_bonus * (self.exploration_decay ** self.total_episodes)
|
||||
)
|
||||
|
||||
# Initialize game session with curriculum-appropriate settings
|
||||
self.session = GameSession(
|
||||
self.size,
|
||||
self.size,
|
||||
self.rules
|
||||
)
|
||||
|
||||
# Reset counters
|
||||
self.steps = 0
|
||||
self.steps_since_food = 0
|
||||
self.current_time = 0
|
||||
self.last_score = 0
|
||||
self.last_distance = float('inf')
|
||||
self.last_direction = self.session.get_state()["snake_direction"] # Track initial direction
|
||||
self.recent_positions = []
|
||||
self.steps_in_same_direction = 0 # Track steps without direction change
|
||||
self.last_min_food_distance = float('inf') # Initialize minimum distance tracking
|
||||
|
||||
# Get initial observation
|
||||
observation = self._get_normalized_observation()
|
||||
info = {
|
||||
"curriculum_stage": self.curriculum_stage,
|
||||
"exploration_bonus": self.current_exploration_bonus
|
||||
}
|
||||
|
||||
return observation, info
|
||||
|
||||
def _calculate_area_penalty(self, head_pos):
|
||||
"""Calculate penalty for staying in a small area."""
|
||||
# Add current position to history
|
||||
self.recent_positions.append(head_pos)
|
||||
if len(self.recent_positions) > self.max_recent_positions:
|
||||
self.recent_positions.pop(0)
|
||||
|
||||
if len(self.recent_positions) < 10: # Need minimum history to calculate
|
||||
return 0.0
|
||||
|
||||
# Calculate bounding box of recent positions
|
||||
x_coords = [p[0] for p in self.recent_positions]
|
||||
y_coords = [p[1] for p in self.recent_positions]
|
||||
area_width = max(x_coords) - min(x_coords) + 1
|
||||
area_height = max(y_coords) - min(y_coords) + 1
|
||||
area = area_width * area_height
|
||||
|
||||
# Calculate unique positions visited
|
||||
unique_positions = len(set(self.recent_positions))
|
||||
|
||||
# Penalize small areas and repeated positions
|
||||
area_penalty = 0.0
|
||||
if area < 9: # 3x3 grid or smaller
|
||||
area_penalty = -0.2 * (9 - area) / 8 # Max penalty -0.2 for 1x1 area
|
||||
|
||||
# Additional penalty for revisiting same positions frequently
|
||||
repetition_penalty = -0.3 * (1 - unique_positions / len(self.recent_positions))
|
||||
|
||||
return area_penalty + repetition_penalty
|
||||
|
||||
def step(self, action):
|
||||
"""Improved reward structure with length-aware exploration."""
|
||||
self.steps += 1
|
||||
self.steps_since_food += 1
|
||||
|
||||
# Calculate current time limit based on snake length
|
||||
snake_length = len(self.session.get_state()["snake_body"])
|
||||
current_time_limit = min(
|
||||
self.max_time_limit,
|
||||
self.base_time_limit + (snake_length - 1) * self.length_time_bonus
|
||||
)
|
||||
|
||||
# Get state before action
|
||||
prev_state = self.session.get_state()
|
||||
prev_head = prev_state["snake_head"]
|
||||
prev_food_dist = manhattan_distance(prev_head, prev_state["food_position"])
|
||||
if self.rules.wrap_around:
|
||||
prev_food_dist = min(prev_food_dist,
|
||||
manhattan_distance_wrap(prev_head, prev_state["food_position"], self.size))
|
||||
|
||||
# Convert action to Direction and apply
|
||||
direction = self.action_to_direction[int(action)]
|
||||
|
||||
# Update game state
|
||||
state, base_reward, done = self.session.step(direction, self.current_time)
|
||||
|
||||
# Initialize reward components
|
||||
rewards = {
|
||||
'food': 0.0,
|
||||
'death': 0.0,
|
||||
'distance': 0.0,
|
||||
'survival': 0.0,
|
||||
'milestone': 0.0,
|
||||
'efficiency': 0.0,
|
||||
'exploration': 0.0,
|
||||
'safety': 0.0,
|
||||
'timeout': 0.0,
|
||||
'direction': 0.0,
|
||||
'repetitive': 0.0
|
||||
}
|
||||
|
||||
# Get current state info
|
||||
curr_head = state["snake_body"][0]
|
||||
curr_food = state["food_position"]
|
||||
curr_food_dist = manhattan_distance(curr_head, curr_food)
|
||||
if self.rules.wrap_around:
|
||||
curr_food_dist = min(curr_food_dist,
|
||||
manhattan_distance_wrap(curr_head, curr_food, self.size))
|
||||
|
||||
# Update current time
|
||||
self.current_time += self.session.move_cooldown
|
||||
|
||||
# Food reward with exponential scaling based on snake length
|
||||
food_eaten = state["score"] > prev_state["score"]
|
||||
if food_eaten:
|
||||
# Reset timer when food is eaten
|
||||
self.steps_since_food = 0
|
||||
|
||||
# Calculate exponential food reward
|
||||
base_food_reward = self.reward_scales['food']
|
||||
length_multiplier = 1.2 ** (len(state["snake_body"]) - 1) # 20% increase per length
|
||||
rewards['food'] = base_food_reward * length_multiplier
|
||||
|
||||
# Add milestone rewards for achieving certain scores
|
||||
if state["score"] in [5, 10, 15, 20]:
|
||||
rewards['milestone'] = self.reward_scales['milestone'] * (state["score"] / 5)
|
||||
else:
|
||||
# Only apply survival penalty if we haven't eaten food in a while
|
||||
if self.steps_since_food > current_time_limit / 2:
|
||||
penalty_factor = (self.steps_since_food - current_time_limit/2) / (current_time_limit/2)
|
||||
rewards['survival'] = self.reward_scales['survival'] * penalty_factor
|
||||
|
||||
# Check if direction changed and update counters
|
||||
direction_changed = direction != self.last_direction
|
||||
if direction_changed:
|
||||
self.steps_in_same_direction = 0
|
||||
else:
|
||||
self.steps_in_same_direction += 1
|
||||
# Apply growing penalty for repetitive movement
|
||||
if self.steps_in_same_direction > (self.repetitive_threshold - 5 if self.rules.wrap_around else self.repetitive_threshold):
|
||||
excess_steps = self.steps_in_same_direction - (self.repetitive_threshold - 5 if self.rules.wrap_around else self.repetitive_threshold)
|
||||
# Use much faster growth rate in wrap-around mode
|
||||
growth_rate = self.repetitive_wrap_scale if self.rules.wrap_around else self.repetitive_scale
|
||||
repetitive_penalty = self.reward_scales['repetitive'] * (growth_rate ** excess_steps)
|
||||
|
||||
# Apply additional multiplier in wrap-around mode that grows with steps
|
||||
if self.rules.wrap_around:
|
||||
wrap_multiplier = 1.0 + (excess_steps * 0.5) # Multiplier grows with each step
|
||||
repetitive_penalty *= wrap_multiplier
|
||||
rewards['repetitive'] = repetitive_penalty # No maximum cap - let it grow unbounded in wrap-around mode
|
||||
|
||||
else:
|
||||
# Cap the penalty when not in wrap-around mode
|
||||
rewards['repetitive'] = max(repetitive_penalty, self.max_repetitive_penalty)
|
||||
|
||||
|
||||
self.last_direction = direction
|
||||
|
||||
# Check for timeout or stuck in loop - only if we haven't just eaten food
|
||||
if (self.steps_since_food >= current_time_limit and not food_eaten):
|
||||
done = True
|
||||
rewards['timeout'] = self.reward_scales['timeout']
|
||||
# Scale timeout penalty based on distance to food
|
||||
if curr_food_dist < 5: # Harsher penalty for timing out near food
|
||||
rewards['timeout'] *= 1.5
|
||||
# Additional penalty for getting stuck in a loop
|
||||
if self.steps_in_same_direction >= 30: # Also update this threshold
|
||||
rewards['timeout'] *= 1.2
|
||||
|
||||
# Death penalty - calculate AFTER food rewards
|
||||
if done and not self.steps_since_food >= current_time_limit:
|
||||
base_death_penalty = self.reward_scales['death']
|
||||
|
||||
# Check if death was due to self-collision
|
||||
if curr_head in state["snake_body"][1:]:
|
||||
# Base multiplier for self-collision
|
||||
self_collision_multiplier = 2.0
|
||||
|
||||
# Additional penalty scaling with score
|
||||
# At score 5: 2.5x penalty
|
||||
# At score 10: 3.0x penalty
|
||||
# At score 15: 3.5x penalty
|
||||
score_multiplier = 1.0 + (state["score"] / 10)
|
||||
|
||||
# Apply both multipliers
|
||||
rewards['death'] = base_death_penalty * self_collision_multiplier * score_multiplier
|
||||
|
||||
# Additional penalty if we died near food
|
||||
if curr_food_dist < 5:
|
||||
rewards['death'] *= 1.5 # Even harsher if we collide with ourselves near food
|
||||
else:
|
||||
# For wall collisions, keep original scaling but make it less punishing at higher scores
|
||||
rewards['death'] = base_death_penalty * (1.0 - min(0.5, state["score"] / 20))
|
||||
|
||||
# Add an immediate negative reward for losing potential score
|
||||
potential_loss_penalty = -0.5 * state["score"] # Larger penalty for dying with higher scores
|
||||
rewards['death'] += potential_loss_penalty
|
||||
|
||||
# Track minimum distance to food and check for near misses
|
||||
if curr_food_dist < self.last_min_food_distance:
|
||||
self.last_min_food_distance = curr_food_dist
|
||||
elif curr_food_dist > self.last_min_food_distance:
|
||||
# We're moving away from our closest approach
|
||||
if self.last_min_food_distance <= self.near_miss_threshold:
|
||||
# We got very close but missed
|
||||
rewards['distance'] += self.reward_scales['near_miss']
|
||||
self.last_min_food_distance = float('inf') # Reset tracking
|
||||
|
||||
# Progressive distance reward - more reward for getting closer when near food
|
||||
distance_change = prev_food_dist - curr_food_dist
|
||||
if curr_food_dist < 5: # When close to food
|
||||
distance_multiplier = 2.0 # Double the reward/penalty
|
||||
else:
|
||||
distance_multiplier = 1.0
|
||||
rewards['distance'] = distance_change * self.reward_scales['distance'] * distance_multiplier
|
||||
|
||||
# Calculate direction penalty
|
||||
if curr_food_dist <= 5: # Only apply when close to food
|
||||
# Calculate optimal direction to food
|
||||
dx = curr_food[0] - curr_head[0]
|
||||
dy = curr_food[1] - curr_head[1]
|
||||
|
||||
if self.rules.wrap_around:
|
||||
# Adjust for wrap-around
|
||||
if dx > self.size/2: dx -= self.size
|
||||
elif dx < -self.size/2: dx += self.size
|
||||
if dy > self.size/2: dy -= self.size
|
||||
elif dy < -self.size/2: dy += self.size
|
||||
|
||||
# Determine optimal direction(s)
|
||||
optimal_directions = []
|
||||
if abs(dx) > abs(dy):
|
||||
if dx > 0: optimal_directions.append(Direction.RIGHT)
|
||||
elif dx < 0: optimal_directions.append(Direction.LEFT)
|
||||
if abs(dy) >= abs(dx):
|
||||
if dy > 0: optimal_directions.append(Direction.DOWN)
|
||||
elif dy < 0: optimal_directions.append(Direction.UP)
|
||||
|
||||
# Base penalty scales with proximity (closer = higher penalty)
|
||||
base_penalty = (6 - curr_food_dist) * 0.1 # Scales from 0.1 to 0.5
|
||||
|
||||
if direction not in optimal_directions:
|
||||
# Check if we're moving directly away from food
|
||||
opposite_dirs = {
|
||||
Direction.UP: Direction.DOWN,
|
||||
Direction.DOWN: Direction.UP,
|
||||
Direction.LEFT: Direction.RIGHT,
|
||||
Direction.RIGHT: Direction.LEFT
|
||||
}
|
||||
optimal_opposites = [opposite_dirs[d] for d in optimal_directions]
|
||||
|
||||
if direction in optimal_opposites:
|
||||
# Double penalty for moving directly away from food
|
||||
rewards['direction'] = -2 * base_penalty
|
||||
else:
|
||||
# Regular penalty for suboptimal direction
|
||||
rewards['direction'] = -base_penalty
|
||||
elif curr_food_dist <= 2: # Small reward for correct direction when very close
|
||||
rewards['direction'] = 0.1
|
||||
|
||||
# Efficiency reward - encourage purposeful movement
|
||||
if not done and self.steps_since_food < 50: # Only apply when actively hunting
|
||||
rewards['efficiency'] = self.reward_scales['efficiency'] * (1.0 - self.steps_since_food / 50)
|
||||
|
||||
# Calculate final reward - simple sum of components
|
||||
reward = sum(rewards.values())
|
||||
|
||||
# Get observation
|
||||
observation = self._get_normalized_observation()
|
||||
|
||||
# No more episode-based truncation
|
||||
truncated = False
|
||||
|
||||
# Include reward components and exploration info in info dict for monitoring
|
||||
info = {
|
||||
"score": state["score"],
|
||||
"reward_components": rewards,
|
||||
"exploration_bonus": self.current_exploration_bonus,
|
||||
"direction_changed": direction_changed,
|
||||
"steps_in_same_direction": self.steps_in_same_direction,
|
||||
"snake_length": snake_length,
|
||||
"steps": self.steps # Include total steps for monitoring
|
||||
}
|
||||
|
||||
return observation, reward, done, truncated, info
|
||||
|
||||
def render(self):
|
||||
"""Render the environment."""
|
||||
if self.window is None and self.render_mode == "human":
|
||||
pygame.init()
|
||||
self.window = pygame.display.set_mode((self.window_width, self.window_height))
|
||||
pygame.display.set_caption("Snake Environment")
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
if self.window is not None:
|
||||
self.window.fill((0, 0, 0))
|
||||
|
||||
# Draw game area
|
||||
self.game_area.draw(self.window)
|
||||
|
||||
# Draw game session
|
||||
self.session.render(self.window, self.game_area)
|
||||
|
||||
# Draw score
|
||||
font = pygame.font.Font(None, 36)
|
||||
score_text = f"Score: {self.session.score}"
|
||||
score_surface = font.render(score_text, True, (0, 255, 0))
|
||||
score_rect = score_surface.get_rect(midtop=(self.window_width // 2, 20))
|
||||
self.window.blit(score_surface, score_rect)
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
|
||||
if self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.window)),
|
||||
axes=(1, 0, 2)
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the environment."""
|
||||
if self.window is not None:
|
||||
pygame.quit()
|
||||
self.window = None
|
||||
|
||||
def get_state(self) -> Dict:
|
||||
"""Get the current game state."""
|
||||
state = self.session.get_state()
|
||||
state['steps_since_food'] = self.steps_since_food
|
||||
|
||||
# Calculate and include current time limit
|
||||
snake_length = len(state["snake_body"])
|
||||
current_time_limit = min(
|
||||
self.max_time_limit,
|
||||
self.base_time_limit + (snake_length - 1) * self.length_time_bonus
|
||||
)
|
||||
state['current_time_limit'] = current_time_limit
|
||||
|
||||
return state
|
||||
|
||||
def _get_adjacent_positions(self, position: Tuple[int, int]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Get all adjacent positions to the given position.
|
||||
|
||||
Args:
|
||||
position: Current position (x, y)
|
||||
|
||||
Returns:
|
||||
List of adjacent positions [(x, y), ...]
|
||||
|
||||
Raises:
|
||||
ValueError: If position is None or not a tuple of 2 integers
|
||||
ValueError: If position coordinates are not integers
|
||||
"""
|
||||
if not position or not isinstance(position, tuple) or len(position) != 2:
|
||||
raise ValueError("Position must be a tuple of 2 coordinates")
|
||||
|
||||
try:
|
||||
x, y = int(position[0]), int(position[1])
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("Position coordinates must be integers")
|
||||
|
||||
# Get positions in all 4 directions (up, right, down, left)
|
||||
adjacent = [
|
||||
(x, y - 1), # up
|
||||
(x + 1, y), # right
|
||||
(x, y + 1), # down
|
||||
(x - 1, y) # left
|
||||
]
|
||||
|
||||
# Validate size attribute exists and is positive integer
|
||||
if not hasattr(self, 'size') or not isinstance(self.size, int) or self.size <= 0:
|
||||
raise ValueError("Environment size must be a positive integer")
|
||||
|
||||
# If wrap-around is enabled, adjust out-of-bounds positions
|
||||
if hasattr(self, 'rules') and hasattr(self.rules, 'wrap_around') and self.rules.wrap_around:
|
||||
adjacent = [
|
||||
(pos[0] % self.size, pos[1] % self.size)
|
||||
for pos in adjacent
|
||||
]
|
||||
return adjacent
|
||||
|
||||
# Otherwise, only return positions that are within bounds
|
||||
return [
|
||||
(pos[0], pos[1]) for pos in adjacent
|
||||
if 0 <= pos[0] < self.size and 0 <= pos[1] < self.size
|
||||
]
|
||||
|
||||
def _update_curriculum(self, score: int, done: bool) -> None:
|
||||
"""Update curriculum stage based on agent's performance."""
|
||||
if done:
|
||||
if score >= self.stage_requirements[self.curriculum_stage]:
|
||||
self.consecutive_successes += 1
|
||||
else:
|
||||
self.consecutive_successes = 0
|
||||
|
||||
# Check for stage advancement
|
||||
if (self.curriculum_stage < 3 and
|
||||
self.consecutive_successes >= self.success_threshold[self.curriculum_stage]):
|
||||
self.curriculum_stage += 1
|
||||
self.consecutive_successes = 0
|
||||
# Increase episode length with curriculum stage
|
||||
self.max_steps = self.base_max_steps * (1 + self.curriculum_stage * 0.5)
|
||||
|
||||
def _get_curriculum_rules(self) -> GameRules:
|
||||
"""Get game rules appropriate for current curriculum stage."""
|
||||
rules = GameRules()
|
||||
|
||||
# Adjust rules based on curriculum stage
|
||||
if self.curriculum_stage == 0:
|
||||
rules.wrap_around = True # Make it easier to avoid walls
|
||||
rules.speed_increase = False
|
||||
rules.initial_move_cooldown = 200 # Much slower movement
|
||||
elif self.curriculum_stage == 1:
|
||||
rules.wrap_around = True
|
||||
rules.speed_increase = False # Still no speed increase
|
||||
rules.initial_move_cooldown = 150
|
||||
elif self.curriculum_stage == 2:
|
||||
rules.wrap_around = True # Keep wrap-around until final stage
|
||||
rules.speed_increase = True
|
||||
rules.initial_move_cooldown = 120
|
||||
else: # Stage 3
|
||||
rules.wrap_around = False
|
||||
rules.speed_increase = True
|
||||
rules.initial_move_cooldown = 100
|
||||
|
||||
return rules
|
||||
|
||||
# Register the environment with Gymnasium
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
register(
|
||||
id='Snake-v0',
|
||||
entry_point='src.ai.environment:SnakeEnv',
|
||||
kwargs={
|
||||
'size': 30,
|
||||
'render_mode': None,
|
||||
'difficulty': 'easy'
|
||||
}
|
||||
)
|
93
src/ai/model_networks.py
Normal file
93
src/ai/model_networks.py
Normal file
@ -0,0 +1,93 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
|
||||
class CustomRecurrentCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
A simple CNN with an LSTM layer added on top. This enables the network to keep an internal
|
||||
memory of past observations so that it can adapt to changing rules/environment dynamics.
|
||||
"""
|
||||
def __init__(self, observation_space, features_dim=512):
|
||||
# Initialize with a features_dim that matches the LSTM's output.
|
||||
super(CustomRecurrentCNN, self).__init__(observation_space, features_dim)
|
||||
|
||||
# Calculate the expected CNN output size first
|
||||
n_input_channels = observation_space.shape[0] # Number of input channels
|
||||
|
||||
self.cnn = nn.Sequential(
|
||||
# Layer 1: (n_channels, 30, 30) -> (32, 15, 15)
|
||||
nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
|
||||
# Layer 2: (32, 15, 15) -> (64, 8, 8)
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
|
||||
# Layer 3: (64, 8, 8) -> (64, 4, 4)
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Flatten()
|
||||
)
|
||||
|
||||
# Calculate CNN output dimension using a dummy forward pass
|
||||
with torch.no_grad():
|
||||
dummy_input = torch.zeros(1, *observation_space.shape)
|
||||
cnn_output = self.cnn(dummy_input)
|
||||
self._cnn_output_dim = cnn_output.shape[1] # Should be 64 * 4 * 4 = 1024
|
||||
|
||||
# Add linear layer to reduce CNN output to desired feature dimension
|
||||
self.fc = nn.Linear(self._cnn_output_dim, 256)
|
||||
|
||||
# LSTM layer
|
||||
self.lstm = nn.LSTM(256, features_dim, batch_first=True)
|
||||
|
||||
def forward(self, observations):
|
||||
# Pass through CNN: (batch, channels, height, width) -> (batch, cnn_features)
|
||||
cnn_out = self.cnn(observations)
|
||||
|
||||
# Pass through linear layer: (batch, cnn_features) -> (batch, 256)
|
||||
fc_out = self.fc(cnn_out)
|
||||
|
||||
# Add time dimension for LSTM: (batch, 256) -> (batch, 1, 256)
|
||||
lstm_input = fc_out.unsqueeze(1)
|
||||
|
||||
# LSTM: (batch, 1, 256) -> (batch, 1, features_dim)
|
||||
lstm_out, _ = self.lstm(lstm_input)
|
||||
|
||||
# Remove time dimension: (batch, 1, features_dim) -> (batch, features_dim)
|
||||
return lstm_out.squeeze(1)
|
||||
|
||||
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
Custom feature extractor for the snake environment.
|
||||
Uses fully connected layers since our input is a 1D vector of normalized features.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
|
||||
super().__init__(observation_space, features_dim)
|
||||
self.input_dim = int(np.prod(observation_space.shape))
|
||||
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Linear(self.input_dim, 512),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.Linear(512, features_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
# Handle both single and batch observations
|
||||
if len(observations.shape) == 1:
|
||||
observations = observations.unsqueeze(0)
|
||||
|
||||
# Ensure the input dimension matches what we expect
|
||||
flat_obs = observations.reshape(observations.shape[0], self.input_dim)
|
||||
return self.cnn(flat_obs)
|
423
src/ai/train.py
Normal file
423
src/ai/train.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""
|
||||
Training script for Snake AI using Stable Baselines3.
|
||||
|
||||
This script sets up and trains RL agents using the SnakeEnv environment.
|
||||
It supports multiple difficulty levels and provides real-time visualization
|
||||
of the training process.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
import gymnasium as gym
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from queue import Empty
|
||||
|
||||
from src.ai.environment import SnakeEnv
|
||||
from src.ai.visualize import run_dashboard
|
||||
from src.ai.model_networks import CustomRecurrentCNN, CustomCNN
|
||||
|
||||
class VisualizationCallback(BaseCallback):
|
||||
"""Callback for updating the training visualization."""
|
||||
|
||||
def __init__(self, viz_queue: Queue, total_timesteps: int, eval_freq: int = 1000, n_steps: int = 512, n_envs: int = 32, initial_timesteps: int = 0):
|
||||
super().__init__()
|
||||
self.viz_queue = viz_queue
|
||||
self.total_timesteps = total_timesteps
|
||||
self.eval_freq = eval_freq
|
||||
self.n_steps = n_steps
|
||||
self.n_envs = n_envs
|
||||
self.initial_timesteps = initial_timesteps # Store initial timesteps
|
||||
|
||||
# Training metrics
|
||||
self.episode_rewards = []
|
||||
self.episode_lengths = []
|
||||
|
||||
# Evaluation metrics
|
||||
self.eval_scores = []
|
||||
self.high_score = 0
|
||||
self.recent_high_score = 0
|
||||
|
||||
# Current episode tracking
|
||||
self.current_reward = 0
|
||||
self.current_length = 0
|
||||
|
||||
# Evaluation games (one with wrap, one without)
|
||||
self.eval_envs = [
|
||||
Monitor(SnakeEnv(render_mode="rgb_array")) for _ in range(2)
|
||||
]
|
||||
self.eval_envs[0].unwrapped.rules.wrap_around = False
|
||||
self.eval_envs[1].unwrapped.rules.wrap_around = True
|
||||
|
||||
self.eval_obs = [None, None]
|
||||
self.eval_done = [True, True]
|
||||
self.eval_scores_current = [0, 0]
|
||||
|
||||
# Real-time demo environment
|
||||
self.demo_env = Monitor(SnakeEnv(render_mode="rgb_array"))
|
||||
self.demo_env.unwrapped.rules.wrap_around = True
|
||||
self.demo_obs = None
|
||||
self.demo_done = True
|
||||
self.last_demo_time = 0
|
||||
self.demo_speed = 50
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
"""Update visualization on each step."""
|
||||
try:
|
||||
current_time = time.time() * 1000 # Convert to milliseconds
|
||||
|
||||
# Update episode tracking
|
||||
self.current_length += 1
|
||||
reward = self.locals.get("rewards", [0])[0]
|
||||
self.current_reward += reward
|
||||
|
||||
# Get action probabilities and value estimate for demo game
|
||||
if self.demo_obs is not None:
|
||||
# Ensure proper tensor shape (batch_size, obs_dim)
|
||||
demo_obs_tensor = torch.as_tensor(self.demo_obs).float()
|
||||
if len(demo_obs_tensor.shape) == 1:
|
||||
demo_obs_tensor = demo_obs_tensor.unsqueeze(0)
|
||||
|
||||
# Move tensor to same device as model's policy
|
||||
device = self.model.policy.device
|
||||
demo_obs_tensor = demo_obs_tensor.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Get action distribution and value estimate
|
||||
dist = self.model.policy.get_distribution(demo_obs_tensor)
|
||||
action_probs = dist.distribution.probs
|
||||
value_estimate = self.model.policy.predict_values(demo_obs_tensor)
|
||||
|
||||
# Move results back to CPU and get first item (since we added batch dimension)
|
||||
action_probs = action_probs[0].cpu().numpy()
|
||||
value_estimate = value_estimate[0].cpu().numpy()
|
||||
else:
|
||||
action_probs = np.zeros(4)
|
||||
value_estimate = np.array([0.0])
|
||||
|
||||
# Check if episode ended
|
||||
dones = self.locals.get("dones", [False])
|
||||
if any(dones):
|
||||
self.episode_rewards.append(self.current_reward)
|
||||
self.episode_lengths.append(self.current_length)
|
||||
self.current_reward = 0
|
||||
self.current_length = 0
|
||||
|
||||
# Run evaluation games (keep these deterministic as they're meant to show best performance)
|
||||
for i, (env, obs, done) in enumerate(zip(self.eval_envs, self.eval_obs, self.eval_done)):
|
||||
if done:
|
||||
if obs is not None: # If this wasn't the first reset
|
||||
self.eval_scores.append(self.eval_scores_current[i])
|
||||
# Update high scores based on evaluation performance
|
||||
self.high_score = max(self.high_score, self.eval_scores_current[i])
|
||||
self.recent_high_score = max(self.recent_high_score, self.eval_scores_current[i])
|
||||
self.eval_obs[i] = env.reset()[0]
|
||||
self.eval_done[i] = False
|
||||
self.eval_scores_current[i] = 0
|
||||
|
||||
eval_action, _ = self.model.predict(self.eval_obs[i], deterministic=True)
|
||||
eval_action = eval_action.item() if hasattr(eval_action, 'item') else eval_action
|
||||
self.eval_obs[i], _, done, truncated, info = env.step(eval_action)
|
||||
self.eval_done[i] = done or truncated
|
||||
# Update current evaluation score
|
||||
self.eval_scores_current[i] = info.get("score", 0)
|
||||
|
||||
# Run real-time demo game
|
||||
if current_time - self.last_demo_time >= self.demo_speed:
|
||||
if self.demo_done:
|
||||
self.demo_obs = self.demo_env.reset()[0]
|
||||
self.demo_done = False
|
||||
|
||||
demo_action, _ = self.model.predict(self.demo_obs, deterministic=True)
|
||||
demo_action = demo_action.item() if hasattr(demo_action, 'item') else demo_action
|
||||
self.demo_obs, _, done, truncated, info = self.demo_env.step(demo_action)
|
||||
self.demo_done = done or truncated
|
||||
self.last_demo_time = current_time
|
||||
|
||||
# Get current states from actual training environments
|
||||
# Sample one no-wrap and one wrap environment from the training environments
|
||||
no_wrap_idx = 0 # First half are no-wrap
|
||||
wrap_idx = self.n_envs // 2 # Second half are wrap
|
||||
|
||||
training_state_1 = self.training_env.get_attr("session")[no_wrap_idx].get_state()
|
||||
training_state_1["wrap_around"] = False
|
||||
|
||||
training_state_2 = self.training_env.get_attr("session")[wrap_idx].get_state()
|
||||
training_state_2["wrap_around"] = True
|
||||
|
||||
eval_state_1 = self.eval_envs[0].unwrapped.session.get_state()
|
||||
eval_state_1["wrap_around"] = self.eval_envs[0].unwrapped.rules.wrap_around
|
||||
|
||||
eval_state_2 = self.eval_envs[1].unwrapped.session.get_state()
|
||||
eval_state_2["wrap_around"] = self.eval_envs[1].unwrapped.rules.wrap_around
|
||||
|
||||
demo_state = self.demo_env.unwrapped.session.get_state()
|
||||
demo_state["wrap_around"] = self.demo_env.unwrapped.rules.wrap_around
|
||||
|
||||
# Update viz_data
|
||||
viz_data = {
|
||||
'training_state': [training_state_1, training_state_2],
|
||||
'eval_states': [eval_state_1, eval_state_2],
|
||||
'demo_state': demo_state,
|
||||
'weights_info': {
|
||||
'action_probs': action_probs.tolist() if 'action_probs' in locals() else [0.25] * 4,
|
||||
'value_estimate': float(value_estimate[0]) if 'value_estimate' in locals() else 0.0,
|
||||
'action_labels': ['Up', 'Right', 'Down', 'Left']
|
||||
},
|
||||
'training_info': {
|
||||
'total_timesteps': self.num_timesteps,
|
||||
'initial_timesteps': self.initial_timesteps, # Add initial timesteps
|
||||
'target_timesteps': self.total_timesteps,
|
||||
'episode_reward': float(self.current_reward),
|
||||
'mean_reward': float(np.mean(self.episode_rewards[-100:])) if self.episode_rewards else 0.0,
|
||||
'episode_length': int(self.current_length),
|
||||
'mean_length': float(np.mean(self.episode_lengths[-100:])) if self.episode_lengths else 0.0,
|
||||
'mean_eval_score': float(np.mean(self.eval_scores[-100:])) if self.eval_scores else 0.0,
|
||||
'rewards_history': self.episode_rewards[-1000:],
|
||||
'lengths_history': self.episode_lengths[-1000:],
|
||||
'eval_scores_history': self.eval_scores[-1000:],
|
||||
'high_score': self.high_score,
|
||||
'recent_high_score': self.recent_high_score
|
||||
}
|
||||
}
|
||||
|
||||
# Send update to visualization
|
||||
self.viz_queue.put(('update', viz_data))
|
||||
|
||||
# Reset recent high score periodically
|
||||
if self.num_timesteps % (self.n_steps * self.n_envs) == 0:
|
||||
print(f"\nEvaluation High Score: {self.high_score}")
|
||||
print(f"Recent Evaluation High Score: {self.recent_high_score}")
|
||||
self.recent_high_score = 0
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in visualization callback: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return True
|
||||
|
||||
def make_env(seed: int = 0, wrap_walls: bool = False) -> gym.Env:
|
||||
"""Create a training environment.
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
wrap_walls: Whether snake can wrap around walls
|
||||
"""
|
||||
def _init() -> gym.Env:
|
||||
env = SnakeEnv(render_mode=None)
|
||||
env = Monitor(env)
|
||||
# Set wrap-around mode
|
||||
env.unwrapped.rules.wrap_around = wrap_walls
|
||||
env.reset(seed=seed)
|
||||
return env
|
||||
|
||||
set_random_seed(seed)
|
||||
return _init
|
||||
|
||||
def train_model(
|
||||
total_timesteps: int,
|
||||
viz_queue: Queue,
|
||||
model_path: Optional[str] = None,
|
||||
cuda: bool = True,
|
||||
learning_rate: float = 3e-4, # Standard PPO learning rate
|
||||
batch_size: int = 256, # More reasonable batch size
|
||||
n_envs: int = 16, # Balanced number of environments
|
||||
n_steps: int = 2048 # Standard PPO steps
|
||||
) -> None:
|
||||
"""
|
||||
Train the snake AI model.
|
||||
|
||||
Args:
|
||||
total_timesteps: Number of timesteps to train for
|
||||
viz_queue: Queue for sending visualization updates
|
||||
model_path: Path to save/load the model
|
||||
cuda: Whether to use CUDA for training
|
||||
learning_rate: Learning rate for the optimizer
|
||||
batch_size: Batch size for training
|
||||
n_envs: Number of parallel environments
|
||||
n_steps: Number of steps per environment
|
||||
"""
|
||||
try:
|
||||
# Set up environments - half with wrap, half without
|
||||
envs = []
|
||||
for i in range(n_envs):
|
||||
wrap_walls = i >= n_envs // 2 # Half the envs with wrap-around
|
||||
envs.append(make_env(i, wrap_walls))
|
||||
|
||||
env = SubprocVecEnv(envs)
|
||||
|
||||
# PPO hyperparameters - adjusted for more stable learning
|
||||
ppo_params = dict(
|
||||
learning_rate=1e-4, # Slightly higher learning rate
|
||||
n_steps=1024, # Shorter horizon for faster updates
|
||||
batch_size=128, # Smaller batch size for more frequent updates
|
||||
n_epochs=4, # Balanced number of epochs
|
||||
gamma=0.98, # Slightly lower discount for more immediate rewards
|
||||
gae_lambda=0.9, # Lower GAE lambda for more immediate advantage estimates
|
||||
clip_range=0.1, # Smaller clip range for more conservative updates
|
||||
clip_range_vf=0.1, # Also clip value function
|
||||
normalize_advantage=True,
|
||||
ent_coef=0.2, # Moderate entropy for exploration
|
||||
vf_coef=0.8, # Balanced value function coefficient
|
||||
max_grad_norm=0.3, # More conservative gradient clipping
|
||||
target_kl=None,
|
||||
verbose=1,
|
||||
device="cuda" if cuda and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Create or load model with updated architecture
|
||||
policy_kwargs = dict(
|
||||
features_extractor_class=CustomRecurrentCNN, # Use the new recurrent extractor
|
||||
features_extractor_kwargs=dict(features_dim=512), # You can adjust features_dim as needed
|
||||
net_arch=dict(
|
||||
pi=[512, 256, 128], # Policy network layers
|
||||
vf=[512, 256, 128] # Value network layers
|
||||
),
|
||||
activation_fn=nn.ReLU,
|
||||
normalize_images=False
|
||||
)
|
||||
|
||||
# Track initial timesteps for progress tracking
|
||||
initial_timesteps = 0
|
||||
|
||||
if model_path and os.path.exists(model_path):
|
||||
print(f"Loading model from {model_path}")
|
||||
try:
|
||||
model = PPO.load(
|
||||
model_path,
|
||||
env=env,
|
||||
custom_objects={"policy_kwargs": policy_kwargs},
|
||||
**ppo_params
|
||||
)
|
||||
initial_timesteps = model.num_timesteps
|
||||
print(f"Loaded model with {initial_timesteps:,} timesteps of training")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model with error: {e}")
|
||||
print("Creating new model instead")
|
||||
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, **ppo_params)
|
||||
else:
|
||||
print("Creating new model")
|
||||
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, **ppo_params)
|
||||
|
||||
# Set up visualization callback with initial timesteps
|
||||
viz_callback = VisualizationCallback(
|
||||
viz_queue=viz_queue,
|
||||
total_timesteps=total_timesteps,
|
||||
eval_freq=1000,
|
||||
n_steps=n_steps,
|
||||
n_envs=n_envs,
|
||||
initial_timesteps=initial_timesteps # Pass initial timesteps to callback
|
||||
)
|
||||
|
||||
# Training loop with stop check
|
||||
should_stop = False
|
||||
timesteps_per_batch = n_steps * n_envs
|
||||
num_batches = total_timesteps // timesteps_per_batch
|
||||
|
||||
for i in range(num_batches):
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
# Train for one batch
|
||||
model.learn(
|
||||
total_timesteps=timesteps_per_batch,
|
||||
callback=viz_callback,
|
||||
progress_bar=True,
|
||||
reset_num_timesteps=False
|
||||
)
|
||||
|
||||
# Check for stop signal from visualization
|
||||
try:
|
||||
while True: # Process all pending messages
|
||||
msg_type, _ = viz_queue.get_nowait()
|
||||
if msg_type == 'stop_training':
|
||||
should_stop = True
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# Save model if training completed or interrupted
|
||||
if model_path:
|
||||
print(f"Saving model to {model_path}")
|
||||
model.save(model_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in training process: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Clean up
|
||||
env.close()
|
||||
# Signal visualization to stop if we haven't already
|
||||
viz_queue.put(('stop', None))
|
||||
|
||||
def main():
|
||||
"""Main training script."""
|
||||
parser = argparse.ArgumentParser(description="Train Snake AI")
|
||||
parser.add_argument("--config", action="store_true",
|
||||
help="Show configuration UI before training")
|
||||
parser.add_argument("--model", type=str, default="models/snake_ai.zip",
|
||||
help="Path to save/load model")
|
||||
parser.add_argument("--no-cuda", action="store_true",
|
||||
help="Disable CUDA training")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get training configuration
|
||||
if args.config:
|
||||
from src.ai.config_ui import ConfigUI
|
||||
config_ui = ConfigUI()
|
||||
config = config_ui.run()
|
||||
if config is None: # User cancelled
|
||||
return
|
||||
else:
|
||||
# Load saved config or use defaults
|
||||
config_file = "training_config.json"
|
||||
if os.path.exists(config_file):
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {
|
||||
'timesteps': 2000000,
|
||||
'learning_rate': 0.0001,
|
||||
'batch_size': 512,
|
||||
'n_envs': 32,
|
||||
'n_steps': 512
|
||||
}
|
||||
|
||||
# Create visualization queue
|
||||
viz_queue = Queue()
|
||||
|
||||
# Start visualization in separate thread
|
||||
viz_thread = Thread(target=run_dashboard, args=(viz_queue,))
|
||||
viz_thread.start()
|
||||
|
||||
# Start training with configuration
|
||||
train_model(
|
||||
total_timesteps=config['timesteps'],
|
||||
viz_queue=viz_queue,
|
||||
model_path=args.model,
|
||||
cuda=not args.no_cuda,
|
||||
learning_rate=config['learning_rate'],
|
||||
batch_size=config['batch_size'],
|
||||
n_envs=config['n_envs'],
|
||||
n_steps=config['n_steps']
|
||||
)
|
||||
|
||||
# Wait for visualization to finish
|
||||
viz_thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
540
src/ai/visualize.py
Normal file
540
src/ai/visualize.py
Normal file
@ -0,0 +1,540 @@
|
||||
"""
|
||||
Training Visualization Dashboard
|
||||
|
||||
This module provides a real-time visualization of the training process,
|
||||
showing both the training metrics and evaluation games.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
import threading
|
||||
from queue import Queue, Empty
|
||||
import time
|
||||
|
||||
class TrainingDashboard:
|
||||
def __init__(self, width: int = 1920, height: int = 1080):
|
||||
"""Initialize the training dashboard."""
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
# Fonts
|
||||
self.title_font = None
|
||||
self.header_font = None
|
||||
self.text_font = None
|
||||
|
||||
# Colors
|
||||
self.colors = {
|
||||
'background': (0, 0, 0),
|
||||
'text': (200, 200, 200),
|
||||
'highlight': (0, 255, 0),
|
||||
'grid': (40, 40, 40),
|
||||
'border': (0, 255, 0),
|
||||
'snake': (0, 255, 0),
|
||||
'food': (255, 0, 0),
|
||||
'graph_bg': (20, 20, 20),
|
||||
'graph_line': (0, 255, 0),
|
||||
'graph_grid': (40, 40, 40)
|
||||
}
|
||||
|
||||
# Layout
|
||||
self.layout = {
|
||||
'margin': 20,
|
||||
'game_size': 400,
|
||||
'graph_height': 200,
|
||||
'metrics_width': 300
|
||||
}
|
||||
|
||||
# Training metrics history
|
||||
self.metrics_history = {
|
||||
'rewards': [],
|
||||
'scores': [],
|
||||
'lengths': []
|
||||
}
|
||||
self.max_history = 1000
|
||||
|
||||
# Current state
|
||||
self.current_state = None
|
||||
self.eval_state = None
|
||||
self.demo_state = None
|
||||
self.training_info = None
|
||||
self.model_info = None
|
||||
|
||||
# Performance tracking
|
||||
self.frame_times = []
|
||||
self.max_frame_times = 60
|
||||
|
||||
# Initialize the clock
|
||||
self.clock = None
|
||||
self.target_fps = 60
|
||||
|
||||
def run(self, update_queue: Queue) -> None:
|
||||
"""Run the dashboard, processing updates from the queue."""
|
||||
pygame.init()
|
||||
try:
|
||||
self.screen = pygame.display.set_mode((self.width, self.height))
|
||||
pygame.display.set_caption("Snake AI Training Dashboard")
|
||||
|
||||
# Initialize fonts after pygame is initialized
|
||||
self.title_font = pygame.font.Font(None, 48)
|
||||
self.header_font = pygame.font.Font(None, 36)
|
||||
self.text_font = pygame.font.Font(None, 24)
|
||||
|
||||
# Initialize clock
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
running = True
|
||||
while running:
|
||||
# Handle events
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
running = False
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
running = False
|
||||
|
||||
# Process updates from queue
|
||||
try:
|
||||
while True: # Process all available updates
|
||||
update_type, data = update_queue.get_nowait()
|
||||
if update_type == 'update':
|
||||
self.update(data)
|
||||
elif update_type == 'stop':
|
||||
running = False
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# Render dashboard
|
||||
self.render()
|
||||
self.clock.tick(self.target_fps)
|
||||
finally:
|
||||
pygame.quit()
|
||||
|
||||
def update(self, data: Dict[str, Any]) -> None:
|
||||
"""Update the dashboard with new data."""
|
||||
self.current_state = data # Store entire data dict
|
||||
if 'training_info' in data:
|
||||
self.training_info = data['training_info']
|
||||
# Update metrics history
|
||||
for key in ['rewards', 'scores', 'lengths']:
|
||||
if key in self.training_info:
|
||||
self.metrics_history[key].append(self.training_info[key])
|
||||
if len(self.metrics_history[key]) > self.max_history:
|
||||
self.metrics_history[key].pop(0)
|
||||
|
||||
def draw_game_view(self, state: Dict[str, Any], position: Tuple[int, int], size: int, title: str) -> None:
|
||||
"""Draw a game view at the specified position."""
|
||||
if not state:
|
||||
return
|
||||
|
||||
x, y = position
|
||||
|
||||
# Draw game area background
|
||||
game_rect = pygame.Rect(x, y, size, size)
|
||||
pygame.draw.rect(self.screen, self.colors['grid'], game_rect)
|
||||
pygame.draw.rect(self.screen, self.colors['border'], game_rect, 2)
|
||||
|
||||
# Get time limit info directly from state
|
||||
steps_since_food = state.get('steps_since_food', 0)
|
||||
current_time_limit = state.get('current_time_limit', 100) # Default to base limit if not provided
|
||||
time_left = max(0, current_time_limit - steps_since_food)
|
||||
|
||||
# Calculate cell size based on grid dimensions
|
||||
cell_size = size // 30 # Assuming 30x30 grid
|
||||
|
||||
# Draw snake
|
||||
if "snake_body" in state:
|
||||
for segment in state["snake_body"]:
|
||||
segment_rect = pygame.Rect(
|
||||
x + segment[0] * cell_size,
|
||||
y + segment[1] * cell_size,
|
||||
cell_size, cell_size
|
||||
)
|
||||
pygame.draw.rect(self.screen, self.colors['snake'], segment_rect)
|
||||
|
||||
# Draw food
|
||||
if "food_position" in state:
|
||||
food_rect = pygame.Rect(
|
||||
x + state["food_position"][0] * cell_size,
|
||||
y + state["food_position"][1] * cell_size,
|
||||
cell_size, cell_size
|
||||
)
|
||||
pygame.draw.rect(self.screen, self.colors['food'], food_rect)
|
||||
|
||||
# Draw score and timer above game area
|
||||
score = state.get('score', 0)
|
||||
score_text = f"Score: {score}"
|
||||
timer_text = f"Time: {time_left}/{current_time_limit}" # Show both current and max time
|
||||
combined_text = f"{score_text} | {timer_text}"
|
||||
score_surface = self.text_font.render(combined_text, True, self.colors['text'])
|
||||
score_rect = score_surface.get_rect(midtop=(x + size//2, y - 25))
|
||||
self.screen.blit(score_surface, score_rect)
|
||||
|
||||
# Draw title below game area, including wrap mode for training view
|
||||
if title == "Training Episode" and 'wrap_around' in state:
|
||||
wrap_mode = "(Wrap)" if state['wrap_around'] else "(No Wrap)"
|
||||
title = f"{title} {wrap_mode}"
|
||||
title_text = title
|
||||
title_surface = self.header_font.render(title_text, True, self.colors['text'])
|
||||
title_rect = title_surface.get_rect(midtop=(x + size//2, y + size + 10))
|
||||
self.screen.blit(title_surface, title_rect)
|
||||
|
||||
def draw_metrics(self, x: int, y: int, width: int) -> None:
|
||||
"""Draw training metrics."""
|
||||
if not self.training_info:
|
||||
return
|
||||
|
||||
font = pygame.font.Font(None, 24)
|
||||
y_offset = 0
|
||||
|
||||
# Draw high scores
|
||||
high_score = self.training_info.get("high_score", 0)
|
||||
recent_high = self.training_info.get("recent_high_score", 0)
|
||||
|
||||
metrics = [
|
||||
("All-time High Score", high_score),
|
||||
("Recent High Score", recent_high),
|
||||
("", ""), # Empty line for spacing
|
||||
("Training Metrics", ""),
|
||||
("Total Steps", self.training_info.get('total_timesteps', 0)),
|
||||
("Episode Reward", f"{self.training_info.get('episode_reward', 0):.2f}"),
|
||||
("Mean Reward", f"{self.training_info.get('mean_reward', 0):.2f}"),
|
||||
("Episode Length", self.training_info.get('episode_length', 0)),
|
||||
("Mean Length", f"{self.training_info.get('mean_length', 0):.2f}"),
|
||||
("FPS", f"{len(self.frame_times) / sum(self.frame_times):.1f}" if self.frame_times else "0")
|
||||
]
|
||||
|
||||
for i, (label, value) in enumerate(metrics):
|
||||
text = f"{label}: {value}" if label else ""
|
||||
surface = self.text_font.render(text, True, self.colors['text'])
|
||||
self.screen.blit(surface, (x, y + i * 25))
|
||||
|
||||
def draw_graph(self, data: List[float], rect: pygame.Rect, title: str,
|
||||
color: Tuple[int, int, int], min_val: float, max_val: float,
|
||||
smoothing_window: int = 10) -> None:
|
||||
"""Draw a line graph of the given data.
|
||||
|
||||
Args:
|
||||
data: List of data points.
|
||||
rect: Rectangle area for drawing the graph.
|
||||
title: Title for the graph.
|
||||
color: Color for the graph line.
|
||||
min_val: Minimum value for scaling the y-axis.
|
||||
max_val: Maximum value for scaling the y-axis.
|
||||
smoothing_window: Optional smoothing window size for drawing
|
||||
long term trends. Increase this value for a smoother, longer-term trend.
|
||||
Default is 10.
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Create a separate surface for the graph
|
||||
graph_surface = pygame.Surface((rect.width, rect.height))
|
||||
graph_surface.fill(self.colors['graph_bg'])
|
||||
|
||||
# Draw grid lines on the graph surface
|
||||
num_lines = 5
|
||||
for i in range(num_lines):
|
||||
y_line = i * rect.height / (num_lines - 1)
|
||||
pygame.draw.line(graph_surface, self.colors['graph_grid'],
|
||||
(0, y_line), (rect.width, y_line))
|
||||
|
||||
# Add padding to the min/max values for better visualization
|
||||
value_range = max_val - min_val
|
||||
if value_range == 0:
|
||||
value_range = 1.0
|
||||
padding = value_range * 0.1
|
||||
min_val_adjusted = min_val - padding
|
||||
max_val_adjusted = max_val + padding
|
||||
|
||||
# Smooth the data using the provided smoothing_window parameter
|
||||
window_size = smoothing_window if len(data) >= smoothing_window else len(data)
|
||||
if window_size > 1:
|
||||
smoothed_data = np.convolve(data, np.ones(window_size) / window_size, mode='valid')
|
||||
else:
|
||||
smoothed_data = data
|
||||
|
||||
points = []
|
||||
for i, value in enumerate(smoothed_data):
|
||||
x_point = i * rect.width / len(smoothed_data)
|
||||
y_point = rect.height - ((value - min_val_adjusted) * rect.height /
|
||||
(max_val_adjusted - min_val_adjusted))
|
||||
points.append((x_point, y_point))
|
||||
|
||||
if len(points) > 1:
|
||||
pygame.draw.lines(graph_surface, color, False, points, 2)
|
||||
|
||||
# Draw border on the graph surface
|
||||
pygame.draw.rect(graph_surface, self.colors['border'],
|
||||
pygame.Rect(0, 0, rect.width, rect.height), 1)
|
||||
|
||||
# Blit the graph surface onto the main screen
|
||||
self.screen.blit(graph_surface, rect)
|
||||
|
||||
# Draw title and current value
|
||||
current_value = data[-1] if data else 0
|
||||
title_text = f"{title} (Current: {current_value:.2f})"
|
||||
title_surface = self.text_font.render(title_text, True, self.colors['text'])
|
||||
title_rect = title_surface.get_rect(midtop=(rect.centerx, rect.top - 20))
|
||||
self.screen.blit(title_surface, title_rect)
|
||||
|
||||
# Draw value labels to the right of the graph
|
||||
label_spacing = rect.height / (num_lines - 1)
|
||||
label_x = rect.right + 10
|
||||
for i in range(num_lines):
|
||||
value = max_val_adjusted - (i * (max_val_adjusted - min_val_adjusted) / (num_lines - 1))
|
||||
label_text = f"{value:.1f}"
|
||||
label_surface = self.text_font.render(label_text, True, self.colors['text'])
|
||||
label_rect = label_surface.get_rect(
|
||||
left=label_x,
|
||||
centery=rect.top + i * label_spacing
|
||||
)
|
||||
self.screen.blit(label_surface, label_rect)
|
||||
|
||||
def draw_weights_visualization(self, x: int, y: int, width: int, height: int, weights_info: Dict) -> None:
|
||||
"""Draw the neural network weights visualization."""
|
||||
if not weights_info:
|
||||
return
|
||||
|
||||
# Get data
|
||||
action_probs = weights_info.get('action_probs', [0, 0, 0, 0])
|
||||
value_estimate = weights_info.get('value_estimate', 0.0)
|
||||
action_labels = weights_info.get('action_labels', ['Up', 'Right', 'Down', 'Left'])
|
||||
|
||||
# Draw background
|
||||
rect = pygame.Rect(x, y, width, height)
|
||||
pygame.draw.rect(self.screen, self.colors['graph_bg'], rect)
|
||||
pygame.draw.rect(self.screen, self.colors['border'], rect, 1)
|
||||
|
||||
# Draw title
|
||||
title = "Action Probabilities"
|
||||
title_surface = self.header_font.render(title, True, self.colors['text'])
|
||||
title_rect = title_surface.get_rect(midtop=(x + width//2, y + 5))
|
||||
self.screen.blit(title_surface, title_rect)
|
||||
|
||||
# Draw action probability bars
|
||||
bar_height = 20
|
||||
bar_spacing = 30
|
||||
bar_start_y = y + 50
|
||||
max_bar_width = width - 120 # Leave space for labels
|
||||
|
||||
for i, (prob, label) in enumerate(zip(action_probs, action_labels)):
|
||||
# Draw label
|
||||
label_surface = self.text_font.render(label, True, self.colors['text'])
|
||||
self.screen.blit(label_surface, (x + 10, bar_start_y + i * bar_spacing))
|
||||
|
||||
# Draw bar background
|
||||
bar_bg_rect = pygame.Rect(x + 70, bar_start_y + i * bar_spacing, max_bar_width, bar_height)
|
||||
pygame.draw.rect(self.screen, self.colors['grid'], bar_bg_rect)
|
||||
|
||||
# Draw probability bar
|
||||
bar_width = int(prob * max_bar_width)
|
||||
bar_rect = pygame.Rect(x + 70, bar_start_y + i * bar_spacing, bar_width, bar_height)
|
||||
pygame.draw.rect(self.screen, self.colors['highlight'], bar_rect)
|
||||
|
||||
# Draw probability value
|
||||
prob_text = f"{prob:.2f}"
|
||||
prob_surface = self.text_font.render(prob_text, True, self.colors['text'])
|
||||
self.screen.blit(prob_surface, (x + 80 + max_bar_width, bar_start_y + i * bar_spacing))
|
||||
|
||||
# Draw value estimate
|
||||
value_text = f"Value Estimate: {value_estimate:.2f}"
|
||||
value_surface = self.text_font.render(value_text, True, self.colors['text'])
|
||||
value_rect = value_surface.get_rect(midbottom=(x + width//2, y + height - 10))
|
||||
self.screen.blit(value_surface, value_rect)
|
||||
|
||||
def draw_progress_bar(self, x: int, y: int, width: int, height: int, progress: float, text: str) -> None:
|
||||
"""Draw a progress bar with text.
|
||||
|
||||
Args:
|
||||
x, y: Position of the progress bar
|
||||
width, height: Dimensions of the progress bar
|
||||
progress: Progress value between 0 and 1
|
||||
text: Text to display above the progress bar
|
||||
"""
|
||||
# Draw background
|
||||
bg_rect = pygame.Rect(x, y, width, height)
|
||||
pygame.draw.rect(self.screen, self.colors['graph_bg'], bg_rect)
|
||||
pygame.draw.rect(self.screen, self.colors['border'], bg_rect, 1)
|
||||
|
||||
# Draw progress
|
||||
if progress > 0:
|
||||
progress_width = int(width * progress)
|
||||
progress_rect = pygame.Rect(x, y, progress_width, height)
|
||||
pygame.draw.rect(self.screen, self.colors['highlight'], progress_rect)
|
||||
|
||||
# Draw text
|
||||
text_surface = self.text_font.render(text, True, self.colors['text'])
|
||||
text_rect = text_surface.get_rect(bottomleft=(x, y - 5))
|
||||
self.screen.blit(text_surface, text_rect)
|
||||
|
||||
def render(self) -> None:
|
||||
"""Render the dashboard."""
|
||||
if not self.current_state:
|
||||
return
|
||||
|
||||
self.screen.fill(self.colors['background'])
|
||||
|
||||
# Draw game views
|
||||
if 'training_state' in self.current_state:
|
||||
# Draw both training games
|
||||
self.draw_game_view(
|
||||
self.current_state['training_state'][0], # No wrap training
|
||||
(50, 50),
|
||||
250,
|
||||
"Training (No Wrap)"
|
||||
)
|
||||
self.draw_game_view(
|
||||
self.current_state['training_state'][1], # Wrap training
|
||||
(350, 50),
|
||||
250,
|
||||
"Training (Wrap)"
|
||||
)
|
||||
|
||||
if 'eval_states' in self.current_state:
|
||||
self.draw_game_view(
|
||||
self.current_state['eval_states'][0],
|
||||
(650, 50),
|
||||
250,
|
||||
"Evaluation (No Wrap)"
|
||||
)
|
||||
self.draw_game_view(
|
||||
self.current_state['eval_states'][1],
|
||||
(950, 50),
|
||||
250,
|
||||
"Evaluation (Wrap)"
|
||||
)
|
||||
|
||||
if 'demo_state' in self.current_state:
|
||||
self.draw_game_view(
|
||||
self.current_state['demo_state'],
|
||||
(1250, 50),
|
||||
250,
|
||||
"Demo Game"
|
||||
)
|
||||
|
||||
# Draw metrics
|
||||
if self.training_info:
|
||||
metrics_data = [
|
||||
(self.training_info['rewards_history'], "Episode Rewards"),
|
||||
(self.training_info['lengths_history'], "Episode Lengths"),
|
||||
(self.training_info['eval_scores_history'], "Evaluation Scores")
|
||||
]
|
||||
|
||||
for i, (data, label) in enumerate(metrics_data):
|
||||
if data: # Only draw if we have data
|
||||
rect = pygame.Rect(50, 400 + i * 200, 500, 150)
|
||||
self.draw_graph(
|
||||
data,
|
||||
rect,
|
||||
label,
|
||||
self.colors['graph_line'],
|
||||
min(data) if data else 0,
|
||||
max(data) if data else 1
|
||||
)
|
||||
|
||||
# Draw current metrics text
|
||||
metrics_text = [
|
||||
f"Total Steps: {self.training_info.get('total_timesteps', 0):,}",
|
||||
f"Mean Reward: {self.training_info.get('mean_reward', 0):.2f}",
|
||||
f"Mean Length: {self.training_info.get('mean_length', 0):.1f}",
|
||||
f"Mean Eval Score: {self.training_info.get('mean_eval_score', 0):.2f}",
|
||||
f"High Score: {self.training_info.get('high_score', 0)}",
|
||||
f"Recent High Score: {self.training_info.get('recent_high_score', 0)}"
|
||||
]
|
||||
|
||||
# Add vertical spacer between metrics sections
|
||||
metrics_text.insert(3, "") # Insert empty string as spacer after first 3 metrics
|
||||
|
||||
font = pygame.font.Font(None, 36)
|
||||
for i, text in enumerate(metrics_text):
|
||||
surface = font.render(text, True, self.colors['text'])
|
||||
self.screen.blit(surface, (600, 400 + i * 30))
|
||||
|
||||
# Draw action probabilities if available
|
||||
if 'weights_info' in self.current_state:
|
||||
self.draw_weights_visualization(
|
||||
1100, 400,
|
||||
300, 300,
|
||||
self.current_state['weights_info']
|
||||
)
|
||||
|
||||
# Draw trend indicator graph
|
||||
if self.training_info:
|
||||
# Calculate trend indicators
|
||||
window = 100 # Use last 100 episodes for trends
|
||||
rewards = self.training_info['rewards_history'][-window:]
|
||||
lengths = self.training_info['lengths_history'][-window:]
|
||||
eval_scores = self.training_info['eval_scores_history'][-window:]
|
||||
|
||||
if rewards and lengths and eval_scores:
|
||||
# Normalize each metric to 0-1 range for fair combination
|
||||
norm_rewards = [(r - min(rewards)) / (max(rewards) - min(rewards) + 1e-8) for r in rewards]
|
||||
norm_lengths = [(l - min(lengths)) / (max(lengths) - min(lengths) + 1e-8) for l in lengths]
|
||||
norm_scores = [(s - min(eval_scores)) / (max(eval_scores) - min(eval_scores) + 1e-8) for s in eval_scores]
|
||||
|
||||
# Combine metrics with weights (rewards: 0.4, lengths: 0.3, eval_scores: 0.3)
|
||||
trend_data = [0.4 * r + 0.3 * l + 0.3 * s
|
||||
for r, l, s in zip(norm_rewards, norm_lengths, norm_scores)]
|
||||
# Add trend direction indicator
|
||||
|
||||
trending = 0 # -1, 0, 1
|
||||
|
||||
if len(trend_data) >= 2:
|
||||
recent_trend = trend_data[-1] - trend_data[0]
|
||||
trending = 1 if recent_trend > 0.1 else -1 if recent_trend < -0.1 else 0
|
||||
trend_text = "↑ Improving" if trending == 1 else "↓ Declining" if trending == -1 else "→ Stable"
|
||||
trend_color = (0, 255, 0) if trending == 1 else (255, 0, 0) if trending == -1 else (200, 200, 0)
|
||||
trend_surface = self.text_font.render(trend_text, True, trend_color)
|
||||
trend_rect = trend_surface.get_rect(midtop=(1250, 860))
|
||||
self.screen.blit(trend_surface, trend_rect)
|
||||
|
||||
graph_color = (0, 255, 0) if trending == 1 else (255, 0, 0) if trending == -1 else (200, 200, 0)
|
||||
|
||||
# Draw trend graph 50 pixels below the action probabilities graph
|
||||
trend_rect = pygame.Rect(1100, 750, 300, 150)
|
||||
self.draw_graph(
|
||||
trend_data,
|
||||
trend_rect,
|
||||
"Learning Trend",
|
||||
graph_color, # color based on trend
|
||||
0, # Min value (normalized)
|
||||
1, # Max value (normalized)
|
||||
smoothing_window=10 # Use a larger smoothing window for longer-term trends
|
||||
)
|
||||
|
||||
# Draw training progress bar at the bottom of the screen
|
||||
if self.training_info:
|
||||
# Get current session steps (from start of this training run)
|
||||
current_session_steps = self.training_info.get('total_timesteps', 0) - self.training_info.get('initial_timesteps', 0)
|
||||
target_timesteps = self.training_info.get('target_timesteps', 1000000)
|
||||
progress = min(1.0, current_session_steps / target_timesteps)
|
||||
|
||||
# Show both current session progress and total model training
|
||||
total_model_steps = self.training_info.get('total_timesteps', 0)
|
||||
progress_text = (
|
||||
f"Current Session: {current_session_steps:,}/{target_timesteps:,} steps ({progress*100:.1f}%) | "
|
||||
f"Total Model Training: {total_model_steps:,} steps"
|
||||
)
|
||||
self.draw_progress_bar(50, self.height - 40, self.width - 100, 20, progress, progress_text)
|
||||
|
||||
pygame.display.flip()
|
||||
|
||||
def handle_events(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self.running = False
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
self.running = False
|
||||
|
||||
def run_dashboard(update_queue: Queue) -> None:
|
||||
"""Run the training dashboard in a separate thread."""
|
||||
dashboard = TrainingDashboard()
|
||||
try:
|
||||
dashboard.run(update_queue)
|
||||
finally:
|
||||
# Send stop signal back to training process
|
||||
update_queue.put(('stop_training', None))
|
||||
pygame.quit()
|
16
src/cli.py
16
src/cli.py
@ -6,7 +6,8 @@ It handles command-line arguments for different game modes and configurations.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from src import Game, GameMode
|
||||
from src.game import Game, GameState
|
||||
from src.ui.menu import GameMode
|
||||
|
||||
def main():
|
||||
"""
|
||||
@ -35,12 +36,19 @@ def main():
|
||||
if args.test:
|
||||
game.width = 400
|
||||
game.height = 300
|
||||
game.block_size = 20
|
||||
game.fps = 30
|
||||
game.game_session.grid_size = 20
|
||||
game.clock.tick(30) # Set FPS to 30
|
||||
|
||||
# Apply debug mode settings
|
||||
if args.debug:
|
||||
game.debug = True
|
||||
game.settings.debug_mode = True
|
||||
if game.game_session:
|
||||
game.game_session.show_grid = True
|
||||
|
||||
# Apply AI mode settings
|
||||
if args.ai_only:
|
||||
game.game_mode = GameMode.AI_MEDIUM
|
||||
game.state = GameState.PLAYING
|
||||
|
||||
# Run the game
|
||||
try:
|
||||
|
36
src/config/__init__.py
Normal file
36
src/config/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
Game Configuration Package
|
||||
|
||||
This package contains game configuration and constants.
|
||||
"""
|
||||
|
||||
from .constants import *
|
||||
from .settings import GameRules, GameSettings
|
||||
|
||||
__all__ = [
|
||||
'GameRules',
|
||||
'GameSettings',
|
||||
# Include all constants
|
||||
'WINDOW_WIDTH',
|
||||
'WINDOW_HEIGHT',
|
||||
'DEFAULT_GRID_WIDTH',
|
||||
'DEFAULT_GRID_HEIGHT',
|
||||
'DEFAULT_GRID_SIZE',
|
||||
'DEFAULT_PADDING',
|
||||
'DEFAULT_SCORE_HEIGHT',
|
||||
'BLACK',
|
||||
'WHITE',
|
||||
'GREEN',
|
||||
'DARK_GREEN',
|
||||
'GRAY',
|
||||
'DARK_GRAY',
|
||||
'GRID_COLOR',
|
||||
'TITLE_FONT_SIZE',
|
||||
'SUBTITLE_FONT_SIZE',
|
||||
'SCORE_FONT_SIZE',
|
||||
'DEBUG_FONT_SIZE',
|
||||
'FPS',
|
||||
'DEFAULT_MOVE_COOLDOWN',
|
||||
'MIN_MOVE_COOLDOWN',
|
||||
'MENU_ITEM_SPACING'
|
||||
]
|
25
src/config/colors.py
Normal file
25
src/config/colors.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""
|
||||
Game Colors
|
||||
|
||||
This module contains all the color values used throughout the game.
|
||||
"""
|
||||
|
||||
# Modern color palette
|
||||
BLACK = (0, 0, 0)
|
||||
WHITE = (255, 255, 255)
|
||||
GREEN = (0, 255, 0)
|
||||
DARK_GREEN = (0, 100, 0)
|
||||
NEON_GREEN = (57, 255, 20) # Brighter, more neon green
|
||||
LIME_GREEN = (50, 205, 50) # Softer green for variety
|
||||
FOREST_GREEN = (0, 100, 0) # Dark green for contrast
|
||||
DARKER_GREEN = (0, 40, 0) # Very dark green for borders
|
||||
GRAY = (128, 128, 128)
|
||||
DARK_GRAY = (25, 25, 25) # Darker background for better contrast
|
||||
GRID_COLOR = (40, 40, 40)
|
||||
SUBTLE_GRID_COLOR = (35, 35, 35)
|
||||
RED = (255, 0, 0)
|
||||
NEON_RED = (255, 20, 57) # Brighter, more neon red for food
|
||||
GLOW_GREEN = (150, 255, 150) # Light green for glow effects
|
||||
SNAKE_GRADIENT_START = (57, 255, 20) # Head color
|
||||
SNAKE_GRADIENT_END = (0, 150, 0) # Tail color
|
||||
|
31
src/config/constants.py
Normal file
31
src/config/constants.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Game Constants
|
||||
|
||||
This module contains all the constant values used throughout the game.
|
||||
"""
|
||||
|
||||
# Window settings
|
||||
WINDOW_WIDTH = 1024
|
||||
WINDOW_HEIGHT = 768
|
||||
|
||||
# Grid settings
|
||||
DEFAULT_GRID_WIDTH = 30
|
||||
DEFAULT_GRID_HEIGHT = 30
|
||||
DEFAULT_GRID_SIZE = 20
|
||||
|
||||
# Game area settings
|
||||
DEFAULT_PADDING = 40
|
||||
DEFAULT_SCORE_HEIGHT = 60
|
||||
# Font sizes
|
||||
TITLE_FONT_SIZE = 72
|
||||
SUBTITLE_FONT_SIZE = 36
|
||||
SCORE_FONT_SIZE = 36
|
||||
DEBUG_FONT_SIZE = 24
|
||||
|
||||
# Game timing
|
||||
FPS = 60
|
||||
DEFAULT_MOVE_COOLDOWN = 100
|
||||
MIN_MOVE_COOLDOWN = 50
|
||||
|
||||
# Menu settings
|
||||
MENU_ITEM_SPACING = 50
|
52
src/config/settings.py
Normal file
52
src/config/settings.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Game Settings
|
||||
|
||||
This module contains the game settings and rules that can be configured.
|
||||
"""
|
||||
|
||||
class GameRules:
|
||||
"""Game rules that can be configured through the settings menu."""
|
||||
|
||||
def __init__(self):
|
||||
self.wrap_around = True # Whether snake wraps around screen edges
|
||||
self.speed_increase = True # Whether snake speeds up as it grows
|
||||
self.min_move_cooldown = 50 # Minimum movement delay in milliseconds
|
||||
self.initial_move_cooldown = 100 # Initial movement delay
|
||||
self.starting_length = 3 # Starting length of the snake
|
||||
|
||||
def update_rule(self, rule_name, value):
|
||||
"""
|
||||
Update a game rule with a new value.
|
||||
|
||||
Args:
|
||||
rule_name (str): Name of the rule to update
|
||||
value: New value for the rule
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful, False otherwise
|
||||
"""
|
||||
if hasattr(self, rule_name):
|
||||
# Define validation rules for specific attributes
|
||||
validators = {
|
||||
'starting_length': lambda x: 1 <= x <= 10,
|
||||
'wrap_around': lambda x: isinstance(x, bool),
|
||||
'speed_increase': lambda x: isinstance(x, bool),
|
||||
'min_move_cooldown': lambda x: x > 0,
|
||||
'initial_move_cooldown': lambda x: x > 0
|
||||
}
|
||||
|
||||
# Validate the value if there's a validator for this rule
|
||||
if rule_name in validators and not validators[rule_name](value):
|
||||
return False
|
||||
|
||||
setattr(self, rule_name, value)
|
||||
return True
|
||||
return False
|
||||
|
||||
class GameSettings:
|
||||
"""Global game settings."""
|
||||
|
||||
def __init__(self):
|
||||
self.debug_mode = False # Whether to show debug information
|
||||
self.show_grid = self.debug_mode # Whether to show the grid lines
|
||||
self.rules = GameRules() # Game rules instance
|
16
src/core/__init__.py
Normal file
16
src/core/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Game Core Package
|
||||
|
||||
This package contains the core game mechanics and entities.
|
||||
"""
|
||||
|
||||
from .snake import Snake, Direction
|
||||
from .food import Food
|
||||
from .game_session import GameSession
|
||||
|
||||
__all__ = [
|
||||
'Snake',
|
||||
'Direction',
|
||||
'Food',
|
||||
'GameSession'
|
||||
]
|
82
src/core/food.py
Normal file
82
src/core/food.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""
|
||||
Food Module
|
||||
|
||||
This module provides the Food class that represents the food item in the game.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
from src.ui import GameArea
|
||||
from src.config.colors import *
|
||||
|
||||
class Food:
|
||||
def __init__(self, color: Tuple[int, int, int] = NEON_RED):
|
||||
"""
|
||||
Initialize food with a color.
|
||||
|
||||
Args:
|
||||
color: RGB color tuple for the food
|
||||
"""
|
||||
self.color = color
|
||||
self.position = (0, 0) # Will be set by spawn_at_position()
|
||||
|
||||
def spawn_at_position(self, available_positions: List[Tuple[int, int]]) -> None:
|
||||
"""
|
||||
Spawn food at a random position from the available positions.
|
||||
|
||||
Args:
|
||||
available_positions: List of valid grid positions where food can spawn
|
||||
"""
|
||||
if not available_positions:
|
||||
# No valid positions (snake fills screen) - game should be won
|
||||
return
|
||||
|
||||
# Choose random position from valid positions
|
||||
self.position = random.choice(available_positions)
|
||||
|
||||
def draw(self, screen: pygame.Surface, game_area: GameArea) -> None:
|
||||
"""
|
||||
Draw the food on the screen with a glowing effect
|
||||
|
||||
Args:
|
||||
screen: Pygame surface to draw on
|
||||
game_area: GameArea instance for coordinate conversion
|
||||
"""
|
||||
screen_x, screen_y = game_area.get_screen_pos(self.position[0], self.position[1])
|
||||
|
||||
# Draw outer glow
|
||||
glow_size = game_area.grid_size + 4
|
||||
glow_surface = pygame.Surface((glow_size, glow_size), pygame.SRCALPHA)
|
||||
|
||||
# Create radial gradient for glow
|
||||
for i in range(3):
|
||||
alpha = 100 - (i * 30) # Fade out alpha
|
||||
size = glow_size - (i * 2)
|
||||
pos = (glow_size - size) // 2
|
||||
pygame.draw.circle(glow_surface, (*NEON_RED[:3], alpha), (glow_size//2, glow_size//2), size//2)
|
||||
|
||||
# Draw glow
|
||||
screen.blit(glow_surface, (screen_x - 2, screen_y - 2))
|
||||
|
||||
# Draw main food body
|
||||
food_rect = pygame.Rect(
|
||||
screen_x + 2,
|
||||
screen_y + 2,
|
||||
game_area.grid_size - 4,
|
||||
game_area.grid_size - 4
|
||||
)
|
||||
pygame.draw.rect(screen, NEON_RED, food_rect, border_radius=4)
|
||||
|
||||
# Draw highlight
|
||||
highlight_rect = pygame.Rect(
|
||||
screen_x + 4,
|
||||
screen_y + 4,
|
||||
game_area.grid_size - 8,
|
||||
game_area.grid_size - 8
|
||||
)
|
||||
pygame.draw.rect(screen, (*NEON_RED, 200), highlight_rect, border_radius=3)
|
||||
|
||||
def check_collision(self, position: Tuple[int, int]) -> bool:
|
||||
"""Check if the given position collides with the food"""
|
||||
return self.position == position
|
239
src/core/game_session.py
Normal file
239
src/core/game_session.py
Normal file
@ -0,0 +1,239 @@
|
||||
"""
|
||||
Game Session Manager
|
||||
|
||||
This module provides a reusable game session class that encapsulates the core gameplay mechanics.
|
||||
It can be used by both the main game and the training environment to ensure consistent behavior.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from typing import Tuple, Dict, Optional
|
||||
from src.core import Snake, Direction, Food
|
||||
from src.ui import GameArea
|
||||
from src.config import (
|
||||
WINDOW_WIDTH,
|
||||
WINDOW_HEIGHT,
|
||||
FPS
|
||||
)
|
||||
from src.config.colors import *
|
||||
from src.config.settings import GameRules
|
||||
|
||||
class GameSession:
|
||||
"""Manages a single game session with consistent rules and boundaries."""
|
||||
|
||||
def __init__(self, grid_width: int, grid_height: int, rules : GameRules, window_width: int = WINDOW_WIDTH, window_height: int = WINDOW_HEIGHT):
|
||||
"""
|
||||
Initialize a new game session.
|
||||
|
||||
Args:
|
||||
grid_width: Number of grid cells horizontally
|
||||
grid_height: Number of grid cells vertically
|
||||
rules: GameRules instance containing game settings
|
||||
window_width: Width of the game window
|
||||
window_height: Height of the game window
|
||||
"""
|
||||
self.grid_width = grid_width
|
||||
self.grid_height = grid_height
|
||||
self.rules = rules
|
||||
self.window_width = window_width
|
||||
self.window_height = window_height
|
||||
|
||||
# Game objects
|
||||
self.snake = None
|
||||
self.food = None
|
||||
|
||||
# Game state
|
||||
self.score = 0
|
||||
self.is_game_over = False
|
||||
self.move_cooldown = rules.initial_move_cooldown
|
||||
self.last_move_time = 0
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> Dict:
|
||||
"""
|
||||
Reset the game session to its initial state, including the game area.
|
||||
|
||||
Returns:
|
||||
Dict containing the initial game state
|
||||
"""
|
||||
# Initialize game area with window dimensions
|
||||
self.game_area = GameArea(
|
||||
window_width=self.window_width,
|
||||
window_height=self.window_height,
|
||||
grid_width=self.grid_width,
|
||||
grid_height=self.grid_height
|
||||
)
|
||||
|
||||
# Initialize snake in the middle of the grid
|
||||
start_x = self.grid_width // 2
|
||||
start_y = self.grid_height // 2
|
||||
self.snake = Snake(
|
||||
start_pos=(start_x, start_y),
|
||||
grid_width=self.grid_width,
|
||||
grid_height=self.grid_height
|
||||
)
|
||||
|
||||
# Initialize food
|
||||
self.food = Food()
|
||||
self._spawn_food()
|
||||
|
||||
# Reset game state
|
||||
self.score = 0
|
||||
self.is_game_over = False
|
||||
self.move_cooldown = self.rules.initial_move_cooldown
|
||||
self.last_move_time = 0
|
||||
|
||||
return self.get_state()
|
||||
|
||||
def _spawn_food(self) -> None:
|
||||
"""Spawn food in a random empty grid cell."""
|
||||
# Get all occupied positions
|
||||
occupied = set(self.snake.body)
|
||||
|
||||
# Get all possible positions
|
||||
all_positions = [(x, y) for x in range(self.grid_width)
|
||||
for y in range(self.grid_height)]
|
||||
|
||||
# Filter out occupied positions
|
||||
available = [pos for pos in all_positions if pos not in occupied]
|
||||
|
||||
if available:
|
||||
self.food.spawn_at_position(available)
|
||||
|
||||
def step(self, action: Optional[Direction] = None, current_time: int = 0) -> Tuple[Dict, float, bool]:
|
||||
"""
|
||||
Advance the game state by one step.
|
||||
|
||||
Args:
|
||||
action: Optional direction to change to
|
||||
current_time: Current game time in milliseconds
|
||||
|
||||
Returns:
|
||||
Tuple of (game_state, reward, done)
|
||||
"""
|
||||
if self.is_game_over:
|
||||
return self.get_state(), 0, True
|
||||
|
||||
# Apply action if provided
|
||||
if action is not None:
|
||||
self.snake.change_direction(action)
|
||||
|
||||
# Update snake movement animation
|
||||
self.snake.update_movement(1.0 / FPS) #
|
||||
|
||||
# Check if enough time has passed for next move
|
||||
if current_time - self.last_move_time < self.move_cooldown:
|
||||
return self.get_state(), 0, False
|
||||
|
||||
self.last_move_time = current_time
|
||||
|
||||
# Process any pending direction changes first
|
||||
if self.snake.input_buffer:
|
||||
# Try first direction
|
||||
next_direction = self.snake.input_buffer[0]
|
||||
if self.snake.is_direction_valid(next_direction):
|
||||
self.snake.direction = next_direction
|
||||
self.snake.input_buffer.pop(0)
|
||||
# If invalid and we have another input, try that one first
|
||||
elif len(self.snake.input_buffer) > 1:
|
||||
alt_direction = self.snake.input_buffer[1]
|
||||
if self.snake.is_direction_valid(alt_direction):
|
||||
self.snake.direction = alt_direction
|
||||
self.snake.input_buffer.pop(1) # Remove the successful second input
|
||||
# Remove the first input regardless as it's had its chance
|
||||
self.snake.input_buffer.pop(0)
|
||||
else:
|
||||
# Single invalid input, just remove it
|
||||
self.snake.input_buffer.pop(0)
|
||||
|
||||
# Now get movement vector with updated direction
|
||||
dx, dy = self.snake.direction.to_vector()
|
||||
new_head = (self.snake.body[0][0] + dx, self.snake.body[0][1] + dy)
|
||||
|
||||
# Check for immediate wall collision (before moving)
|
||||
if not self.rules.wrap_around:
|
||||
if (new_head[0] < 0 or new_head[0] >= self.grid_width or
|
||||
new_head[1] < 0 or new_head[1] >= self.grid_height):
|
||||
self.is_game_over = True
|
||||
return self.get_state(), -1, True
|
||||
|
||||
# check for collision with snake body
|
||||
if new_head in self.snake.body[1:]:
|
||||
self.is_game_over = True
|
||||
return self.get_state(), -1, True
|
||||
|
||||
# Move snake
|
||||
head_x, head_y = self.snake.move()
|
||||
|
||||
# Handle wrap-around if enabled
|
||||
if self.rules.wrap_around:
|
||||
head_x = head_x % self.grid_width
|
||||
head_y = head_y % self.grid_height
|
||||
self.snake.body[0] = (head_x, head_y)
|
||||
self.snake.target_positions[0] = (head_x, head_y) # Update target position for smooth movement
|
||||
|
||||
# Check self collision
|
||||
if (head_x, head_y) in self.snake.body[1:]:
|
||||
self.is_game_over = True
|
||||
return self.get_state(), -1, True
|
||||
|
||||
# Check food collision
|
||||
if (head_x, head_y) == self.food.position:
|
||||
self.score += 1
|
||||
self.snake.grow()
|
||||
self._spawn_food()
|
||||
|
||||
# Increase speed if enabled
|
||||
if self.rules.speed_increase:
|
||||
self.move_cooldown = max(
|
||||
self.rules.min_move_cooldown,
|
||||
self.move_cooldown - 2
|
||||
)
|
||||
return self.get_state(), 1, False
|
||||
|
||||
return self.get_state(), 0, False
|
||||
|
||||
def get_state(self) -> Dict:
|
||||
"""
|
||||
Get the current game state.
|
||||
|
||||
Returns:
|
||||
Dict containing the game state
|
||||
"""
|
||||
return {
|
||||
"grid_width": self.grid_width,
|
||||
"grid_height": self.grid_height,
|
||||
"snake_head": self.snake.body[0],
|
||||
"snake_body": self.snake.body,
|
||||
"snake_direction": self.snake.direction,
|
||||
"food_position": self.food.position,
|
||||
"score": self.score,
|
||||
"game_over": self.is_game_over,
|
||||
"move_cooldown": self.move_cooldown
|
||||
}
|
||||
|
||||
def render(self, screen: pygame.Surface) -> None:
|
||||
"""
|
||||
Render the game state to the screen.
|
||||
|
||||
Args:
|
||||
screen: Pygame surface to render to
|
||||
"""
|
||||
# Draw game area first
|
||||
self.game_area.draw(screen)
|
||||
|
||||
# Draw game objects
|
||||
self.snake.draw(screen, self.game_area)
|
||||
self.food.draw(screen, self.game_area)
|
||||
|
||||
# Draw grid (optional, for debugging)
|
||||
if hasattr(self, 'show_grid') and self.show_grid:
|
||||
for x in range(self.grid_width + 1):
|
||||
start_pos = (self.game_area.x + x * self.game_area.grid_size, self.game_area.y)
|
||||
end_pos = (self.game_area.x + x * self.game_area.grid_size, self.game_area.y + self.game_area.height)
|
||||
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
|
||||
|
||||
for y in range(self.grid_height + 1):
|
||||
start_pos = (self.game_area.x, self.game_area.y + y * self.game_area.grid_size)
|
||||
end_pos = (self.game_area.x + self.game_area.width, self.game_area.y + y * self.game_area.grid_size)
|
||||
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
|
357
src/core/snake.py
Normal file
357
src/core/snake.py
Normal file
@ -0,0 +1,357 @@
|
||||
"""
|
||||
Snake Module
|
||||
|
||||
This module provides the Snake class and Direction enum for snake movement and rendering.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from enum import Enum, auto
|
||||
from typing import Tuple, List
|
||||
import math
|
||||
from src.config.colors import *
|
||||
import pygame.gfxdraw
|
||||
|
||||
class Direction(Enum):
|
||||
UP = auto()
|
||||
DOWN = auto()
|
||||
LEFT = auto()
|
||||
RIGHT = auto()
|
||||
|
||||
def to_vector(self) -> Tuple[int, int]:
|
||||
# Convert direction to a movement vector
|
||||
if self == Direction.UP:
|
||||
return (0, -1)
|
||||
elif self == Direction.DOWN:
|
||||
return (0, 1)
|
||||
elif self == Direction.LEFT:
|
||||
return (-1, 0)
|
||||
else: # Direction.RIGHT
|
||||
return (1, 0)
|
||||
|
||||
class Snake:
|
||||
def __init__(self, start_pos: Tuple[int, int], grid_width: int, grid_height: int, length: int = 3):
|
||||
"""
|
||||
Initialize snake at the given grid position.
|
||||
|
||||
Args:
|
||||
start_pos: Starting position in grid coordinates (x, y)
|
||||
grid_width: Width of the game grid
|
||||
grid_height: Height of the game grid
|
||||
length: Initial length of the snake
|
||||
"""
|
||||
self.direction = Direction.RIGHT
|
||||
self.body = [start_pos] # Head is at index 0
|
||||
self.growing = False
|
||||
self.grid_width = grid_width
|
||||
self.grid_height = grid_height
|
||||
|
||||
# FIX: Start from 1 to avoid duplicate head segment.
|
||||
for i in range(1, length):
|
||||
self.body.append((start_pos[0] - i, start_pos[1]))
|
||||
|
||||
# Initialize visual positions for each body segment
|
||||
self.visual_positions = [(float(pos[0]), float(pos[1])) for pos in self.body]
|
||||
self.move_progress = 1.0 # Progress of current move (0.0 to 1.0)
|
||||
self.move_speed = 0.2 # Movement speed
|
||||
self.target_positions = list(self.body) # Target grid positions
|
||||
|
||||
# Input buffer for direction changes
|
||||
self.input_buffer = [] # Store up to 2 pending direction changes
|
||||
self.max_buffer_size = 2
|
||||
|
||||
# Create base segment surface with gradient
|
||||
self.segment_size = 32 # Base size for segment surface
|
||||
self.base_surface = pygame.Surface((self.segment_size, self.segment_size), pygame.SRCALPHA)
|
||||
|
||||
# Create a radial gradient for the base segment
|
||||
center = self.segment_size // 2
|
||||
for radius in range(center, -1, -1):
|
||||
# Calculate color intensity based on distance from center
|
||||
intensity = 1.0 - (radius / center) ** 0.8 # Adjust power for gradient shape
|
||||
color = (
|
||||
int(SNAKE_GRADIENT_START[0] * intensity),
|
||||
int(SNAKE_GRADIENT_START[1] * intensity),
|
||||
int(SNAKE_GRADIENT_START[2] * intensity),
|
||||
255
|
||||
)
|
||||
pygame.draw.circle(self.base_surface, color, (center, center), radius)
|
||||
|
||||
def move(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Move the snake one grid cell in current direction.
|
||||
|
||||
Returns:
|
||||
New head position (x, y)
|
||||
"""
|
||||
# Get new head position
|
||||
head = self.body[0]
|
||||
new_head = self._get_new_head_position(head)
|
||||
|
||||
# Update grid positions
|
||||
self.body.insert(0, new_head)
|
||||
if not self.growing:
|
||||
self.body.pop()
|
||||
else:
|
||||
self.growing = False
|
||||
|
||||
# Update visual positions
|
||||
self.target_positions = list(self.body) # Copy current grid positions as targets
|
||||
if len(self.visual_positions) < len(self.target_positions):
|
||||
self.visual_positions.append(self.visual_positions[-1])
|
||||
|
||||
# Reset move progress for smooth animation
|
||||
self.move_progress = 0.0
|
||||
|
||||
return new_head
|
||||
|
||||
def update_movement(self, dt: float):
|
||||
"""Update visual position interpolation with consistent timing."""
|
||||
if self.move_progress >= 1.0:
|
||||
return
|
||||
|
||||
# Use cubic easing for smoother acceleration/deceleration
|
||||
self.move_progress = min(1.0, self.move_progress + self.move_speed)
|
||||
t = self.move_progress
|
||||
overall_progress = t * t * (3 - 2 * t)
|
||||
|
||||
# Update each segment with fixed delay windows
|
||||
segment_count = len(self.visual_positions)
|
||||
for i in range(segment_count):
|
||||
if i >= len(self.target_positions):
|
||||
continue
|
||||
|
||||
current = self.visual_positions[i]
|
||||
target = self.target_positions[i]
|
||||
|
||||
dx = target[0] - current[0]
|
||||
dy = target[1] - current[1]
|
||||
|
||||
# Check if this segment needs to wrap
|
||||
wrap_x = abs(dx) > 2
|
||||
wrap_y = abs(dy) > 2
|
||||
|
||||
if wrap_x or wrap_y:
|
||||
self.visual_positions[i] = target
|
||||
else:
|
||||
# Use fixed delay windows instead of compounding delays
|
||||
delay_window = 0.2 # Total delay spread across all segments
|
||||
segment_delay = (i / segment_count) * delay_window
|
||||
|
||||
# Calculate progress for this segment
|
||||
segment_progress = max(0.0, (overall_progress - segment_delay) / (1.0 - segment_delay))
|
||||
segment_progress = min(1.0, segment_progress)
|
||||
|
||||
# Apply smoothed movement
|
||||
new_x = current[0] + dx * segment_progress
|
||||
new_y = current[1] + dy * segment_progress
|
||||
self.visual_positions[i] = (new_x, new_y)
|
||||
|
||||
def _get_new_head_position(self, head: Tuple[int, int]) -> Tuple[int, int]:
|
||||
"""Calculate new head position based on current direction"""
|
||||
x, y = head
|
||||
dx, dy = self.direction.to_vector()
|
||||
return (x + dx, y + dy)
|
||||
|
||||
def grow(self):
|
||||
"""Mark the snake to grow on next move"""
|
||||
self.growing = True
|
||||
|
||||
def check_collision(self, width: int, height: int, wrap_around: bool = False) -> bool:
|
||||
"""
|
||||
Check if snake has collided with walls or itself.
|
||||
|
||||
Args:
|
||||
width: Game area width
|
||||
height: Game area height
|
||||
wrap_around: If True, snake wraps around screen edges instead of colliding
|
||||
"""
|
||||
head = self.body[0]
|
||||
|
||||
if not wrap_around:
|
||||
# Check wall collision
|
||||
if (head[0] < 0 or head[0] >= width or
|
||||
head[1] < 0 or head[1] >= height):
|
||||
return True
|
||||
|
||||
# Check self collision (skip head)
|
||||
if head in self.body[1:]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def wrap_position(self, width: int, height: int):
|
||||
"""Wrap snake's head position around screen edges"""
|
||||
head_x, head_y = self.body[0]
|
||||
wrapped_head = (
|
||||
head_x % width,
|
||||
head_y % height
|
||||
)
|
||||
self.body[0] = wrapped_head
|
||||
|
||||
def is_direction_valid(self, new_direction: Direction) -> bool:
|
||||
"""
|
||||
Try to change direction, ensuring no 180-degree turns.
|
||||
|
||||
Args:
|
||||
new_direction: The direction to change to
|
||||
|
||||
Returns:
|
||||
True if direction was changed, False otherwise
|
||||
"""
|
||||
opposite_directions = {
|
||||
Direction.UP: Direction.DOWN,
|
||||
Direction.DOWN: Direction.UP,
|
||||
Direction.LEFT: Direction.RIGHT,
|
||||
Direction.RIGHT: Direction.LEFT
|
||||
}
|
||||
|
||||
if opposite_directions[new_direction] == self.direction or new_direction == self.direction:
|
||||
return False
|
||||
|
||||
self.direction = new_direction
|
||||
return True
|
||||
|
||||
def change_direction(self, new_direction: Direction):
|
||||
"""
|
||||
Buffer a direction change to be applied on next move.
|
||||
|
||||
Args:
|
||||
new_direction: The direction to change to
|
||||
"""
|
||||
if self.input_buffer and self.input_buffer[-1] == new_direction:
|
||||
return
|
||||
|
||||
if len(self.input_buffer) < self.max_buffer_size:
|
||||
self.input_buffer.append(new_direction)
|
||||
|
||||
def draw(self, screen: pygame.Surface, game_area):
|
||||
"""Draw snake with clean, connected design and cute tongue."""
|
||||
if len(self.visual_positions) < 2:
|
||||
return
|
||||
|
||||
base_size = game_area.grid_size
|
||||
segment_size = base_size - 4
|
||||
|
||||
# Draw snake body segments
|
||||
for i in range(len(self.visual_positions) - 1):
|
||||
pos1 = self.visual_positions[i]
|
||||
pos2 = self.visual_positions[i + 1]
|
||||
|
||||
screen_positions = self._get_wrapped_segment_positions(pos1, pos2, game_area)
|
||||
if not screen_positions:
|
||||
continue
|
||||
|
||||
for sp1, sp2 in screen_positions:
|
||||
# Calculate segment properties
|
||||
progress = max(0.4, 1.0 - (i / len(self.visual_positions)))
|
||||
|
||||
color = (
|
||||
int(SNAKE_GRADIENT_START[0] * progress),
|
||||
int(SNAKE_GRADIENT_START[1] * progress),
|
||||
int(SNAKE_GRADIENT_START[2] * progress)
|
||||
)
|
||||
|
||||
# Draw connecting segments with extra thickness
|
||||
center1 = (sp1[0] + base_size//2, sp1[1] + base_size//2)
|
||||
center2 = (sp2[0] + base_size//2, sp2[1] + base_size//2)
|
||||
|
||||
# Draw a thicker base line first
|
||||
pygame.draw.line(screen, color, center1, center2, segment_size + 4)
|
||||
|
||||
# Draw larger circles at joints to ensure corner coverage
|
||||
pygame.draw.circle(screen, color, center1, (segment_size + 2)//2)
|
||||
pygame.draw.circle(screen, color, center2, (segment_size + 2)//2)
|
||||
|
||||
# If this is a corner (check by looking at next segment)
|
||||
if i < len(self.visual_positions) - 2:
|
||||
pos3 = self.visual_positions[i + 2]
|
||||
if pos2[0] != pos1[0] and pos3[1] != pos2[1] or \
|
||||
pos2[1] != pos1[1] and pos3[0] != pos2[0]:
|
||||
# Draw extra circle at the corner to ensure coverage
|
||||
pygame.draw.circle(screen, color, center2, (segment_size + 6)//2)
|
||||
|
||||
# Draw head
|
||||
head_pos = self.visual_positions[0]
|
||||
screen_pos = game_area.get_screen_pos(head_pos[0], head_pos[1])
|
||||
center = (screen_pos[0] + base_size//2, screen_pos[1] + base_size//2)
|
||||
|
||||
# Draw head circle
|
||||
pygame.draw.circle(screen, SNAKE_GRADIENT_START, center, segment_size//2)
|
||||
|
||||
# Calculate direction for eyes and tongue
|
||||
next_pos = self.visual_positions[1] if len(self.visual_positions) > 1 else head_pos
|
||||
dx = head_pos[0] - next_pos[0] # Reversed the direction calculation
|
||||
dy = head_pos[1] - next_pos[1] # Reversed the direction calculation
|
||||
angle = math.atan2(dy, dx)
|
||||
|
||||
# Draw eyes
|
||||
eye_offset = segment_size//4
|
||||
eye_size = 2
|
||||
eye_angle = angle + math.pi/2
|
||||
|
||||
for side in [-1, 1]:
|
||||
eye_x = center[0] + math.cos(eye_angle) * eye_offset * side
|
||||
eye_y = center[1] + math.sin(eye_angle) * eye_offset
|
||||
pygame.draw.circle(screen, (0, 0, 0), (int(eye_x), int(eye_y)), eye_size)
|
||||
|
||||
# Draw flickering tongue with more pronounced animation
|
||||
time = pygame.time.get_ticks()
|
||||
tongue_flick = math.sin(time * 0.01) * 0.8 + 0.2 # Slower, more pronounced flicking
|
||||
|
||||
# Tongue base position (at front of head)
|
||||
tongue_base_x = center[0] + math.cos(angle) * (segment_size//2)
|
||||
tongue_base_y = center[1] + math.sin(angle) * (segment_size//2)
|
||||
|
||||
# Tongue length varies with flicking animation
|
||||
tongue_length = (3 + tongue_flick * 6) # Length varies between 3 and 9 pixels
|
||||
fork_length = 4 # Slightly longer fork
|
||||
fork_angle = math.pi/3 # Wider fork angle
|
||||
|
||||
# Calculate tongue tip
|
||||
tongue_tip_x = tongue_base_x + math.cos(angle) * tongue_length
|
||||
tongue_tip_y = tongue_base_y + math.sin(angle) * tongue_length
|
||||
|
||||
# Calculate fork tips
|
||||
left_fork_x = tongue_tip_x + math.cos(angle + fork_angle) * fork_length
|
||||
left_fork_y = tongue_tip_y + math.sin(angle + fork_angle) * fork_length
|
||||
right_fork_x = tongue_tip_x + math.cos(angle - fork_angle) * fork_length
|
||||
right_fork_y = tongue_tip_y + math.sin(angle - fork_angle) * fork_length
|
||||
|
||||
# Draw tongue with brighter red
|
||||
tongue_color = (255, 0, 0) # Bright red
|
||||
|
||||
# Draw main tongue line
|
||||
pygame.draw.line(screen, tongue_color,
|
||||
(tongue_base_x, tongue_base_y),
|
||||
(tongue_tip_x, tongue_tip_y), 2)
|
||||
|
||||
# Draw forked tips
|
||||
pygame.draw.line(screen, tongue_color,
|
||||
(tongue_tip_x, tongue_tip_y),
|
||||
(left_fork_x, left_fork_y), 2)
|
||||
pygame.draw.line(screen, tongue_color,
|
||||
(tongue_tip_x, tongue_tip_y),
|
||||
(right_fork_x, right_fork_y), 2)
|
||||
|
||||
def _get_wrapped_segment_positions(self, pos1, pos2, game_area):
|
||||
"""
|
||||
Get screen positions for segment rendering, handling wrap-around.
|
||||
Only returns positions for segments that are close enough to interpolate.
|
||||
"""
|
||||
grid_w = self.grid_width
|
||||
grid_h = self.grid_height
|
||||
|
||||
# Calculate primary direction
|
||||
dx = pos2[0] - pos1[0]
|
||||
dy = pos2[1] - pos1[1]
|
||||
|
||||
# If segments are too far apart (wrapping), don't draw connecting segment
|
||||
if abs(dx) > 2 or abs(dy) > 2:
|
||||
return []
|
||||
|
||||
# Convert to screen coordinates
|
||||
screen_pos1 = game_area.get_screen_pos(pos1[0], pos1[1])
|
||||
screen_pos2 = game_area.get_screen_pos(pos2[0], pos2[1])
|
||||
|
||||
return [(screen_pos1, screen_pos2)]
|
||||
|
50
src/food.py
50
src/food.py
@ -1,50 +0,0 @@
|
||||
import pygame
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
|
||||
class Food:
|
||||
def __init__(self, block_size: int, color: Tuple[int, int, int] = (255, 0, 0)):
|
||||
self.block_size = block_size
|
||||
self.color = color
|
||||
self.position = (0, 0) # Will be set by spawn()
|
||||
|
||||
def spawn(self, width: int, height: int, occupied_positions: List[Tuple[int, int]]) -> None:
|
||||
"""
|
||||
Spawn food at a random position that isn't occupied by the snake.
|
||||
|
||||
Args:
|
||||
width: Game area width
|
||||
height: Game area height
|
||||
occupied_positions: List of positions (typically snake body positions) where food shouldn't spawn
|
||||
"""
|
||||
# Calculate all valid grid positions
|
||||
valid_positions = [
|
||||
(x, y)
|
||||
for x in range(0, width, self.block_size)
|
||||
for y in range(0, height, self.block_size)
|
||||
if (x, y) not in occupied_positions
|
||||
]
|
||||
|
||||
if not valid_positions:
|
||||
# No valid positions (snake fills screen) - game should be won
|
||||
return
|
||||
|
||||
# Choose random position from valid positions
|
||||
self.position = random.choice(valid_positions)
|
||||
|
||||
def draw(self, screen: pygame.Surface) -> None:
|
||||
"""Draw the food on the screen"""
|
||||
pygame.draw.rect(
|
||||
screen,
|
||||
self.color,
|
||||
pygame.Rect(
|
||||
self.position[0],
|
||||
self.position[1],
|
||||
self.block_size,
|
||||
self.block_size
|
||||
)
|
||||
)
|
||||
|
||||
def check_collision(self, position: Tuple[int, int]) -> bool:
|
||||
"""Check if the given position collides with the food"""
|
||||
return self.position == position
|
399
src/game.py
399
src/game.py
@ -1,190 +1,109 @@
|
||||
"""
|
||||
Main Game Module
|
||||
|
||||
This module provides the main game class that handles the game loop,
|
||||
state management, and user interface.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
import sys
|
||||
from enum import Enum, auto
|
||||
from src.snake import Snake, Direction
|
||||
from src.food import Food
|
||||
from src.menu import Menu, GameMode, MenuItem
|
||||
|
||||
from src.config import (
|
||||
WINDOW_WIDTH,
|
||||
WINDOW_HEIGHT,
|
||||
FPS,
|
||||
SCORE_FONT_SIZE
|
||||
)
|
||||
from src.config import GameSettings
|
||||
from src.config.colors import BLACK, GREEN, NEON_GREEN
|
||||
from src.core import Direction, GameSession
|
||||
from src.ui import Menu, GameMode, SettingsMenu, PauseMenu, GameArea
|
||||
|
||||
class GameState(Enum):
|
||||
"""Game states."""
|
||||
MENU = auto()
|
||||
SETTINGS = auto() # New state for settings menu
|
||||
SETTINGS = auto()
|
||||
PLAYING = auto()
|
||||
GAME_OVER = auto()
|
||||
PAUSED = auto()
|
||||
|
||||
class GameRules:
|
||||
def __init__(self):
|
||||
self.wrap_around = False # Whether snake wraps around screen edges
|
||||
self.speed_increase = True # Whether snake speeds up as it grows
|
||||
self.min_move_cooldown = 50 # Minimum movement delay in milliseconds
|
||||
self.initial_move_cooldown = 100 # Initial movement delay
|
||||
|
||||
class SettingsMenu:
|
||||
def __init__(self, width, height, rules):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.rules = rules
|
||||
self.setup_menu_items()
|
||||
|
||||
self.title_font = pygame.font.Font(None, 72)
|
||||
self.subtitle_font = pygame.font.Font(None, 36)
|
||||
|
||||
# Create title surfaces
|
||||
self.title_surface = self.title_font.render("Settings", True, (0, 255, 0))
|
||||
self.title_rect = self.title_surface.get_rect(center=(width//2, height//4))
|
||||
|
||||
# Initialize first item as selected
|
||||
self.selected_index = 0
|
||||
self.menu_items[0].hover = True
|
||||
self.menu_items[0]._setup_font()
|
||||
|
||||
def setup_menu_items(self):
|
||||
start_y = self.height // 2
|
||||
spacing = 50
|
||||
center_x = self.width // 2
|
||||
|
||||
self.menu_items = [
|
||||
MenuItem(f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}",
|
||||
(center_x, start_y),
|
||||
'toggle_wrap'),
|
||||
MenuItem(f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}",
|
||||
(center_x, start_y + spacing),
|
||||
'toggle_speed'),
|
||||
MenuItem("Back to Menu",
|
||||
(center_x, start_y + spacing * 3),
|
||||
'back')
|
||||
]
|
||||
|
||||
def update(self):
|
||||
# Handle mouse hover
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
# Update selected index when mouse hovers
|
||||
self.selected_index = i
|
||||
item.hover = True
|
||||
item._setup_font()
|
||||
else:
|
||||
# Keep keyboard selection visible
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
def handle_input(self, event):
|
||||
if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
|
||||
# Handle mouse clicks
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
self.selected_index = i
|
||||
if item.action == 'toggle_wrap':
|
||||
self.rules.wrap_around = not self.rules.wrap_around
|
||||
item.text = f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}"
|
||||
item._setup_font()
|
||||
elif item.action == 'toggle_speed':
|
||||
self.rules.speed_increase = not self.rules.speed_increase
|
||||
item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}"
|
||||
item._setup_font()
|
||||
elif item.action == 'back':
|
||||
return 'back'
|
||||
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
return 'back'
|
||||
elif event.key == pygame.K_RETURN:
|
||||
item = self.menu_items[self.selected_index]
|
||||
if item.action == 'toggle_wrap':
|
||||
self.rules.wrap_around = not self.rules.wrap_around
|
||||
item.text = f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}"
|
||||
item._setup_font()
|
||||
elif item.action == 'toggle_speed':
|
||||
self.rules.speed_increase = not self.rules.speed_increase
|
||||
item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}"
|
||||
item._setup_font()
|
||||
elif item.action == 'back':
|
||||
return 'back'
|
||||
elif event.key in (pygame.K_UP, pygame.K_DOWN):
|
||||
# Update selected index
|
||||
if event.key == pygame.K_UP:
|
||||
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
|
||||
else:
|
||||
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
|
||||
|
||||
# Update hover states
|
||||
for i, item in enumerate(self.menu_items):
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
return None
|
||||
|
||||
def draw(self, screen):
|
||||
# Draw background
|
||||
screen.fill((0, 0, 0))
|
||||
|
||||
# Draw title
|
||||
screen.blit(self.title_surface, self.title_rect)
|
||||
|
||||
# Draw menu items
|
||||
for item in self.menu_items:
|
||||
item.draw(screen)
|
||||
|
||||
# Draw controls
|
||||
controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to go back"
|
||||
font = pygame.font.Font(None, 24)
|
||||
controls_surface = font.render(controls_text, True, (100, 100, 100))
|
||||
screen.blit(controls_surface,
|
||||
(self.width - controls_surface.get_width() - 10,
|
||||
self.height - 30))
|
||||
|
||||
class Game:
|
||||
"""Main game class that manages the game loop and state."""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize pygame and font system
|
||||
if not pygame.get_init():
|
||||
pygame.init()
|
||||
if not pygame.font.get_init():
|
||||
pygame.font.init()
|
||||
"""Initialize the game."""
|
||||
pygame.init()
|
||||
pygame.font.init()
|
||||
|
||||
# Initialize display
|
||||
self.width = 800
|
||||
self.height = 600
|
||||
# Window setup
|
||||
self.width = WINDOW_WIDTH
|
||||
self.height = WINDOW_HEIGHT
|
||||
self.screen = pygame.display.set_mode((self.width, self.height))
|
||||
pygame.display.set_caption("AI Snake Game")
|
||||
pygame.display.set_caption("Snake Game")
|
||||
|
||||
# Initialize clock
|
||||
# Initialize game components
|
||||
self.clock = pygame.time.Clock()
|
||||
self.fps = 60
|
||||
|
||||
# Game rules
|
||||
self.rules = GameRules()
|
||||
|
||||
# Game objects
|
||||
self.block_size = 20
|
||||
self.menu = Menu(self.width, self.height)
|
||||
self.reset_game()
|
||||
self.font = pygame.font.Font(None, SCORE_FONT_SIZE)
|
||||
self.settings = GameSettings()
|
||||
self.menu = Menu()
|
||||
self.settings_menu = SettingsMenu(rules=self.settings.rules)
|
||||
self.pause_menu = PauseMenu()
|
||||
|
||||
# Game state
|
||||
self.state = GameState.MENU # Start in menu state
|
||||
self.running = True
|
||||
self.state = GameState.MENU
|
||||
self.game_session = None
|
||||
self.game_mode = None
|
||||
self.running = True
|
||||
|
||||
# Debug mode
|
||||
self.debug = False
|
||||
|
||||
# Add settings menu
|
||||
self.settings_menu = SettingsMenu(self.width, self.height, self.rules)
|
||||
self.reset_game()
|
||||
|
||||
def reset_game(self):
|
||||
"""Reset the game state for a new game"""
|
||||
self.snake = Snake((self.width // 2, self.height // 2), self.block_size)
|
||||
self.food = Food(self.block_size)
|
||||
self.food.spawn(self.width, self.height, self.snake.body)
|
||||
self.score = 0
|
||||
"""Initialize or reset the game session."""
|
||||
if self.game_session:
|
||||
del self.game_session
|
||||
self.game_session = GameSession(
|
||||
grid_width=30,
|
||||
grid_height=30,
|
||||
rules=self.settings.rules,
|
||||
window_width=self.width,
|
||||
window_height=self.height
|
||||
)
|
||||
|
||||
def handle_events(self):
|
||||
"""Handle game events."""
|
||||
current_time = pygame.time.get_ticks()
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self.running = False
|
||||
elif event.type == pygame.MOUSEBUTTONDOWN:
|
||||
# Handle mouse clicks in menus
|
||||
if self.state == GameState.MENU:
|
||||
selected_mode = self.menu.handle_input(event)
|
||||
if selected_mode is not None:
|
||||
if selected_mode == GameMode.SETTINGS:
|
||||
self.state = GameState.SETTINGS
|
||||
elif selected_mode == GameMode.PLAYER:
|
||||
self.game_mode = GameMode.PLAYER
|
||||
self.state = GameState.PLAYING
|
||||
self.reset_game()
|
||||
elif selected_mode in (GameMode.AI_EASY, GameMode.AI_MEDIUM, GameMode.AI_HARD):
|
||||
self.game_mode = selected_mode
|
||||
self.state = GameState.PLAYING
|
||||
self.reset_game()
|
||||
else: # Quit selected
|
||||
self.running = False
|
||||
elif self.state == GameState.SETTINGS:
|
||||
result = self.settings_menu.handle_input(event)
|
||||
if result == 'back':
|
||||
self.state = GameState.MENU
|
||||
elif self.state == GameState.PAUSED:
|
||||
result = self.pause_menu.handle_input(event)
|
||||
if result == 'resume':
|
||||
self.state = GameState.PLAYING
|
||||
elif result == 'menu':
|
||||
self.state = GameState.MENU
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
if self.state == GameState.PLAYING:
|
||||
@ -196,156 +115,108 @@ class Game:
|
||||
elif self.state == GameState.SETTINGS:
|
||||
self.state = GameState.MENU
|
||||
elif event.key == pygame.K_F3: # Toggle debug mode
|
||||
self.debug = not self.debug
|
||||
|
||||
if self.state == GameState.MENU:
|
||||
# Handle menu input
|
||||
selected_mode = self.menu.handle_input(event)
|
||||
if selected_mode is not None:
|
||||
if selected_mode == GameMode.SETTINGS:
|
||||
self.state = GameState.SETTINGS
|
||||
elif selected_mode == GameMode.PLAYER:
|
||||
self.game_mode = GameMode.PLAYER
|
||||
self.state = GameState.PLAYING
|
||||
elif selected_mode in (GameMode.AI_EASY, GameMode.AI_MEDIUM, GameMode.AI_HARD):
|
||||
self.game_mode = selected_mode
|
||||
self.state = GameState.PLAYING
|
||||
else: # Quit selected
|
||||
self.running = False
|
||||
|
||||
elif self.state == GameState.SETTINGS:
|
||||
# Handle settings menu input
|
||||
result = self.settings_menu.handle_input(event)
|
||||
if result == 'back':
|
||||
self.state = GameState.MENU
|
||||
|
||||
self.settings.debug_mode = not self.settings.debug_mode
|
||||
if self.game_session:
|
||||
self.game_session.show_grid = self.settings.debug_mode
|
||||
elif self.state == GameState.PLAYING:
|
||||
# Handle snake direction
|
||||
if event.key == pygame.K_UP:
|
||||
self.snake.change_direction(Direction.UP, current_time)
|
||||
self.game_session.step(Direction.UP, current_time)
|
||||
elif event.key == pygame.K_DOWN:
|
||||
self.snake.change_direction(Direction.DOWN, current_time)
|
||||
self.game_session.step(Direction.DOWN, current_time)
|
||||
elif event.key == pygame.K_LEFT:
|
||||
self.snake.change_direction(Direction.LEFT, current_time)
|
||||
self.game_session.step(Direction.LEFT, current_time)
|
||||
elif event.key == pygame.K_RIGHT:
|
||||
self.snake.change_direction(Direction.RIGHT, current_time)
|
||||
|
||||
self.game_session.step(Direction.RIGHT, current_time)
|
||||
elif self.state == GameState.GAME_OVER:
|
||||
# Any key to return to menu in game over state
|
||||
self.state = GameState.MENU
|
||||
self.reset_game()
|
||||
|
||||
def update(self):
|
||||
"""Update game state."""
|
||||
current_time = pygame.time.get_ticks()
|
||||
|
||||
if self.state == GameState.MENU:
|
||||
self.menu.update()
|
||||
elif self.state == GameState.SETTINGS:
|
||||
self.settings_menu.update()
|
||||
elif self.state == GameState.PAUSED:
|
||||
self.pause_menu.update()
|
||||
elif self.state == GameState.PLAYING:
|
||||
# Move snake
|
||||
if self.snake.move(pygame.time.get_ticks()):
|
||||
# Handle wrap-around
|
||||
if self.rules.wrap_around:
|
||||
self.snake.wrap_position(self.width, self.height)
|
||||
|
||||
# Check collisions
|
||||
if self.snake.check_collision(self.width, self.height, self.rules.wrap_around):
|
||||
self.state = GameState.GAME_OVER
|
||||
return
|
||||
|
||||
# Check food collision
|
||||
if self.food.check_collision(self.snake.body[0]):
|
||||
self.snake.grow()
|
||||
self.score += 1
|
||||
self.food.spawn(self.width, self.height, self.snake.body)
|
||||
|
||||
# Increase speed if enabled
|
||||
if self.rules.speed_increase:
|
||||
self.snake.move_cooldown = max(
|
||||
self.rules.min_move_cooldown,
|
||||
self.rules.initial_move_cooldown - self.score * 2
|
||||
)
|
||||
|
||||
def draw_grid(self):
|
||||
"""Draw the game grid (debug mode)"""
|
||||
for x in range(0, self.width, self.block_size):
|
||||
pygame.draw.line(self.screen, (50, 50, 50), (x, 0), (x, self.height))
|
||||
for y in range(0, self.height, self.block_size):
|
||||
pygame.draw.line(self.screen, (50, 50, 50), (0, y), (self.width, y))
|
||||
# Update game session
|
||||
_, _, done = self.game_session.step(current_time=current_time)
|
||||
if done:
|
||||
self.state = GameState.GAME_OVER
|
||||
|
||||
def draw_debug_info(self):
|
||||
"""Draw debug information"""
|
||||
"""Draw debug information."""
|
||||
if not self.game_session:
|
||||
return
|
||||
|
||||
state = self.game_session.get_state()
|
||||
font = pygame.font.Font(None, 24)
|
||||
debug_info = [
|
||||
f"FPS: {int(self.clock.get_fps())}",
|
||||
f"Snake Length: {len(self.snake.body)}",
|
||||
f"Snake Head: {self.snake.body[0]}",
|
||||
f"Food Pos: {self.food.position}",
|
||||
f"Snake Length: {len(state['snake_body'])}",
|
||||
f"Snake Head: {state['snake_body'][0]}",
|
||||
f"Food Pos: {state['food_position']}",
|
||||
f"Game State: {self.state.name}",
|
||||
f"Game Mode: {self.game_mode.name if self.game_mode else 'None'}",
|
||||
f"Wrap Around: {self.rules.wrap_around}",
|
||||
f"Move Cooldown: {self.snake.move_cooldown}ms"
|
||||
f"Move Cooldown: {state['move_cooldown']}ms"
|
||||
]
|
||||
|
||||
for i, text in enumerate(debug_info):
|
||||
surface = font.render(text, True, (0, 255, 0))
|
||||
surface = font.render(text, True, GREEN)
|
||||
self.screen.blit(surface, (10, self.height - 200 + i * 25))
|
||||
|
||||
def render(self):
|
||||
# Clear screen
|
||||
self.screen.fill((0, 0, 0)) # Black background
|
||||
def draw_overlay(self, text):
|
||||
"""Draw a semi-transparent overlay with text."""
|
||||
overlay = pygame.Surface((self.width, self.height))
|
||||
overlay.fill(BLACK)
|
||||
overlay.set_alpha(128)
|
||||
self.screen.blit(overlay, (0, 0))
|
||||
|
||||
# Draw grid in debug mode
|
||||
if self.debug and self.state not in (GameState.MENU, GameState.SETTINGS):
|
||||
self.draw_grid()
|
||||
text_surface = self.font.render(text, True, NEON_GREEN)
|
||||
text_rect = text_surface.get_rect(center=(self.width // 2, self.height // 2))
|
||||
self.screen.blit(text_surface, text_rect)
|
||||
|
||||
def render(self):
|
||||
"""Render the current game state."""
|
||||
# Clear screen
|
||||
self.screen.fill(BLACK)
|
||||
|
||||
if self.state == GameState.MENU:
|
||||
self.menu.draw(self.screen)
|
||||
elif self.state == GameState.SETTINGS:
|
||||
self.settings_menu.draw(self.screen)
|
||||
elif self.state == GameState.PLAYING or self.state == GameState.PAUSED:
|
||||
# Draw game objects
|
||||
self.snake.draw(self.screen)
|
||||
self.food.draw(self.screen)
|
||||
elif self.state in [GameState.PLAYING, GameState.PAUSED, GameState.GAME_OVER]:
|
||||
# Draw game session
|
||||
self.game_session.render(self.screen)
|
||||
|
||||
# Draw score
|
||||
font = pygame.font.Font(None, 36)
|
||||
score_text = font.render(f'Score: {self.score}', True, (255, 255, 255))
|
||||
self.screen.blit(score_text, (10, 10))
|
||||
# Draw score panel
|
||||
score_text = f"Score: {self.game_session.score}"
|
||||
score_surface = self.font.render(score_text, True, GREEN)
|
||||
score_rect = score_surface.get_rect(midtop=(self.width // 2, 20))
|
||||
self.screen.blit(score_surface, score_rect)
|
||||
|
||||
# Draw game mode
|
||||
mode_text = font.render(f'Mode: {self.game_mode.name}', True, (255, 255, 255))
|
||||
self.screen.blit(mode_text, (10, 50))
|
||||
|
||||
# Draw pause indicator
|
||||
# Draw overlays
|
||||
if self.state == GameState.PAUSED:
|
||||
pause_text = font.render('PAUSED', True, (255, 255, 255))
|
||||
text_rect = pause_text.get_rect(center=(self.width//2, self.height//2))
|
||||
self.screen.blit(pause_text, text_rect)
|
||||
|
||||
elif self.state == GameState.GAME_OVER:
|
||||
font = pygame.font.Font(None, 74)
|
||||
text = font.render('Game Over!', True, (255, 0, 0))
|
||||
score_text = font.render(f'Score: {self.score}', True, (255, 255, 255))
|
||||
restart_text = font.render('Press any key for menu', True, (255, 255, 255))
|
||||
self.pause_menu.draw(self.screen)
|
||||
elif self.state == GameState.GAME_OVER:
|
||||
self.draw_overlay("GAME OVER")
|
||||
|
||||
text_rect = text.get_rect(center=(self.width//2, self.height//2 - 50))
|
||||
score_rect = score_text.get_rect(center=(self.width//2, self.height//2 + 50))
|
||||
restart_rect = restart_text.get_rect(center=(self.width//2, self.height//2 + 150))
|
||||
|
||||
self.screen.blit(text, text_rect)
|
||||
self.screen.blit(score_text, score_rect)
|
||||
self.screen.blit(restart_text, restart_rect)
|
||||
|
||||
# Draw debug info
|
||||
if self.debug:
|
||||
self.draw_debug_info()
|
||||
# Draw debug info if enabled
|
||||
if self.settings.debug_mode:
|
||||
self.draw_debug_info()
|
||||
|
||||
# Update display
|
||||
pygame.display.flip()
|
||||
|
||||
def run(self):
|
||||
"""Run the main game loop."""
|
||||
self.running = True
|
||||
while self.running:
|
||||
self.handle_events()
|
||||
self.update()
|
||||
self.render()
|
||||
self.clock.tick(self.fps)
|
||||
self.clock.tick(FPS)
|
10
src/main.py
10
src/main.py
@ -1,8 +1,16 @@
|
||||
"""
|
||||
Main entry point for the AI Snake Game.
|
||||
|
||||
This module provides the standard entry point for running the game
|
||||
without any command line arguments.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
import sys
|
||||
from game import Game
|
||||
from src.game import Game
|
||||
|
||||
def main():
|
||||
"""Run the game with default settings."""
|
||||
# Initialize Pygame
|
||||
pygame.init()
|
||||
|
||||
|
@ -81,6 +81,8 @@ class Menu:
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
self.selected_index = i
|
||||
item.hover = True
|
||||
item._setup_font()
|
||||
return item.action
|
||||
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
|
153
src/snake.py
153
src/snake.py
@ -1,153 +0,0 @@
|
||||
import pygame
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
|
||||
class Direction(Enum):
|
||||
UP = auto()
|
||||
DOWN = auto()
|
||||
LEFT = auto()
|
||||
RIGHT = auto()
|
||||
|
||||
class Snake:
|
||||
def __init__(self, start_pos: Tuple[int, int], block_size: int):
|
||||
self.block_size = block_size
|
||||
self.direction = Direction.RIGHT
|
||||
self.body = [start_pos] # Head is at index 0
|
||||
self.growing = False
|
||||
|
||||
# Movement cooldown
|
||||
self.move_cooldown = 100 # milliseconds
|
||||
self.last_move_time = 0
|
||||
|
||||
# Direction change cooldown
|
||||
self.direction_change_cooldown = 50 # milliseconds
|
||||
self.last_direction_change = 0
|
||||
self.queued_direction = None
|
||||
self.last_valid_move_time = 0 # Track when we last actually moved
|
||||
|
||||
def move(self, current_time: int) -> bool:
|
||||
"""Move the snake if enough time has passed. Returns True if moved."""
|
||||
if current_time - self.last_move_time < self.move_cooldown:
|
||||
return False
|
||||
|
||||
# Apply queued direction change if it exists and is valid
|
||||
if self.queued_direction:
|
||||
if current_time - self.last_valid_move_time >= self.direction_change_cooldown:
|
||||
if self._apply_direction_change(self.queued_direction):
|
||||
self.last_direction_change = current_time
|
||||
self.queued_direction = None
|
||||
|
||||
# Update position
|
||||
head = self.body[0]
|
||||
new_head = self._get_new_head_position(head)
|
||||
|
||||
# Insert new head
|
||||
self.body.insert(0, new_head)
|
||||
|
||||
# Remove tail if not growing
|
||||
if not self.growing:
|
||||
self.body.pop()
|
||||
else:
|
||||
self.growing = False
|
||||
|
||||
self.last_move_time = current_time
|
||||
self.last_valid_move_time = current_time
|
||||
return True
|
||||
|
||||
def _get_new_head_position(self, head: Tuple[int, int]) -> Tuple[int, int]:
|
||||
x, y = head
|
||||
if self.direction == Direction.UP:
|
||||
return (x, y - self.block_size)
|
||||
elif self.direction == Direction.DOWN:
|
||||
return (x, y + self.block_size)
|
||||
elif self.direction == Direction.LEFT:
|
||||
return (x - self.block_size, y)
|
||||
else: # Direction.RIGHT
|
||||
return (x + self.block_size, y)
|
||||
|
||||
def grow(self):
|
||||
"""Mark the snake to grow on next move"""
|
||||
self.growing = True
|
||||
|
||||
def check_collision(self, width: int, height: int, wrap_around: bool = False) -> bool:
|
||||
"""
|
||||
Check if snake has collided with walls or itself.
|
||||
|
||||
Args:
|
||||
width: Game area width
|
||||
height: Game area height
|
||||
wrap_around: If True, snake wraps around screen edges instead of colliding
|
||||
"""
|
||||
head = self.body[0]
|
||||
|
||||
if not wrap_around:
|
||||
# Check wall collision
|
||||
if (head[0] < 0 or head[0] >= width or
|
||||
head[1] < 0 or head[1] >= height):
|
||||
return True
|
||||
|
||||
# Check self collision (skip head)
|
||||
if head in self.body[1:]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def wrap_position(self, width: int, height: int):
|
||||
"""Wrap snake's head position around screen edges"""
|
||||
head_x, head_y = self.body[0]
|
||||
wrapped_head = (
|
||||
head_x % width if head_x != width else 0,
|
||||
head_y % height if head_y != height else 0
|
||||
)
|
||||
self.body[0] = wrapped_head
|
||||
|
||||
def _apply_direction_change(self, new_direction: Direction) -> bool:
|
||||
"""
|
||||
Internal method to actually change direction.
|
||||
Returns True if direction was changed, False otherwise.
|
||||
"""
|
||||
# Prevent 180-degree turns when snake is longer than 1
|
||||
if len(self.body) > 1:
|
||||
opposites = {
|
||||
Direction.UP: Direction.DOWN,
|
||||
Direction.DOWN: Direction.UP,
|
||||
Direction.LEFT: Direction.RIGHT,
|
||||
Direction.RIGHT: Direction.LEFT
|
||||
}
|
||||
if opposites[new_direction] == self.direction:
|
||||
return False
|
||||
|
||||
self.direction = new_direction
|
||||
return True
|
||||
|
||||
def change_direction(self, new_direction: Direction, current_time: int):
|
||||
"""
|
||||
Queue a direction change if it's valid and cooldown has passed.
|
||||
|
||||
Args:
|
||||
new_direction: The desired new direction
|
||||
current_time: Current game time in milliseconds
|
||||
"""
|
||||
# If we haven't moved since the last direction change, queue it
|
||||
if current_time - self.last_valid_move_time < self.direction_change_cooldown:
|
||||
self.queued_direction = new_direction
|
||||
return
|
||||
|
||||
# Try to change direction immediately
|
||||
if self._apply_direction_change(new_direction):
|
||||
self.last_direction_change = current_time
|
||||
|
||||
def draw(self, screen: pygame.Surface):
|
||||
"""Draw the snake on the screen"""
|
||||
# Draw head in a slightly different color
|
||||
head_color = (0, 200, 0) # Darker green for head
|
||||
body_color = (0, 255, 0) # Regular green for body
|
||||
|
||||
# Draw body segments
|
||||
for i, segment in enumerate(self.body):
|
||||
color = head_color if i == 0 else body_color
|
||||
pygame.draw.rect(
|
||||
screen,
|
||||
color,
|
||||
pygame.Rect(segment[0], segment[1], self.block_size, self.block_size)
|
||||
)
|
20
src/ui/__init__.py
Normal file
20
src/ui/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
Game UI Package
|
||||
|
||||
This package contains all user interface components.
|
||||
"""
|
||||
|
||||
from .menu import Menu, GameMode
|
||||
from .menu_item import MenuItem
|
||||
from .settings_menu import SettingsMenu
|
||||
from .pause_menu import PauseMenu
|
||||
from .game_area import GameArea
|
||||
|
||||
__all__ = [
|
||||
'Menu',
|
||||
'GameMode',
|
||||
'MenuItem',
|
||||
'SettingsMenu',
|
||||
'PauseMenu',
|
||||
'GameArea'
|
||||
]
|
138
src/ui/game_area.py
Normal file
138
src/ui/game_area.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""
|
||||
Game Area Module
|
||||
|
||||
This module provides the GameArea class that handles the rendering of the game area
|
||||
and coordinate conversions between screen and grid coordinates.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from src.config import (
|
||||
DEFAULT_PADDING,
|
||||
DEFAULT_SCORE_HEIGHT,
|
||||
)
|
||||
from src.config.colors import *
|
||||
from typing import Tuple
|
||||
|
||||
class GameArea:
|
||||
"""Manages the game area rendering and coordinate conversions."""
|
||||
|
||||
def __init__(self, window_width, window_height, grid_size=20, grid_width=30, grid_height=30,
|
||||
padding=DEFAULT_PADDING, score_height=DEFAULT_SCORE_HEIGHT):
|
||||
# Grid properties
|
||||
self.grid_size = grid_size
|
||||
self.grid_width = grid_width
|
||||
self.grid_height = grid_height
|
||||
|
||||
# Area properties
|
||||
self.padding = padding
|
||||
self.score_height = score_height
|
||||
self.border_thickness = 4 # Increased border thickness
|
||||
|
||||
# Calculate game area dimensions
|
||||
self.width = self.grid_width * self.grid_size
|
||||
self.height = self.grid_height * self.grid_size
|
||||
|
||||
# Center the game area in the window
|
||||
self.x = (window_width - self.width) // 2
|
||||
self.y = self.score_height + self.padding
|
||||
|
||||
# Create rectangle for game bounds
|
||||
self.rect = pygame.Rect(self.x, self.y, self.width, self.height)
|
||||
|
||||
# Create inner rectangle for gameplay area with thicker border
|
||||
self.inner_rect = pygame.Rect(
|
||||
self.x + self.border_thickness,
|
||||
self.y + self.border_thickness,
|
||||
self.width - (self.border_thickness * 2),
|
||||
self.height - (self.border_thickness * 2)
|
||||
)
|
||||
|
||||
self.screen_offset = (self.x, self.y)
|
||||
|
||||
def get_grid_pos(self, screen_x, screen_y):
|
||||
"""Convert screen coordinates to grid position."""
|
||||
grid_x = (screen_x - self.x) // self.grid_size
|
||||
grid_y = (screen_y - self.y) // self.grid_size
|
||||
return grid_x, grid_y
|
||||
|
||||
def get_screen_pos(self, grid_x, grid_y):
|
||||
"""Convert grid position to screen coordinates."""
|
||||
screen_x = self.x + (grid_x * self.grid_size)
|
||||
screen_y = self.y + (grid_y * self.grid_size)
|
||||
return screen_x, screen_y
|
||||
|
||||
def is_within_bounds(self, grid_x, grid_y):
|
||||
"""Check if grid position is within game bounds."""
|
||||
return 0 <= grid_x < self.grid_width and 0 <= grid_y < self.grid_height
|
||||
|
||||
def draw(self, screen):
|
||||
"""Draw the game area boundary and background with modern effects."""
|
||||
# Draw outer border with gradient effect
|
||||
for i in range(self.border_thickness):
|
||||
border_rect = pygame.Rect(
|
||||
self.x + i,
|
||||
self.y + i,
|
||||
self.width - (i * 2),
|
||||
self.height - (i * 2)
|
||||
)
|
||||
# Gradient from darker to lighter
|
||||
color_factor = i / self.border_thickness
|
||||
r = int(DARKER_GREEN[0] * (1 - color_factor) + FOREST_GREEN[0] * color_factor)
|
||||
g = int(DARKER_GREEN[1] * (1 - color_factor) + FOREST_GREEN[1] * color_factor)
|
||||
b = int(DARKER_GREEN[2] * (1 - color_factor) + FOREST_GREEN[2] * color_factor)
|
||||
pygame.draw.rect(screen, (r, g, b), border_rect)
|
||||
|
||||
# Draw inner background
|
||||
pygame.draw.rect(screen, DARK_GRAY, self.inner_rect)
|
||||
|
||||
# Draw subtle grid pattern
|
||||
if hasattr(self, 'show_grid') and self.show_grid:
|
||||
# Debug grid lines
|
||||
for x in range(self.grid_width + 1):
|
||||
start_pos = (self.x + x * self.grid_size, self.y)
|
||||
end_pos = (self.x + x * self.grid_size, self.y + self.height)
|
||||
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
|
||||
|
||||
for y in range(self.grid_height + 1):
|
||||
start_pos = (self.x, self.y + y * self.grid_size)
|
||||
end_pos = (self.x + self.width, self.y + y * self.grid_size)
|
||||
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
|
||||
else:
|
||||
# Subtle grid dots with fade effect
|
||||
for x in range(self.grid_width + 1):
|
||||
for y in range(self.grid_height + 1):
|
||||
pos_x = self.x + x * self.grid_size
|
||||
pos_y = self.y + y * self.grid_size
|
||||
|
||||
# Calculate distance from center for fade effect
|
||||
center_x = self.x + (self.width / 2)
|
||||
center_y = self.y + (self.height / 2)
|
||||
dist_x = abs(pos_x - center_x) / (self.width / 2)
|
||||
dist_y = abs(pos_y - center_y) / (self.height / 2)
|
||||
dist = min(1.0, (dist_x + dist_y) / 2)
|
||||
|
||||
# Fade dot color based on distance from center
|
||||
dot_color = tuple(int(c * (1 - dist * 0.5)) for c in SUBTLE_GRID_COLOR)
|
||||
pygame.draw.circle(screen, dot_color, (pos_x, pos_y), 1)
|
||||
|
||||
# Draw inner border glow
|
||||
glow_colors = [
|
||||
(int(NEON_GREEN[0] * 0.4), int(NEON_GREEN[1] * 0.4), int(NEON_GREEN[2] * 0.4)),
|
||||
(int(NEON_GREEN[0] * 0.6), int(NEON_GREEN[1] * 0.6), int(NEON_GREEN[2] * 0.6)),
|
||||
NEON_GREEN
|
||||
]
|
||||
|
||||
for i, color in enumerate(glow_colors):
|
||||
glow_rect = pygame.Rect(
|
||||
self.inner_rect.x - i,
|
||||
self.inner_rect.y - i,
|
||||
self.inner_rect.width + (i * 2),
|
||||
self.inner_rect.height + (i * 2)
|
||||
)
|
||||
pygame.draw.rect(screen, color, glow_rect, 1)
|
||||
|
||||
def grid_to_screen(self, grid_pos: Tuple[int, int]) -> Tuple[int, int]:
|
||||
"""Convert grid coordinates to screen coordinates"""
|
||||
x = self.rect.left + grid_pos[0] * self.grid_size + self.grid_size // 2
|
||||
y = self.rect.top + grid_pos[1] * self.grid_size + self.grid_size // 2
|
||||
return (x, y)
|
127
src/ui/menu.py
Normal file
127
src/ui/menu.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Main Menu Module
|
||||
|
||||
This module provides the main menu interface for the game.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from enum import Enum, auto
|
||||
from src.config import (
|
||||
WINDOW_WIDTH,
|
||||
WINDOW_HEIGHT,
|
||||
TITLE_FONT_SIZE,
|
||||
SUBTITLE_FONT_SIZE,
|
||||
MENU_ITEM_SPACING
|
||||
)
|
||||
from src.config.colors import BLACK, GRAY
|
||||
from src.ui.menu_item import MenuItem
|
||||
|
||||
class GameMode(Enum):
|
||||
"""Available game modes."""
|
||||
PLAYER = auto()
|
||||
AI_EASY = auto()
|
||||
AI_MEDIUM = auto()
|
||||
AI_HARD = auto()
|
||||
SETTINGS = auto()
|
||||
|
||||
class Menu:
|
||||
"""Main menu interface."""
|
||||
|
||||
def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT):
|
||||
"""Initialize the main menu."""
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.setup_menu_items()
|
||||
|
||||
# Create fonts
|
||||
self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE)
|
||||
self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE)
|
||||
|
||||
# Create title surfaces
|
||||
self.title_surface = self.title_font.render("Snake Game", True, (0, 255, 0))
|
||||
self.title_rect = self.title_surface.get_rect(center=(width//2, height//4))
|
||||
|
||||
# Initialize first item as selected
|
||||
self.selected_index = 0
|
||||
self.menu_items[0].hover = True
|
||||
self.menu_items[0]._setup_font()
|
||||
|
||||
def setup_menu_items(self):
|
||||
"""Setup the menu items."""
|
||||
start_y = self.height // 2
|
||||
spacing = MENU_ITEM_SPACING
|
||||
center_x = self.width // 2
|
||||
|
||||
self.menu_items = [
|
||||
MenuItem("Player Game", (center_x, start_y), GameMode.PLAYER),
|
||||
MenuItem("AI Game (Easy)", (center_x, start_y + spacing), GameMode.AI_EASY),
|
||||
MenuItem("AI Game (Medium)", (center_x, start_y + spacing * 2), GameMode.AI_MEDIUM),
|
||||
MenuItem("AI Game (Hard)", (center_x, start_y + spacing * 3), GameMode.AI_HARD),
|
||||
MenuItem("Settings", (center_x, start_y + spacing * 4), GameMode.SETTINGS),
|
||||
MenuItem("Quit", (center_x, start_y + spacing * 5), 'quit')
|
||||
]
|
||||
|
||||
def update(self):
|
||||
"""Update menu state."""
|
||||
# Handle mouse hover
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
# Update selected index when mouse hovers
|
||||
self.selected_index = i
|
||||
item.hover = True
|
||||
item._setup_font()
|
||||
else:
|
||||
# Keep keyboard selection visible
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
def handle_input(self, event):
|
||||
"""
|
||||
Handle input events.
|
||||
|
||||
Returns:
|
||||
GameMode or None: The selected game mode or None if no selection made
|
||||
"""
|
||||
if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
|
||||
# Handle mouse clicks
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for item in self.menu_items:
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
return item.action
|
||||
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_RETURN:
|
||||
return self.menu_items[self.selected_index].action
|
||||
elif event.key in (pygame.K_UP, pygame.K_DOWN):
|
||||
# Update selected index
|
||||
if event.key == pygame.K_UP:
|
||||
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
|
||||
else:
|
||||
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
|
||||
|
||||
# Update hover states
|
||||
for i, item in enumerate(self.menu_items):
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
return None
|
||||
|
||||
def draw(self, screen):
|
||||
"""Draw the menu to the screen."""
|
||||
# Draw background
|
||||
screen.fill(BLACK)
|
||||
|
||||
# Draw title
|
||||
screen.blit(self.title_surface, self.title_rect)
|
||||
|
||||
# Draw menu items
|
||||
for item in self.menu_items:
|
||||
item.draw(screen)
|
||||
|
||||
# Draw controls
|
||||
controls_text = "Arrow keys or mouse to navigate, Enter to select"
|
||||
controls_surface = self.subtitle_font.render(controls_text, True, GRAY)
|
||||
screen.blit(controls_surface,
|
||||
(self.width - controls_surface.get_width() - 10,
|
||||
self.height - 30))
|
84
src/ui/menu_item.py
Normal file
84
src/ui/menu_item.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Menu Item Module
|
||||
|
||||
This module provides the MenuItem class that represents a clickable menu item.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from src.config.colors import *
|
||||
|
||||
class MenuItem:
|
||||
"""A clickable menu item with hover effects."""
|
||||
|
||||
def __init__(self, text, position, action, font_size_normal=36, font_size_hover=48):
|
||||
"""
|
||||
Initialize a menu item.
|
||||
|
||||
Args:
|
||||
text: The text to display
|
||||
position: (x, y) tuple for center position
|
||||
action: String identifier for the action to take when clicked
|
||||
font_size_normal: Font size when not hovered (default 36)
|
||||
font_size_hover: Font size when hovered (default 48)
|
||||
"""
|
||||
self.text = text
|
||||
self.position = position
|
||||
self.action = action
|
||||
self.hover = False
|
||||
self.font_size_normal = font_size_normal
|
||||
self.font_size_hover = font_size_hover
|
||||
self._setup_font()
|
||||
|
||||
def _setup_font(self):
|
||||
"""Setup the font based on hover state."""
|
||||
size = self.font_size_hover if self.hover else self.font_size_normal
|
||||
self.font = pygame.font.Font(None, size)
|
||||
self.surface = self.font.render(self.text, True, GREEN if self.hover else WHITE)
|
||||
self.rect = self.surface.get_rect(center=self.position)
|
||||
|
||||
def draw(self, screen):
|
||||
"""Draw the menu item to the screen."""
|
||||
screen.blit(self.surface, self.rect)
|
||||
|
||||
|
||||
class NumberMenuItem(MenuItem):
|
||||
"""A menu item that displays a number that can be adjusted with arrow keys."""
|
||||
|
||||
def __init__(self, text, position, action, min_value, max_value,
|
||||
step=1, font_size_normal=36, font_size_hover=48):
|
||||
"""
|
||||
Initialize a number menu item.
|
||||
|
||||
Args:
|
||||
text: The text to display
|
||||
position: (x, y) tuple for center position
|
||||
action: String identifier for the action to take when clicked
|
||||
min_value: Minimum allowed value
|
||||
max_value: Maximum allowed value
|
||||
step: Amount to increment/decrement by (default 1)
|
||||
font_size_normal: Font size when not hovered (default 36)
|
||||
font_size_hover: Font size when hovered (default 48)
|
||||
"""
|
||||
super().__init__(text, position, action, font_size_normal, font_size_hover)
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.step = step
|
||||
self.current_value = int(text.split(": ")[1]) # Extract number from text
|
||||
|
||||
def increment(self):
|
||||
"""Increment the current value by step amount."""
|
||||
if self.current_value + self.step <= self.max_value:
|
||||
self.current_value += self.step
|
||||
self._update_text()
|
||||
|
||||
def decrement(self):
|
||||
"""Decrement the current value by step amount."""
|
||||
if self.current_value - self.step >= self.min_value:
|
||||
self.current_value -= self.step
|
||||
self._update_text()
|
||||
|
||||
def _update_text(self):
|
||||
"""Update the displayed text with new value."""
|
||||
base_text = self.text.split(": ")[0]
|
||||
self.text = f"{base_text}: {self.current_value}"
|
||||
self._setup_font()
|
129
src/ui/pause_menu.py
Normal file
129
src/ui/pause_menu.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""
|
||||
Pause Menu Module
|
||||
|
||||
This module provides the pause menu interface that appears during gameplay.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from src.config import (
|
||||
WINDOW_WIDTH,
|
||||
WINDOW_HEIGHT,
|
||||
TITLE_FONT_SIZE,
|
||||
SUBTITLE_FONT_SIZE,
|
||||
MENU_ITEM_SPACING
|
||||
)
|
||||
from src.config.colors import *
|
||||
from src.ui.menu_item import MenuItem
|
||||
|
||||
class PauseMenu:
|
||||
"""Pause menu interface."""
|
||||
|
||||
def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT):
|
||||
"""
|
||||
Initialize the pause menu.
|
||||
|
||||
Args:
|
||||
width: Window width
|
||||
height: Window height
|
||||
"""
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.setup_menu_items()
|
||||
|
||||
self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE)
|
||||
self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE)
|
||||
|
||||
# Create title surfaces
|
||||
self.title_surface = self.title_font.render("Paused", True, (0, 255, 0))
|
||||
self.title_rect = self.title_surface.get_rect(center=(width//2, height//3))
|
||||
|
||||
# Initialize first item as selected
|
||||
self.selected_index = 0
|
||||
self.menu_items[0].hover = True
|
||||
self.menu_items[0]._setup_font()
|
||||
|
||||
def setup_menu_items(self):
|
||||
"""Setup the menu items."""
|
||||
start_y = self.height // 2
|
||||
spacing = MENU_ITEM_SPACING
|
||||
center_x = self.width // 2
|
||||
|
||||
self.menu_items = [
|
||||
MenuItem("Resume",
|
||||
(center_x, start_y),
|
||||
'resume'),
|
||||
MenuItem("Return to Menu",
|
||||
(center_x, start_y + spacing),
|
||||
'menu')
|
||||
]
|
||||
|
||||
def update(self):
|
||||
"""Update menu state."""
|
||||
# Handle mouse hover
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
# Update selected index when mouse hovers
|
||||
self.selected_index = i
|
||||
item.hover = True
|
||||
item._setup_font()
|
||||
else:
|
||||
# Keep keyboard selection visible
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
def handle_input(self, event):
|
||||
"""
|
||||
Handle input events.
|
||||
|
||||
Returns:
|
||||
str or None: The action to take or None if no action
|
||||
"""
|
||||
if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
|
||||
# Handle mouse clicks
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
self.selected_index = i
|
||||
return item.action
|
||||
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
return 'resume'
|
||||
elif event.key == pygame.K_RETURN:
|
||||
return self.menu_items[self.selected_index].action
|
||||
elif event.key in (pygame.K_UP, pygame.K_DOWN):
|
||||
# Update selected index
|
||||
if event.key == pygame.K_UP:
|
||||
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
|
||||
else:
|
||||
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
|
||||
|
||||
# Update hover states
|
||||
for i, item in enumerate(self.menu_items):
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
return None
|
||||
|
||||
def draw(self, screen):
|
||||
"""Draw the menu to the screen."""
|
||||
# Draw semi-transparent background
|
||||
overlay = pygame.Surface((screen.get_width(), screen.get_height()))
|
||||
overlay.fill(BLACK)
|
||||
overlay.set_alpha(128)
|
||||
screen.blit(overlay, (0, 0))
|
||||
|
||||
# Draw title
|
||||
screen.blit(self.title_surface, self.title_rect)
|
||||
|
||||
# Draw menu items
|
||||
for item in self.menu_items:
|
||||
item.draw(screen)
|
||||
|
||||
# Draw controls
|
||||
controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to resume"
|
||||
controls_surface = self.subtitle_font.render(controls_text, True, GRAY)
|
||||
screen.blit(controls_surface,
|
||||
(self.width - controls_surface.get_width() - 10,
|
||||
self.height - 30))
|
159
src/ui/settings_menu.py
Normal file
159
src/ui/settings_menu.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""
|
||||
Settings Menu Module
|
||||
|
||||
This module provides the settings menu interface for configuring game rules.
|
||||
"""
|
||||
|
||||
import pygame
|
||||
from src.config import (
|
||||
WINDOW_WIDTH,
|
||||
WINDOW_HEIGHT,
|
||||
TITLE_FONT_SIZE,
|
||||
SUBTITLE_FONT_SIZE,
|
||||
MENU_ITEM_SPACING,
|
||||
)
|
||||
from src.config.colors import *
|
||||
from src.ui.menu_item import MenuItem, NumberMenuItem
|
||||
|
||||
class SettingsMenu:
|
||||
"""Settings menu interface."""
|
||||
|
||||
def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT, rules=None):
|
||||
"""
|
||||
Initialize the settings menu.
|
||||
|
||||
Args:
|
||||
width: Window width
|
||||
height: Window height
|
||||
rules: GameRules instance to modify
|
||||
"""
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.rules = rules
|
||||
self.setup_menu_items()
|
||||
|
||||
self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE)
|
||||
self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE)
|
||||
|
||||
# Create title surfaces
|
||||
self.title_surface = self.title_font.render("Settings", True, (0, 255, 0))
|
||||
self.title_rect = self.title_surface.get_rect(center=(width//2, height//4))
|
||||
|
||||
# Initialize first item as selected
|
||||
self.selected_index = 0
|
||||
self.menu_items[0].hover = True
|
||||
self.menu_items[0]._setup_font()
|
||||
|
||||
def setup_menu_items(self):
|
||||
"""Setup the menu items."""
|
||||
start_y = self.height // 2
|
||||
spacing = MENU_ITEM_SPACING
|
||||
center_x = self.width // 2
|
||||
|
||||
self.menu_items = [
|
||||
MenuItem(f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}",
|
||||
(center_x, start_y),
|
||||
'toggle_wrap'),
|
||||
MenuItem(f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}",
|
||||
(center_x, start_y + spacing),
|
||||
'toggle_speed'),
|
||||
NumberMenuItem(f"Starting Length: {self.rules.starting_length}",
|
||||
(center_x, start_y + spacing * 2),
|
||||
'set_starting_length',
|
||||
1, 10),
|
||||
MenuItem("Back to Menu",
|
||||
(center_x, start_y + spacing * 3),
|
||||
'back')
|
||||
]
|
||||
|
||||
def update(self):
|
||||
"""Update menu state."""
|
||||
# Handle mouse hover
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
# Update selected index when mouse hovers
|
||||
self.selected_index = i
|
||||
item.hover = True
|
||||
item._setup_font()
|
||||
else:
|
||||
# Keep keyboard selection visible
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
def handle_input(self, event):
|
||||
"""
|
||||
Handle input events.
|
||||
|
||||
Returns:
|
||||
str or None: The action to take or None if no action
|
||||
"""
|
||||
if event.type == pygame.MOUSEBUTTONDOWN:
|
||||
# Handle mouse clicks
|
||||
mouse_pos = pygame.mouse.get_pos()
|
||||
for i, item in enumerate(self.menu_items):
|
||||
if item.rect.collidepoint(mouse_pos):
|
||||
self.selected_index = i
|
||||
if event.button == 1:
|
||||
if item.action == 'toggle_wrap':
|
||||
self.rules.update_rule('wrap_around', not self.rules.wrap_around)
|
||||
elif item.action == 'toggle_speed':
|
||||
self.rules.speed_increase = not self.rules.speed_increase
|
||||
item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}"
|
||||
item._setup_font()
|
||||
elif item.action == 'set_starting_length':
|
||||
item.increment()
|
||||
self.rules.update_rule('starting_length', item.current_value)
|
||||
elif item.action == 'back':
|
||||
return 'back'
|
||||
elif event.button == 3:
|
||||
if item.action == 'set_starting_length':
|
||||
item.decrement()
|
||||
self.rules.update_rule('starting_length', item.current_value)
|
||||
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
return 'back'
|
||||
elif event.key == pygame.K_RETURN:
|
||||
item = self.menu_items[self.selected_index]
|
||||
if item.action == 'toggle_wrap':
|
||||
self.rules.update_rule('wrap_around', not self.rules.wrap_around)
|
||||
elif item.action == 'toggle_speed':
|
||||
self.rules.update_rule('speed_increase', not self.rules.speed_increase)
|
||||
elif item.action == 'set_starting_length':
|
||||
item.increment()
|
||||
self.rules.update_rule('starting_length', item.current_value)
|
||||
elif item.action == 'back':
|
||||
return 'back'
|
||||
elif event.key in (pygame.K_UP, pygame.K_DOWN):
|
||||
# Update selected index
|
||||
if event.key == pygame.K_UP:
|
||||
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
|
||||
else:
|
||||
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
|
||||
|
||||
# Update hover states
|
||||
for i, item in enumerate(self.menu_items):
|
||||
item.hover = (i == self.selected_index)
|
||||
item._setup_font()
|
||||
|
||||
return None
|
||||
|
||||
def draw(self, screen):
|
||||
"""Draw the menu to the screen."""
|
||||
# Draw background
|
||||
screen.fill(BLACK)
|
||||
|
||||
# Draw title
|
||||
screen.blit(self.title_surface, self.title_rect)
|
||||
|
||||
# Draw menu items
|
||||
for item in self.menu_items:
|
||||
item.draw(screen)
|
||||
|
||||
# Draw controls
|
||||
controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to go back"
|
||||
controls_surface = self.subtitle_font.render(controls_text, True, GRAY)
|
||||
screen.blit(controls_surface,
|
||||
(self.width - controls_surface.get_width() - 10,
|
||||
self.height - 30))
|
@ -8,12 +8,12 @@ This package contains all test modules including:
|
||||
"""
|
||||
|
||||
# Import test utilities and fixtures
|
||||
from tests.conftest import game_config, mock_screen
|
||||
from conftest import game_config, mock_screen
|
||||
|
||||
__version__ = '0.1.0'
|
||||
|
||||
# Define test utilities available for import
|
||||
__all__ = [
|
||||
'game_config',
|
||||
'mock_screen'
|
||||
'mock_screen',
|
||||
]
|
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from src.food import Food
|
||||
from src.core import Food
|
||||
|
||||
def test_food_initialization():
|
||||
"""Test food initialization"""
|
||||
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
from src.snake import Snake, Direction
|
||||
from src.core import Snake, Direction
|
||||
|
||||
def test_snake_initialization():
|
||||
"""Test snake initialization with default values"""
|
||||
|
15
training_config.json
Normal file
15
training_config.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"timesteps": 3000000,
|
||||
"learning_rate": 3e-4,
|
||||
"batch_size": 256,
|
||||
"n_envs": 16,
|
||||
"n_steps": 2048,
|
||||
"gamma": 0.99,
|
||||
"ent_coef": 0.1,
|
||||
"n_epochs": 10,
|
||||
"gae_lambda": 0.95,
|
||||
"clip_range": 0.2,
|
||||
"vf_coef": 0.5,
|
||||
"max_grad_norm": 0.5,
|
||||
"normalize_advantage": true
|
||||
}
|
Loading…
Reference in New Issue
Block a user