Major code restructuring: Reorganized project structure, added AI components - Added new module structure (ai, config, core, ui) - Moved snake and food logic into core module - Added training configuration - Updated gitignore for project-specific files - Modified tests to match new structure

This commit is contained in:
Rbanh 2025-02-24 18:38:44 -05:00
parent 251822ec35
commit b8eeaa4739
34 changed files with 3928 additions and 513 deletions

12
.gitignore vendored
View File

@ -7,6 +7,7 @@ ENV/
.virtualenv/
virtualenv/
*venv/
Activate.ps1
# Python
__pycache__/
@ -39,3 +40,14 @@ wheels/
# OS
.DS_Store
Thumbs.db
# Project specific
logs/
models/
.pytest_cache/
path/
training_data/
*.log
*.pth
*.ckpt
*.h5

View File

@ -3,19 +3,26 @@
This is a living document that tracks the development progress of the AI Snake Game project.
## Current Status
Last Updated: [Menu System Implementation]
- Added main menu interface with game mode selection
- Implemented menu navigation and selection
- Added game mode display during gameplay
- Integrated menu system with game states
Last Updated: [Control and Menu Improvements]
- Fixed menu interaction issues:
- Added proper mouse click handling
- Fixed menu item highlighting
- Improved keyboard and mouse navigation
- Improved snake movement system:
- Implemented input buffer for direction changes
- Fixed rapid direction change issues
- Improved responsiveness while maintaining safety
- Previous updates:
- Added settings menu with game options
- Added wrap-around mode toggle
- Added speed increase toggle
- Added main menu interface
- Added game mode selection
- Implemented menu navigation
- Added game mode display
- Added game restart functionality
- Added "Press any key to restart" message
- Implemented game state reset
- Added Food class with spawning mechanics
- Integrated food system with main game loop
- Implemented basic scoring system
- Added game over screen with score display
## Phase 1: Project Setup and Basic Structure ✅
- [x] Create project structure and virtual environment
@ -27,11 +34,18 @@ Last Updated: [Menu System Implementation]
## Phase 2: Core Game Development ✅
- [x] Implement basic game window using Pygame
- [x] Create Snake class with movement mechanics
- [x] Basic movement and collision
- [x] Input buffering for direction changes
- [x] Fix rapid direction change issues
- [x] Prevent 180-degree turns
- [x] Implement food spawning system
- [x] Create Food class
- [x] Add random food placement
- [x] Add collision detection with food
- [x] Add collision detection (walls and self)
- [x] Add collision detection
- [x] Wall collisions
- [x] Self collisions
- [x] Add wrap-around mode option
- [x] Implement scoring system
- [x] Add score display
- [ ] Add high score tracking
@ -39,12 +53,19 @@ Last Updated: [Menu System Implementation]
- [x] Implement game over state transition
- [x] Add restart functionality
## Phase 3: Game Polish and UI 🔄
## Phase 3: Game Polish and UI
- [x] Add main menu interface
- [x] Create menu layout
- [x] Add game mode selection
- [ ] Add settings options
- [x] Add settings menu
- [x] Fix menu interaction issues
- [x] Mouse click handling
- [x] Menu item highlighting
- [x] Keyboard/mouse navigation
- [x] Implement pause functionality
- [x] Add settings options
- [x] Wrap-around mode toggle
- [x] Speed increase toggle
- [ ] Add visual effects and animations
- [ ] Snake movement animation
- [ ] Food spawn animation
@ -58,36 +79,63 @@ Last Updated: [Menu System Implementation]
- [x] Add game over screen with restart option
## Phase 4: AI Implementation 🔄
- [ ] Design AI algorithm (pathfinding)
- [ ] Research and select appropriate algorithm
- [ ] Implement basic pathfinding
- [ ] Implement AI snake movement logic
- [ ] Create AI controller class
- [ ] Implement decision making
- [ ] Add AI difficulty levels
- [ ] Easy mode (basic pathfinding)
- [ ] Medium mode (improved decision making)
- [ ] Set up ML training environment
- [ ] Create gym-like interface for the game
- [ ] Define observation space (snake + food state)
- [ ] Define action space (4 directions)
- [ ] Implement reward system
- [ ] Implement ML infrastructure
- [ ] Add ML dependencies (PyTorch/TensorFlow)
- [ ] Create neural network architecture
- [ ] Set up training pipeline
- [ ] Train AI models for different difficulties
- [ ] Easy mode (basic food seeking)
- [ ] Medium mode (balanced survival/food seeking)
- [ ] Hard mode (optimal pathfinding)
- [ ] Create AI controller class
- [ ] Model loading and inference
- [ ] Real-time decision making
- [ ] Performance optimization
- [x] Create mode selection (Player vs AI)
- [ ] Optimize AI performance
- [ ] Add training utilities
- [ ] Save/load model checkpoints
- [ ] Training progress visualization
- [ ] Model performance metrics
## Phase 5: Final Polish
- [ ] Add configuration options
- [ ] Game speed settings
- [x] Add configuration options
- [x] Game speed settings
- [x] Wrap-around option
- [ ] Visual settings
- [ ] Sound settings
- [ ] Implement save/load functionality for high scores
- [ ] Bug fixing and performance optimization
- [x] Fix menu interaction bugs
- [x] Fix movement control issues
- [x] Optimize input handling
- [ ] Code cleanup and documentation
- [ ] Add docstrings
- [x] Add docstrings
- [ ] Create API documentation
- [ ] Update README with final features
- [ ] Final testing and refinements
## Resolved Issues
1. Menu Interaction
- Issue: Menu items not responding to clicks
- Solution: Added proper mouse click event handling in game loop
- Solution: Fixed menu item highlighting and selection state
2. Snake Movement
- Issue: Rapid direction changes causing self-collision
- Issue: Missed inputs during quick turns
- Solution: Implemented input buffer system
- Solution: Process one direction change per grid movement
- Solution: Store up to 2 pending direction changes
## Next Steps
1. Implement AI snake movement
- Design pathfinding algorithm
- Create AI controller class
- Set up ML training environment
- Create neural network architecture
- Train initial model
2. Add high score system
- Implement score persistence
- Add high score display
@ -96,7 +144,7 @@ Last Updated: [Menu System Implementation]
- Add game sounds
## Known Issues
- None reported yet
- None reported
## Future Considerations
- Multiplayer support

View File

@ -1,5 +1,12 @@
pygame==2.5.2
numpy==1.24.3
numpy==1.26.3
black==23.12.1 # for code formatting
pytest==7.4.3 # for testing
pytest-xdist==3.5.0 # for parallel test execution
pytest-xdist==3.5.0 # for parallel test execution
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.1.2+cu118 # CUDA 11.8 support
stable-baselines3[extra]==2.2.1 # for reinforcement learning
tensorboard==2.15.1 # for training visualization
gymnasium==0.29.1 # for RL environment interface
tqdm==4.66.1
rich==13.7.0

View File

@ -9,11 +9,12 @@ This package contains all the core game components including:
- AI controllers (coming soon)
"""
from src import config
from src import core
from src import ai
from src import ui
from src.game import Game, GameState
from src.snake import Snake, Direction
from src.food import Food
from src.menu import Menu, GameMode, MenuItem
from src.cli import main as cli_main
from src.ui.menu import GameMode
__version__ = '0.1.0'
__author__ = 'Rbanh'
@ -28,5 +29,9 @@ __all__ = [
'Menu',
'GameMode',
'MenuItem',
'cli_main'
'cli_main',
'config',
'core',
'ai',
'ui'
]

235
src/ai/config_ui.py Normal file
View File

@ -0,0 +1,235 @@
"""
Training Configuration UI
This module provides a graphical interface for configuring AI training parameters.
"""
import pygame
import json
import os
from typing import Dict, Any, Optional, Tuple, List
class ConfigUI:
def __init__(self, width: int = 800, height: int = 600):
"""Initialize the configuration UI."""
pygame.init()
self.width = width
self.height = height
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Training Configuration")
# Fonts
self.title_font = pygame.font.Font(None, 48)
self.header_font = pygame.font.Font(None, 36)
self.text_font = pygame.font.Font(None, 24)
# Colors
self.colors = {
'background': (0, 0, 0),
'text': (200, 200, 200),
'highlight': (0, 255, 0),
'button': (40, 40, 40),
'button_hover': (60, 60, 60),
'input_bg': (30, 30, 30),
'input_active': (50, 50, 50)
}
# Parameters
self.parameters = {
'timesteps': {
'value': 1000000,
'type': 'int',
'min': 1000,
'max': 10000000,
'description': 'Total timesteps to train for'
},
'learning_rate': {
'value': 0.0003,
'type': 'float',
'min': 0.00001,
'max': 0.01,
'description': 'Learning rate for training'
},
'batch_size': {
'value': 64,
'type': 'int',
'min': 32,
'max': 512,
'description': 'Batch size for training'
},
'n_envs': {
'value': 8,
'type': 'int',
'min': 1,
'max': 32,
'description': 'Number of parallel environments'
},
'n_steps': {
'value': 2048,
'type': 'int',
'min': 128,
'max': 8192,
'description': 'Number of steps per update'
}
}
# UI state
self.active_input = None
self.input_text = ""
self.scroll_offset = 0
self.max_scroll = max(0, len(self.parameters) * 60 - (height - 200))
# Load saved config if exists
self.config_file = "training_config.json"
self.load_config()
def load_config(self) -> None:
"""Load configuration from file."""
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r') as f:
saved_config = json.load(f)
for key, value in saved_config.items():
if key in self.parameters:
self.parameters[key]['value'] = value
except:
pass
def save_config(self) -> None:
"""Save configuration to file."""
config = {key: param['value'] for key, param in self.parameters.items()}
with open(self.config_file, 'w') as f:
json.dump(config, f, indent=4)
def draw_text_input(self, rect: pygame.Rect, value: Any, active: bool) -> None:
"""Draw a text input field."""
color = self.colors['input_active'] if active else self.colors['input_bg']
pygame.draw.rect(self.screen, color, rect)
pygame.draw.rect(self.screen, self.colors['text'], rect, 1)
text = str(value)
if active:
text = self.input_text + "|"
text_surface = self.text_font.render(text, True, self.colors['text'])
text_rect = text_surface.get_rect(midleft=(rect.left + 5, rect.centery))
self.screen.blit(text_surface, text_rect)
def draw_button(self, rect: pygame.Rect, text: str, hover: bool = False) -> None:
"""Draw a button."""
color = self.colors['button_hover'] if hover else self.colors['button']
pygame.draw.rect(self.screen, color, rect)
pygame.draw.rect(self.screen, self.colors['text'], rect, 1)
text_surface = self.text_font.render(text, True, self.colors['text'])
text_rect = text_surface.get_rect(center=rect.center)
self.screen.blit(text_surface, text_rect)
def validate_input(self, param: Dict[str, Any], value: str) -> Optional[Any]:
"""Validate and convert input value."""
try:
if param['type'] == 'int':
val = int(value)
else:
val = float(value)
if val < param['min'] or val > param['max']:
return None
return val
except:
return None
def run(self) -> Optional[Dict[str, Any]]:
"""Run the configuration UI. Returns the config dict if saved, None if cancelled."""
running = True
save_clicked = False
mouse_pos = (0, 0)
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.MOUSEBUTTONDOWN:
mouse_pos = event.pos
# Check parameter inputs
y = 100 - self.scroll_offset
for name, param in self.parameters.items():
input_rect = pygame.Rect(300, y, 200, 30)
if input_rect.collidepoint(mouse_pos):
self.active_input = name
self.input_text = str(param['value'])
y += 60
# Check buttons
save_rect = pygame.Rect(self.width//2 - 150, self.height - 60, 140, 40)
cancel_rect = pygame.Rect(self.width//2 + 10, self.height - 60, 140, 40)
if save_rect.collidepoint(mouse_pos):
self.save_config()
save_clicked = True
running = False
elif cancel_rect.collidepoint(mouse_pos):
running = False
elif event.type == pygame.MOUSEBUTTONUP:
if event.button == 4: # Mouse wheel up
self.scroll_offset = max(0, self.scroll_offset - 30)
elif event.button == 5: # Mouse wheel down
self.scroll_offset = min(self.max_scroll, self.scroll_offset + 30)
elif event.type == pygame.KEYDOWN:
if self.active_input is not None:
if event.key == pygame.K_RETURN:
param = self.parameters[self.active_input]
if val := self.validate_input(param, self.input_text):
param['value'] = val
self.active_input = None
elif event.key == pygame.K_BACKSPACE:
self.input_text = self.input_text[:-1]
else:
if event.unicode.isnumeric() or event.unicode == '.':
self.input_text += event.unicode
# Draw UI
self.screen.fill(self.colors['background'])
# Title
title = "Training Configuration"
title_surface = self.title_font.render(title, True, self.colors['highlight'])
title_rect = title_surface.get_rect(midtop=(self.width//2, 20))
self.screen.blit(title_surface, title_rect)
# Parameters
y = 100 - self.scroll_offset
for name, param in self.parameters.items():
if 0 <= y <= self.height - 100:
# Parameter name and description
name_surface = self.header_font.render(name, True, self.colors['text'])
desc_surface = self.text_font.render(param['description'], True, self.colors['text'])
self.screen.blit(name_surface, (20, y))
self.screen.blit(desc_surface, (20, y + 30))
# Input field
input_rect = pygame.Rect(300, y, 200, 30)
self.draw_text_input(input_rect, param['value'], name == self.active_input)
y += 60
# Buttons
save_rect = pygame.Rect(self.width//2 - 150, self.height - 60, 140, 40)
cancel_rect = pygame.Rect(self.width//2 + 10, self.height - 60, 140, 40)
save_hover = save_rect.collidepoint(mouse_pos)
cancel_hover = cancel_rect.collidepoint(mouse_pos)
self.draw_button(save_rect, "Save", save_hover)
self.draw_button(cancel_rect, "Cancel", cancel_hover)
pygame.display.flip()
pygame.quit()
if save_clicked:
return {key: param['value'] for key, param in self.parameters.items()}
return None

70
src/ai/controller.py Normal file
View File

@ -0,0 +1,70 @@
"""
AI Controller for Snake Game
This module provides the AI controller that uses trained models to play the game.
It handles model loading, state processing, and decision making during gameplay.
"""
import os
import numpy as np
from stable_baselines3 import PPO
from ai.environment import SnakeEnv, Direction
class AIController:
"""AI controller that uses trained models to play the game."""
def __init__(self, difficulty: str = "medium"):
"""
Initialize the AI controller.
Args:
difficulty: "easy", "medium", or "hard"
"""
self.difficulty = difficulty
self.model = None
self.env = SnakeEnv() # For state processing
self._load_model()
def _load_model(self):
"""Load the appropriate model based on difficulty."""
model_path = f"models/{self.difficulty}/best_model.zip"
if not os.path.exists(model_path):
# Fall back to final model if best model doesn't exist
model_path = f"models/{self.difficulty}/final_model.zip"
if not os.path.exists(model_path):
raise FileNotFoundError(
f"No model found for difficulty {self.difficulty}. "
"Please train the model first."
)
self.model = PPO.load(model_path)
def get_action(self, game_state: dict) -> Direction:
"""
Get the next action based on the current game state.
Args:
game_state: Dictionary containing:
- snake: Snake object
- food: Food object
- width: Game width
- height: Game height
Returns:
Direction enum indicating the chosen action
"""
# Update environment with current game state
self.env.snake = game_state["snake"]
self.env.food = game_state["food"]
self.env.width = game_state["width"]
self.env.height = game_state["height"]
# Get state observation
state = self.env._get_state()
# Get action from model
action, _ = self.model.predict(state, deterministic=True)
# Convert action index to Direction
return self.env.action_space[action]

787
src/ai/environment.py Normal file
View File

@ -0,0 +1,787 @@
"""
Snake Game Environment for Reinforcement Learning
This module provides a gym-like interface for training AI agents to play the snake game.
It includes:
- State observation space
- Action space
- Reward system
- Environment dynamics
"""
import numpy as np
import random
import gymnasium as gym
from gymnasium import spaces
from typing import Tuple, List, Dict, Any, Optional
import pygame
from src.core import Snake
from src.core import Direction
from src.core import GameSession
from src.ui import GameArea
from src.config import GameRules
def manhattan_distance(pos1: Tuple[int, int], pos2: Tuple[int, int]) -> int:
"""Calculate Manhattan distance between two points."""
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def manhattan_distance_wrap(pos1: Tuple[int, int], pos2: Tuple[int, int], size: int) -> int:
"""Calculate Manhattan distance between two points with wrap-around."""
dx = abs(pos1[0] - pos2[0])
dy = abs(pos1[1] - pos2[1])
return min(dx, size - dx) + min(dy, size - dy)
class SnakeEnv(gym.Env):
"""A Gymnasium environment for training snake AI agents."""
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
def __init__(self, size=30, render_mode=None, difficulty="easy"):
super().__init__()
self.size = size
self.render_mode = render_mode
self.difficulty = difficulty
# Action to Direction mapping
self.action_to_direction = {
0: Direction.UP,
1: Direction.RIGHT,
2: Direction.DOWN,
3: Direction.LEFT
}
# Direction to index mapping for one-hot encoding
self.direction_to_index = {
Direction.UP: 0,
Direction.RIGHT: 1,
Direction.DOWN: 2,
Direction.LEFT: 3
}
# Initialize reward scales with emphasis on precision
self.reward_scales = {
'food': 20.0, # Base food reward (will be scaled exponentially)
'death': -5.0, # Reduced death penalty to encourage exploration
'distance': 0.5, # Increased distance reward
'survival': -0.01, # Small survival penalty
'milestone': 10.0, # Larger milestone rewards
'efficiency': 0.2, # Small efficiency reward
'exploration': 0.5, # Increased exploration reward
'safety': 1.0, # Safety reward for avoiding danger
'timeout': -1.0, # Reduced timeout penalty
'wrap_bonus': 0.2, # Small wrap-around bonus
'near_miss': -0.1, # Very small near miss penalty
'repetitive': -2 # Increased base penalty for repetitive movement
}
# Near miss detection
self.near_miss_threshold = 1 # Only penalize very close misses
# Time limit parameters - more lenient
self.base_time_limit = 200 # Double base time limit
self.length_time_bonus = 100 # Double length bonus
self.max_time_limit = 1000 # Much longer maximum time
# Exploration decay
self.initial_exploration_bonus = 2.0 # Higher initial exploration
self.exploration_decay = 0.9995 # Slower decay
self.min_exploration_bonus = 0.5 # Higher minimum exploration
self.current_exploration_bonus = self.initial_exploration_bonus
self.total_episodes = 0 # Track number of episodes for decay
# Curriculum learning parameters
self.curriculum_stage = 0
self.success_threshold = {
0: 3, # Stage 0: Basic movement (need more consistent success)
1: 5, # Stage 1: Food collection
2: 7, # Stage 2: Longer snake
3: 10 # Stage 3: Full difficulty
}
self.consecutive_successes = 0
self.stage_requirements = {
0: 1, # Stage 0: Get one food
1: 3, # Stage 1: Get three food
2: 5, # Stage 2: Get five food
3: 7 # Stage 3: Get seven food
}
# Dynamic episode length based on curriculum stage
self.base_max_steps = 300 # More steps for exploration
self.max_steps = self.base_max_steps
# Window dimensions for rendering
self.window_width = 1024
self.window_height = 768
# Create game area and rules
self.game_area = GameArea(self.window_width, self.window_height)
self.rules = GameRules()
self.rules.wrap_around = False # No wrap-around for AI training
# Define action and observation spaces
self.action_space = spaces.Discrete(4) # Up, Right, Down, Left
num_channels = 26 # Total number of channels we created in _get_normalized_observation
self.observation_space = spaces.Box(
low=0,
high=1,
shape=(num_channels, size, size), # (channels, height, width)
dtype=np.float32
)
# Game session
self.session = None
self.steps = 0
self.steps_since_food = 0
self.current_time = 0
self.last_score = 0
self.last_distance = float('inf')
self.last_food_steps = 0
self.steps_since_direction_change = 0 # Track steps since last direction change
self.last_direction = None # Track previous direction
self.recent_positions = []
self.steps_in_same_direction = 0 # Track steps without direction change
self.last_min_food_distance = float('inf') # Initialize minimum distance tracking
# For rendering
self.window = None
self.clock = None
self.score_milestone_rewards = {5: 10.0, 10: 20.0, 15: 30.0} # Bonus rewards at score milestones
self.last_milestone = 0 # Track last milestone reached
self.base_survival_penalty = -1.0 # Much larger constant penalty per step
# Movement control parameters
self.max_direction_steps = 10 # Reduced from 20 to 10
self.direction_step_increase = 2 # Reduced from 5 to 2
self.min_safe_turns = 2 # Reduced from 3 to 2
self.repetitive_threshold = 10 # Base threshold before repetitive penalty kicks in
self.repetitive_scale = 1.2 # Base scale for penalty growth
self.repetitive_wrap_scale = 2.0 # Much faster growth rate in wrap mode
self.max_repetitive_penalty = -5.0 # Increased maximum repetitive movement penalty
# Initialize base rewards with adjusted values
rewards = {
'food': 0.0,
'death': 0.0,
'distance': 0.0,
'survival': self.base_survival_penalty * 0.5,
'direction': 0.0,
'efficiency': 0.0,
'progress': 0.0,
'alignment': 0.0,
'repetitive': 0.0
}
def _get_normalized_observation(self) -> np.ndarray:
"""
Create a multi-channel 2D observation from the game state.
Spatial channels encode the grid layout (snake & food),
while extra scalar features are broadcast over additional channels.
"""
state = self.session.get_state()
snake_body = state["snake_body"]
food_position = state["food_position"]
head = snake_body[0]
body = snake_body
# --- Spatial Channels ---
# Channel 0: Snake positions (1 for snake, 0 otherwise)
snake_channel = np.zeros((self.size, self.size), dtype=np.float32)
for pos in snake_body:
snake_channel[pos[1], pos[0]] = 1.0
assert pos[0] < self.size and pos[1] < self.size, \
f"Snake position {pos} exceeds grid size {self.size}"
# Channel 1: Food position
food_channel = np.zeros((self.size, self.size), dtype=np.float32)
food_channel[food_position[1], food_position[0]] = 1.0
# --- Auxiliary Scalar Features ---
# Compute distances and apply wrap-around adjustments
dx = food_position[0] - head[0]
dy = food_position[1] - head[1]
if self.rules.wrap_around:
if dx > self.size/2:
wrap_dx = dx - self.size
elif dx < -self.size/2:
wrap_dx = dx + self.size
else:
wrap_dx = dx
if dy > self.size/2:
wrap_dy = dy - self.size
elif dy < -self.size/2:
wrap_dy = dy + self.size
else:
wrap_dy = dy
else:
wrap_dx = dx
wrap_dy = dy
# Normalize
dx_norm = dx / self.size
dy_norm = dy / self.size
wrap_dx_norm = wrap_dx / self.size
wrap_dy_norm = wrap_dy / self.size
# Danger indicators (vector of 4 values)
dangers = self._get_danger_observations_all_directions(head, body)
# Other auxiliary scalar features
snake_length_norm = len(body) / (self.size**2)
# One-hot encode current direction
direction_one_hot = [0.0] * 4
direction_one_hot[self.direction_to_index[state["snake_direction"]]] = 1.0
# One-hot encode relative food direction
food_dir = [0.0] * 4
if abs(dx_norm) > abs(dy_norm):
if dx_norm > 0:
food_dir[self.direction_to_index[Direction.RIGHT]] = 1.0
else:
food_dir[self.direction_to_index[Direction.LEFT]] = 1.0
else:
if dy_norm > 0:
food_dir[self.direction_to_index[Direction.DOWN]] = 1.0
else:
food_dir[self.direction_to_index[Direction.UP]] = 1.0
current_time_limit = min(
self.max_time_limit,
self.base_time_limit + (len(body) - 1) * self.length_time_bonus
)
time_pressure = np.clip(self.steps_since_food / current_time_limit, 0, 1)
wrap_mode = float(self.rules.wrap_around)
progress = min(1.0, len(body) / 10)
recent_food = min(1.0, self.steps_since_food / 50)
curriculum_stage_norm = float(self.curriculum_stage) / 3
advanced_length = float(len(body) > 5)
manhattan_norm = min(1.0, manhattan_distance(head, food_position) / (self.size * 2))
# --- Create "broadcasted" channels for each scalar ---
def broadcast_channel(value):
return np.full((self.size, self.size), value, dtype=np.float32)
dx_channel = broadcast_channel(dx_norm)
dy_channel = broadcast_channel(dy_norm)
wrap_dx_channel = broadcast_channel(wrap_dx_norm)
wrap_dy_channel = broadcast_channel(wrap_dy_norm)
snake_length_chan = broadcast_channel(snake_length_norm)
time_pressure_chan = broadcast_channel(time_pressure)
wrap_mode_chan = broadcast_channel(wrap_mode)
progress_chan = broadcast_channel(progress)
recent_food_chan = broadcast_channel(recent_food)
curriculum_chan = broadcast_channel(curriculum_stage_norm)
advanced_length_chan = broadcast_channel(advanced_length)
manhattan_chan = broadcast_channel(manhattan_norm)
# For one-hot features, create one channel per value
direction_channels = [broadcast_channel(v) for v in direction_one_hot]
food_dir_channels = [broadcast_channel(v) for v in food_dir]
# For dangers, one channel per direction (4 channels)
danger_channels = [broadcast_channel(d) for d in dangers]
# --- Stack all channels ---
# You can adjust the channel order as desired.
channels = [
snake_channel, # Channel 0
food_channel, # Channel 1
dx_channel, # Channel 2
dy_channel, # Channel 3
wrap_dx_channel, # Channel 4
wrap_dy_channel, # Channel 5
]
channels.extend(danger_channels) # Channels 6-9
channels.append(snake_length_chan) # Channel 10
channels.extend(direction_channels) # Channels 11-14
channels.extend(food_dir_channels) # Channels 15-18
channels.append(time_pressure_chan) # Channel 19
channels.append(wrap_mode_chan) # Channel 20
channels.append(progress_chan) # Channel 21
channels.append(recent_food_chan) # Channel 22
channels.append(curriculum_chan) # Channel 23
channels.append(advanced_length_chan) # Channel 24
channels.append(manhattan_chan) # Channel 25
# Final observation: shape (num_channels, size, size)
observation = np.stack(channels, axis=0)
return np.clip(observation, 0, 1)
def _get_danger_observations_all_directions(self, head: Tuple[int, int], body: List[Tuple[int, int]]) -> List[float]:
"""Get danger observations in all four directions."""
dangers = [0.0] * 4 # [UP, RIGHT, DOWN, LEFT]
# Check each direction
directions = [Direction.UP, Direction.RIGHT, Direction.DOWN, Direction.LEFT]
for i, direction in enumerate(directions):
dx, dy = direction.to_vector()
next_pos = (head[0] + dx, head[1] + dy)
# Check wall collision
if not self.rules.wrap_around:
if (next_pos[0] < 0 or next_pos[0] >= self.size or
next_pos[1] < 0 or next_pos[1] >= self.size):
dangers[i] = 1.0
continue
else:
next_pos = (next_pos[0] % self.size, next_pos[1] % self.size)
# Check self collision
if next_pos in body[1:]:
dangers[i] = 1.0
return dangers
def reset(self, seed=None, options=None):
"""Reset with curriculum-based difficulty and exploration decay."""
# Call parent reset without unpacking
super().reset(seed=seed)
if seed is not None:
random.seed(seed)
np.random.seed(seed)
# Update exploration bonus
self.total_episodes += 1
self.current_exploration_bonus = max(
self.min_exploration_bonus,
self.initial_exploration_bonus * (self.exploration_decay ** self.total_episodes)
)
# Initialize game session with curriculum-appropriate settings
self.session = GameSession(
self.size,
self.size,
self.rules
)
# Reset counters
self.steps = 0
self.steps_since_food = 0
self.current_time = 0
self.last_score = 0
self.last_distance = float('inf')
self.last_direction = self.session.get_state()["snake_direction"] # Track initial direction
self.recent_positions = []
self.steps_in_same_direction = 0 # Track steps without direction change
self.last_min_food_distance = float('inf') # Initialize minimum distance tracking
# Get initial observation
observation = self._get_normalized_observation()
info = {
"curriculum_stage": self.curriculum_stage,
"exploration_bonus": self.current_exploration_bonus
}
return observation, info
def _calculate_area_penalty(self, head_pos):
"""Calculate penalty for staying in a small area."""
# Add current position to history
self.recent_positions.append(head_pos)
if len(self.recent_positions) > self.max_recent_positions:
self.recent_positions.pop(0)
if len(self.recent_positions) < 10: # Need minimum history to calculate
return 0.0
# Calculate bounding box of recent positions
x_coords = [p[0] for p in self.recent_positions]
y_coords = [p[1] for p in self.recent_positions]
area_width = max(x_coords) - min(x_coords) + 1
area_height = max(y_coords) - min(y_coords) + 1
area = area_width * area_height
# Calculate unique positions visited
unique_positions = len(set(self.recent_positions))
# Penalize small areas and repeated positions
area_penalty = 0.0
if area < 9: # 3x3 grid or smaller
area_penalty = -0.2 * (9 - area) / 8 # Max penalty -0.2 for 1x1 area
# Additional penalty for revisiting same positions frequently
repetition_penalty = -0.3 * (1 - unique_positions / len(self.recent_positions))
return area_penalty + repetition_penalty
def step(self, action):
"""Improved reward structure with length-aware exploration."""
self.steps += 1
self.steps_since_food += 1
# Calculate current time limit based on snake length
snake_length = len(self.session.get_state()["snake_body"])
current_time_limit = min(
self.max_time_limit,
self.base_time_limit + (snake_length - 1) * self.length_time_bonus
)
# Get state before action
prev_state = self.session.get_state()
prev_head = prev_state["snake_head"]
prev_food_dist = manhattan_distance(prev_head, prev_state["food_position"])
if self.rules.wrap_around:
prev_food_dist = min(prev_food_dist,
manhattan_distance_wrap(prev_head, prev_state["food_position"], self.size))
# Convert action to Direction and apply
direction = self.action_to_direction[int(action)]
# Update game state
state, base_reward, done = self.session.step(direction, self.current_time)
# Initialize reward components
rewards = {
'food': 0.0,
'death': 0.0,
'distance': 0.0,
'survival': 0.0,
'milestone': 0.0,
'efficiency': 0.0,
'exploration': 0.0,
'safety': 0.0,
'timeout': 0.0,
'direction': 0.0,
'repetitive': 0.0
}
# Get current state info
curr_head = state["snake_body"][0]
curr_food = state["food_position"]
curr_food_dist = manhattan_distance(curr_head, curr_food)
if self.rules.wrap_around:
curr_food_dist = min(curr_food_dist,
manhattan_distance_wrap(curr_head, curr_food, self.size))
# Update current time
self.current_time += self.session.move_cooldown
# Food reward with exponential scaling based on snake length
food_eaten = state["score"] > prev_state["score"]
if food_eaten:
# Reset timer when food is eaten
self.steps_since_food = 0
# Calculate exponential food reward
base_food_reward = self.reward_scales['food']
length_multiplier = 1.2 ** (len(state["snake_body"]) - 1) # 20% increase per length
rewards['food'] = base_food_reward * length_multiplier
# Add milestone rewards for achieving certain scores
if state["score"] in [5, 10, 15, 20]:
rewards['milestone'] = self.reward_scales['milestone'] * (state["score"] / 5)
else:
# Only apply survival penalty if we haven't eaten food in a while
if self.steps_since_food > current_time_limit / 2:
penalty_factor = (self.steps_since_food - current_time_limit/2) / (current_time_limit/2)
rewards['survival'] = self.reward_scales['survival'] * penalty_factor
# Check if direction changed and update counters
direction_changed = direction != self.last_direction
if direction_changed:
self.steps_in_same_direction = 0
else:
self.steps_in_same_direction += 1
# Apply growing penalty for repetitive movement
if self.steps_in_same_direction > (self.repetitive_threshold - 5 if self.rules.wrap_around else self.repetitive_threshold):
excess_steps = self.steps_in_same_direction - (self.repetitive_threshold - 5 if self.rules.wrap_around else self.repetitive_threshold)
# Use much faster growth rate in wrap-around mode
growth_rate = self.repetitive_wrap_scale if self.rules.wrap_around else self.repetitive_scale
repetitive_penalty = self.reward_scales['repetitive'] * (growth_rate ** excess_steps)
# Apply additional multiplier in wrap-around mode that grows with steps
if self.rules.wrap_around:
wrap_multiplier = 1.0 + (excess_steps * 0.5) # Multiplier grows with each step
repetitive_penalty *= wrap_multiplier
rewards['repetitive'] = repetitive_penalty # No maximum cap - let it grow unbounded in wrap-around mode
else:
# Cap the penalty when not in wrap-around mode
rewards['repetitive'] = max(repetitive_penalty, self.max_repetitive_penalty)
self.last_direction = direction
# Check for timeout or stuck in loop - only if we haven't just eaten food
if (self.steps_since_food >= current_time_limit and not food_eaten):
done = True
rewards['timeout'] = self.reward_scales['timeout']
# Scale timeout penalty based on distance to food
if curr_food_dist < 5: # Harsher penalty for timing out near food
rewards['timeout'] *= 1.5
# Additional penalty for getting stuck in a loop
if self.steps_in_same_direction >= 30: # Also update this threshold
rewards['timeout'] *= 1.2
# Death penalty - calculate AFTER food rewards
if done and not self.steps_since_food >= current_time_limit:
base_death_penalty = self.reward_scales['death']
# Check if death was due to self-collision
if curr_head in state["snake_body"][1:]:
# Base multiplier for self-collision
self_collision_multiplier = 2.0
# Additional penalty scaling with score
# At score 5: 2.5x penalty
# At score 10: 3.0x penalty
# At score 15: 3.5x penalty
score_multiplier = 1.0 + (state["score"] / 10)
# Apply both multipliers
rewards['death'] = base_death_penalty * self_collision_multiplier * score_multiplier
# Additional penalty if we died near food
if curr_food_dist < 5:
rewards['death'] *= 1.5 # Even harsher if we collide with ourselves near food
else:
# For wall collisions, keep original scaling but make it less punishing at higher scores
rewards['death'] = base_death_penalty * (1.0 - min(0.5, state["score"] / 20))
# Add an immediate negative reward for losing potential score
potential_loss_penalty = -0.5 * state["score"] # Larger penalty for dying with higher scores
rewards['death'] += potential_loss_penalty
# Track minimum distance to food and check for near misses
if curr_food_dist < self.last_min_food_distance:
self.last_min_food_distance = curr_food_dist
elif curr_food_dist > self.last_min_food_distance:
# We're moving away from our closest approach
if self.last_min_food_distance <= self.near_miss_threshold:
# We got very close but missed
rewards['distance'] += self.reward_scales['near_miss']
self.last_min_food_distance = float('inf') # Reset tracking
# Progressive distance reward - more reward for getting closer when near food
distance_change = prev_food_dist - curr_food_dist
if curr_food_dist < 5: # When close to food
distance_multiplier = 2.0 # Double the reward/penalty
else:
distance_multiplier = 1.0
rewards['distance'] = distance_change * self.reward_scales['distance'] * distance_multiplier
# Calculate direction penalty
if curr_food_dist <= 5: # Only apply when close to food
# Calculate optimal direction to food
dx = curr_food[0] - curr_head[0]
dy = curr_food[1] - curr_head[1]
if self.rules.wrap_around:
# Adjust for wrap-around
if dx > self.size/2: dx -= self.size
elif dx < -self.size/2: dx += self.size
if dy > self.size/2: dy -= self.size
elif dy < -self.size/2: dy += self.size
# Determine optimal direction(s)
optimal_directions = []
if abs(dx) > abs(dy):
if dx > 0: optimal_directions.append(Direction.RIGHT)
elif dx < 0: optimal_directions.append(Direction.LEFT)
if abs(dy) >= abs(dx):
if dy > 0: optimal_directions.append(Direction.DOWN)
elif dy < 0: optimal_directions.append(Direction.UP)
# Base penalty scales with proximity (closer = higher penalty)
base_penalty = (6 - curr_food_dist) * 0.1 # Scales from 0.1 to 0.5
if direction not in optimal_directions:
# Check if we're moving directly away from food
opposite_dirs = {
Direction.UP: Direction.DOWN,
Direction.DOWN: Direction.UP,
Direction.LEFT: Direction.RIGHT,
Direction.RIGHT: Direction.LEFT
}
optimal_opposites = [opposite_dirs[d] for d in optimal_directions]
if direction in optimal_opposites:
# Double penalty for moving directly away from food
rewards['direction'] = -2 * base_penalty
else:
# Regular penalty for suboptimal direction
rewards['direction'] = -base_penalty
elif curr_food_dist <= 2: # Small reward for correct direction when very close
rewards['direction'] = 0.1
# Efficiency reward - encourage purposeful movement
if not done and self.steps_since_food < 50: # Only apply when actively hunting
rewards['efficiency'] = self.reward_scales['efficiency'] * (1.0 - self.steps_since_food / 50)
# Calculate final reward - simple sum of components
reward = sum(rewards.values())
# Get observation
observation = self._get_normalized_observation()
# No more episode-based truncation
truncated = False
# Include reward components and exploration info in info dict for monitoring
info = {
"score": state["score"],
"reward_components": rewards,
"exploration_bonus": self.current_exploration_bonus,
"direction_changed": direction_changed,
"steps_in_same_direction": self.steps_in_same_direction,
"snake_length": snake_length,
"steps": self.steps # Include total steps for monitoring
}
return observation, reward, done, truncated, info
def render(self):
"""Render the environment."""
if self.window is None and self.render_mode == "human":
pygame.init()
self.window = pygame.display.set_mode((self.window_width, self.window_height))
pygame.display.set_caption("Snake Environment")
self.clock = pygame.time.Clock()
if self.window is not None:
self.window.fill((0, 0, 0))
# Draw game area
self.game_area.draw(self.window)
# Draw game session
self.session.render(self.window, self.game_area)
# Draw score
font = pygame.font.Font(None, 36)
score_text = f"Score: {self.session.score}"
score_surface = font.render(score_text, True, (0, 255, 0))
score_rect = score_surface.get_rect(midtop=(self.window_width // 2, 20))
self.window.blit(score_surface, score_rect)
pygame.display.flip()
self.clock.tick(self.metadata["render_fps"])
if self.render_mode == "rgb_array":
return np.transpose(
np.array(pygame.surfarray.pixels3d(self.window)),
axes=(1, 0, 2)
)
def close(self):
"""Close the environment."""
if self.window is not None:
pygame.quit()
self.window = None
def get_state(self) -> Dict:
"""Get the current game state."""
state = self.session.get_state()
state['steps_since_food'] = self.steps_since_food
# Calculate and include current time limit
snake_length = len(state["snake_body"])
current_time_limit = min(
self.max_time_limit,
self.base_time_limit + (snake_length - 1) * self.length_time_bonus
)
state['current_time_limit'] = current_time_limit
return state
def _get_adjacent_positions(self, position: Tuple[int, int]) -> List[Tuple[int, int]]:
"""
Get all adjacent positions to the given position.
Args:
position: Current position (x, y)
Returns:
List of adjacent positions [(x, y), ...]
Raises:
ValueError: If position is None or not a tuple of 2 integers
ValueError: If position coordinates are not integers
"""
if not position or not isinstance(position, tuple) or len(position) != 2:
raise ValueError("Position must be a tuple of 2 coordinates")
try:
x, y = int(position[0]), int(position[1])
except (TypeError, ValueError):
raise ValueError("Position coordinates must be integers")
# Get positions in all 4 directions (up, right, down, left)
adjacent = [
(x, y - 1), # up
(x + 1, y), # right
(x, y + 1), # down
(x - 1, y) # left
]
# Validate size attribute exists and is positive integer
if not hasattr(self, 'size') or not isinstance(self.size, int) or self.size <= 0:
raise ValueError("Environment size must be a positive integer")
# If wrap-around is enabled, adjust out-of-bounds positions
if hasattr(self, 'rules') and hasattr(self.rules, 'wrap_around') and self.rules.wrap_around:
adjacent = [
(pos[0] % self.size, pos[1] % self.size)
for pos in adjacent
]
return adjacent
# Otherwise, only return positions that are within bounds
return [
(pos[0], pos[1]) for pos in adjacent
if 0 <= pos[0] < self.size and 0 <= pos[1] < self.size
]
def _update_curriculum(self, score: int, done: bool) -> None:
"""Update curriculum stage based on agent's performance."""
if done:
if score >= self.stage_requirements[self.curriculum_stage]:
self.consecutive_successes += 1
else:
self.consecutive_successes = 0
# Check for stage advancement
if (self.curriculum_stage < 3 and
self.consecutive_successes >= self.success_threshold[self.curriculum_stage]):
self.curriculum_stage += 1
self.consecutive_successes = 0
# Increase episode length with curriculum stage
self.max_steps = self.base_max_steps * (1 + self.curriculum_stage * 0.5)
def _get_curriculum_rules(self) -> GameRules:
"""Get game rules appropriate for current curriculum stage."""
rules = GameRules()
# Adjust rules based on curriculum stage
if self.curriculum_stage == 0:
rules.wrap_around = True # Make it easier to avoid walls
rules.speed_increase = False
rules.initial_move_cooldown = 200 # Much slower movement
elif self.curriculum_stage == 1:
rules.wrap_around = True
rules.speed_increase = False # Still no speed increase
rules.initial_move_cooldown = 150
elif self.curriculum_stage == 2:
rules.wrap_around = True # Keep wrap-around until final stage
rules.speed_increase = True
rules.initial_move_cooldown = 120
else: # Stage 3
rules.wrap_around = False
rules.speed_increase = True
rules.initial_move_cooldown = 100
return rules
# Register the environment with Gymnasium
from gymnasium.envs.registration import register
register(
id='Snake-v0',
entry_point='src.ai.environment:SnakeEnv',
kwargs={
'size': 30,
'render_mode': None,
'difficulty': 'easy'
}
)

93
src/ai/model_networks.py Normal file
View File

@ -0,0 +1,93 @@
import gym
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomRecurrentCNN(BaseFeaturesExtractor):
"""
A simple CNN with an LSTM layer added on top. This enables the network to keep an internal
memory of past observations so that it can adapt to changing rules/environment dynamics.
"""
def __init__(self, observation_space, features_dim=512):
# Initialize with a features_dim that matches the LSTM's output.
super(CustomRecurrentCNN, self).__init__(observation_space, features_dim)
# Calculate the expected CNN output size first
n_input_channels = observation_space.shape[0] # Number of input channels
self.cnn = nn.Sequential(
# Layer 1: (n_channels, 30, 30) -> (32, 15, 15)
nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
# Layer 2: (32, 15, 15) -> (64, 8, 8)
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
# Layer 3: (64, 8, 8) -> (64, 4, 4)
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
# Calculate CNN output dimension using a dummy forward pass
with torch.no_grad():
dummy_input = torch.zeros(1, *observation_space.shape)
cnn_output = self.cnn(dummy_input)
self._cnn_output_dim = cnn_output.shape[1] # Should be 64 * 4 * 4 = 1024
# Add linear layer to reduce CNN output to desired feature dimension
self.fc = nn.Linear(self._cnn_output_dim, 256)
# LSTM layer
self.lstm = nn.LSTM(256, features_dim, batch_first=True)
def forward(self, observations):
# Pass through CNN: (batch, channels, height, width) -> (batch, cnn_features)
cnn_out = self.cnn(observations)
# Pass through linear layer: (batch, cnn_features) -> (batch, 256)
fc_out = self.fc(cnn_out)
# Add time dimension for LSTM: (batch, 256) -> (batch, 1, 256)
lstm_input = fc_out.unsqueeze(1)
# LSTM: (batch, 1, 256) -> (batch, 1, features_dim)
lstm_out, _ = self.lstm(lstm_input)
# Remove time dimension: (batch, 1, features_dim) -> (batch, features_dim)
return lstm_out.squeeze(1)
class CustomCNN(BaseFeaturesExtractor):
"""
Custom feature extractor for the snake environment.
Uses fully connected layers since our input is a 1D vector of normalized features.
"""
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
super().__init__(observation_space, features_dim)
self.input_dim = int(np.prod(observation_space.shape))
self.cnn = nn.Sequential(
nn.Linear(self.input_dim, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Linear(512, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Linear(512, features_dim),
nn.ReLU(),
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
# Handle both single and batch observations
if len(observations.shape) == 1:
observations = observations.unsqueeze(0)
# Ensure the input dimension matches what we expect
flat_obs = observations.reshape(observations.shape[0], self.input_dim)
return self.cnn(flat_obs)

423
src/ai/train.py Normal file
View File

@ -0,0 +1,423 @@
"""
Training script for Snake AI using Stable Baselines3.
This script sets up and trains RL agents using the SnakeEnv environment.
It supports multiple difficulty levels and provides real-time visualization
of the training process.
"""
import os
import time
import numpy as np
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor
import gymnasium as gym
import torch.nn as nn
import argparse
from queue import Queue
from threading import Thread
from typing import Dict, Any, Optional
import json
from queue import Empty
from src.ai.environment import SnakeEnv
from src.ai.visualize import run_dashboard
from src.ai.model_networks import CustomRecurrentCNN, CustomCNN
class VisualizationCallback(BaseCallback):
"""Callback for updating the training visualization."""
def __init__(self, viz_queue: Queue, total_timesteps: int, eval_freq: int = 1000, n_steps: int = 512, n_envs: int = 32, initial_timesteps: int = 0):
super().__init__()
self.viz_queue = viz_queue
self.total_timesteps = total_timesteps
self.eval_freq = eval_freq
self.n_steps = n_steps
self.n_envs = n_envs
self.initial_timesteps = initial_timesteps # Store initial timesteps
# Training metrics
self.episode_rewards = []
self.episode_lengths = []
# Evaluation metrics
self.eval_scores = []
self.high_score = 0
self.recent_high_score = 0
# Current episode tracking
self.current_reward = 0
self.current_length = 0
# Evaluation games (one with wrap, one without)
self.eval_envs = [
Monitor(SnakeEnv(render_mode="rgb_array")) for _ in range(2)
]
self.eval_envs[0].unwrapped.rules.wrap_around = False
self.eval_envs[1].unwrapped.rules.wrap_around = True
self.eval_obs = [None, None]
self.eval_done = [True, True]
self.eval_scores_current = [0, 0]
# Real-time demo environment
self.demo_env = Monitor(SnakeEnv(render_mode="rgb_array"))
self.demo_env.unwrapped.rules.wrap_around = True
self.demo_obs = None
self.demo_done = True
self.last_demo_time = 0
self.demo_speed = 50
def _on_step(self) -> bool:
"""Update visualization on each step."""
try:
current_time = time.time() * 1000 # Convert to milliseconds
# Update episode tracking
self.current_length += 1
reward = self.locals.get("rewards", [0])[0]
self.current_reward += reward
# Get action probabilities and value estimate for demo game
if self.demo_obs is not None:
# Ensure proper tensor shape (batch_size, obs_dim)
demo_obs_tensor = torch.as_tensor(self.demo_obs).float()
if len(demo_obs_tensor.shape) == 1:
demo_obs_tensor = demo_obs_tensor.unsqueeze(0)
# Move tensor to same device as model's policy
device = self.model.policy.device
demo_obs_tensor = demo_obs_tensor.to(device)
with torch.no_grad():
# Get action distribution and value estimate
dist = self.model.policy.get_distribution(demo_obs_tensor)
action_probs = dist.distribution.probs
value_estimate = self.model.policy.predict_values(demo_obs_tensor)
# Move results back to CPU and get first item (since we added batch dimension)
action_probs = action_probs[0].cpu().numpy()
value_estimate = value_estimate[0].cpu().numpy()
else:
action_probs = np.zeros(4)
value_estimate = np.array([0.0])
# Check if episode ended
dones = self.locals.get("dones", [False])
if any(dones):
self.episode_rewards.append(self.current_reward)
self.episode_lengths.append(self.current_length)
self.current_reward = 0
self.current_length = 0
# Run evaluation games (keep these deterministic as they're meant to show best performance)
for i, (env, obs, done) in enumerate(zip(self.eval_envs, self.eval_obs, self.eval_done)):
if done:
if obs is not None: # If this wasn't the first reset
self.eval_scores.append(self.eval_scores_current[i])
# Update high scores based on evaluation performance
self.high_score = max(self.high_score, self.eval_scores_current[i])
self.recent_high_score = max(self.recent_high_score, self.eval_scores_current[i])
self.eval_obs[i] = env.reset()[0]
self.eval_done[i] = False
self.eval_scores_current[i] = 0
eval_action, _ = self.model.predict(self.eval_obs[i], deterministic=True)
eval_action = eval_action.item() if hasattr(eval_action, 'item') else eval_action
self.eval_obs[i], _, done, truncated, info = env.step(eval_action)
self.eval_done[i] = done or truncated
# Update current evaluation score
self.eval_scores_current[i] = info.get("score", 0)
# Run real-time demo game
if current_time - self.last_demo_time >= self.demo_speed:
if self.demo_done:
self.demo_obs = self.demo_env.reset()[0]
self.demo_done = False
demo_action, _ = self.model.predict(self.demo_obs, deterministic=True)
demo_action = demo_action.item() if hasattr(demo_action, 'item') else demo_action
self.demo_obs, _, done, truncated, info = self.demo_env.step(demo_action)
self.demo_done = done or truncated
self.last_demo_time = current_time
# Get current states from actual training environments
# Sample one no-wrap and one wrap environment from the training environments
no_wrap_idx = 0 # First half are no-wrap
wrap_idx = self.n_envs // 2 # Second half are wrap
training_state_1 = self.training_env.get_attr("session")[no_wrap_idx].get_state()
training_state_1["wrap_around"] = False
training_state_2 = self.training_env.get_attr("session")[wrap_idx].get_state()
training_state_2["wrap_around"] = True
eval_state_1 = self.eval_envs[0].unwrapped.session.get_state()
eval_state_1["wrap_around"] = self.eval_envs[0].unwrapped.rules.wrap_around
eval_state_2 = self.eval_envs[1].unwrapped.session.get_state()
eval_state_2["wrap_around"] = self.eval_envs[1].unwrapped.rules.wrap_around
demo_state = self.demo_env.unwrapped.session.get_state()
demo_state["wrap_around"] = self.demo_env.unwrapped.rules.wrap_around
# Update viz_data
viz_data = {
'training_state': [training_state_1, training_state_2],
'eval_states': [eval_state_1, eval_state_2],
'demo_state': demo_state,
'weights_info': {
'action_probs': action_probs.tolist() if 'action_probs' in locals() else [0.25] * 4,
'value_estimate': float(value_estimate[0]) if 'value_estimate' in locals() else 0.0,
'action_labels': ['Up', 'Right', 'Down', 'Left']
},
'training_info': {
'total_timesteps': self.num_timesteps,
'initial_timesteps': self.initial_timesteps, # Add initial timesteps
'target_timesteps': self.total_timesteps,
'episode_reward': float(self.current_reward),
'mean_reward': float(np.mean(self.episode_rewards[-100:])) if self.episode_rewards else 0.0,
'episode_length': int(self.current_length),
'mean_length': float(np.mean(self.episode_lengths[-100:])) if self.episode_lengths else 0.0,
'mean_eval_score': float(np.mean(self.eval_scores[-100:])) if self.eval_scores else 0.0,
'rewards_history': self.episode_rewards[-1000:],
'lengths_history': self.episode_lengths[-1000:],
'eval_scores_history': self.eval_scores[-1000:],
'high_score': self.high_score,
'recent_high_score': self.recent_high_score
}
}
# Send update to visualization
self.viz_queue.put(('update', viz_data))
# Reset recent high score periodically
if self.num_timesteps % (self.n_steps * self.n_envs) == 0:
print(f"\nEvaluation High Score: {self.high_score}")
print(f"Recent Evaluation High Score: {self.recent_high_score}")
self.recent_high_score = 0
return True
except Exception as e:
print(f"Error in visualization callback: {e}")
import traceback
traceback.print_exc()
return True
def make_env(seed: int = 0, wrap_walls: bool = False) -> gym.Env:
"""Create a training environment.
Args:
seed: Random seed
wrap_walls: Whether snake can wrap around walls
"""
def _init() -> gym.Env:
env = SnakeEnv(render_mode=None)
env = Monitor(env)
# Set wrap-around mode
env.unwrapped.rules.wrap_around = wrap_walls
env.reset(seed=seed)
return env
set_random_seed(seed)
return _init
def train_model(
total_timesteps: int,
viz_queue: Queue,
model_path: Optional[str] = None,
cuda: bool = True,
learning_rate: float = 3e-4, # Standard PPO learning rate
batch_size: int = 256, # More reasonable batch size
n_envs: int = 16, # Balanced number of environments
n_steps: int = 2048 # Standard PPO steps
) -> None:
"""
Train the snake AI model.
Args:
total_timesteps: Number of timesteps to train for
viz_queue: Queue for sending visualization updates
model_path: Path to save/load the model
cuda: Whether to use CUDA for training
learning_rate: Learning rate for the optimizer
batch_size: Batch size for training
n_envs: Number of parallel environments
n_steps: Number of steps per environment
"""
try:
# Set up environments - half with wrap, half without
envs = []
for i in range(n_envs):
wrap_walls = i >= n_envs // 2 # Half the envs with wrap-around
envs.append(make_env(i, wrap_walls))
env = SubprocVecEnv(envs)
# PPO hyperparameters - adjusted for more stable learning
ppo_params = dict(
learning_rate=1e-4, # Slightly higher learning rate
n_steps=1024, # Shorter horizon for faster updates
batch_size=128, # Smaller batch size for more frequent updates
n_epochs=4, # Balanced number of epochs
gamma=0.98, # Slightly lower discount for more immediate rewards
gae_lambda=0.9, # Lower GAE lambda for more immediate advantage estimates
clip_range=0.1, # Smaller clip range for more conservative updates
clip_range_vf=0.1, # Also clip value function
normalize_advantage=True,
ent_coef=0.2, # Moderate entropy for exploration
vf_coef=0.8, # Balanced value function coefficient
max_grad_norm=0.3, # More conservative gradient clipping
target_kl=None,
verbose=1,
device="cuda" if cuda and torch.cuda.is_available() else "cpu"
)
# Create or load model with updated architecture
policy_kwargs = dict(
features_extractor_class=CustomRecurrentCNN, # Use the new recurrent extractor
features_extractor_kwargs=dict(features_dim=512), # You can adjust features_dim as needed
net_arch=dict(
pi=[512, 256, 128], # Policy network layers
vf=[512, 256, 128] # Value network layers
),
activation_fn=nn.ReLU,
normalize_images=False
)
# Track initial timesteps for progress tracking
initial_timesteps = 0
if model_path and os.path.exists(model_path):
print(f"Loading model from {model_path}")
try:
model = PPO.load(
model_path,
env=env,
custom_objects={"policy_kwargs": policy_kwargs},
**ppo_params
)
initial_timesteps = model.num_timesteps
print(f"Loaded model with {initial_timesteps:,} timesteps of training")
except Exception as e:
print(f"Failed to load model with error: {e}")
print("Creating new model instead")
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, **ppo_params)
else:
print("Creating new model")
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, **ppo_params)
# Set up visualization callback with initial timesteps
viz_callback = VisualizationCallback(
viz_queue=viz_queue,
total_timesteps=total_timesteps,
eval_freq=1000,
n_steps=n_steps,
n_envs=n_envs,
initial_timesteps=initial_timesteps # Pass initial timesteps to callback
)
# Training loop with stop check
should_stop = False
timesteps_per_batch = n_steps * n_envs
num_batches = total_timesteps // timesteps_per_batch
for i in range(num_batches):
if should_stop:
break
# Train for one batch
model.learn(
total_timesteps=timesteps_per_batch,
callback=viz_callback,
progress_bar=True,
reset_num_timesteps=False
)
# Check for stop signal from visualization
try:
while True: # Process all pending messages
msg_type, _ = viz_queue.get_nowait()
if msg_type == 'stop_training':
should_stop = True
break
except Empty:
pass
# Save model if training completed or interrupted
if model_path:
print(f"Saving model to {model_path}")
model.save(model_path)
except Exception as e:
print(f"Error in training process: {e}")
import traceback
traceback.print_exc()
finally:
# Clean up
env.close()
# Signal visualization to stop if we haven't already
viz_queue.put(('stop', None))
def main():
"""Main training script."""
parser = argparse.ArgumentParser(description="Train Snake AI")
parser.add_argument("--config", action="store_true",
help="Show configuration UI before training")
parser.add_argument("--model", type=str, default="models/snake_ai.zip",
help="Path to save/load model")
parser.add_argument("--no-cuda", action="store_true",
help="Disable CUDA training")
args = parser.parse_args()
# Get training configuration
if args.config:
from src.ai.config_ui import ConfigUI
config_ui = ConfigUI()
config = config_ui.run()
if config is None: # User cancelled
return
else:
# Load saved config or use defaults
config_file = "training_config.json"
if os.path.exists(config_file):
with open(config_file, 'r') as f:
config = json.load(f)
else:
config = {
'timesteps': 2000000,
'learning_rate': 0.0001,
'batch_size': 512,
'n_envs': 32,
'n_steps': 512
}
# Create visualization queue
viz_queue = Queue()
# Start visualization in separate thread
viz_thread = Thread(target=run_dashboard, args=(viz_queue,))
viz_thread.start()
# Start training with configuration
train_model(
total_timesteps=config['timesteps'],
viz_queue=viz_queue,
model_path=args.model,
cuda=not args.no_cuda,
learning_rate=config['learning_rate'],
batch_size=config['batch_size'],
n_envs=config['n_envs'],
n_steps=config['n_steps']
)
# Wait for visualization to finish
viz_thread.join()
if __name__ == "__main__":
main()

540
src/ai/visualize.py Normal file
View File

@ -0,0 +1,540 @@
"""
Training Visualization Dashboard
This module provides a real-time visualization of the training process,
showing both the training metrics and evaluation games.
"""
import pygame
import numpy as np
from typing import Dict, Any, Optional, List, Tuple
import threading
from queue import Queue, Empty
import time
class TrainingDashboard:
def __init__(self, width: int = 1920, height: int = 1080):
"""Initialize the training dashboard."""
self.width = width
self.height = height
# Fonts
self.title_font = None
self.header_font = None
self.text_font = None
# Colors
self.colors = {
'background': (0, 0, 0),
'text': (200, 200, 200),
'highlight': (0, 255, 0),
'grid': (40, 40, 40),
'border': (0, 255, 0),
'snake': (0, 255, 0),
'food': (255, 0, 0),
'graph_bg': (20, 20, 20),
'graph_line': (0, 255, 0),
'graph_grid': (40, 40, 40)
}
# Layout
self.layout = {
'margin': 20,
'game_size': 400,
'graph_height': 200,
'metrics_width': 300
}
# Training metrics history
self.metrics_history = {
'rewards': [],
'scores': [],
'lengths': []
}
self.max_history = 1000
# Current state
self.current_state = None
self.eval_state = None
self.demo_state = None
self.training_info = None
self.model_info = None
# Performance tracking
self.frame_times = []
self.max_frame_times = 60
# Initialize the clock
self.clock = None
self.target_fps = 60
def run(self, update_queue: Queue) -> None:
"""Run the dashboard, processing updates from the queue."""
pygame.init()
try:
self.screen = pygame.display.set_mode((self.width, self.height))
pygame.display.set_caption("Snake AI Training Dashboard")
# Initialize fonts after pygame is initialized
self.title_font = pygame.font.Font(None, 48)
self.header_font = pygame.font.Font(None, 36)
self.text_font = pygame.font.Font(None, 24)
# Initialize clock
self.clock = pygame.time.Clock()
running = True
while running:
# Handle events
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
running = False
# Process updates from queue
try:
while True: # Process all available updates
update_type, data = update_queue.get_nowait()
if update_type == 'update':
self.update(data)
elif update_type == 'stop':
running = False
break
except Empty:
pass
# Render dashboard
self.render()
self.clock.tick(self.target_fps)
finally:
pygame.quit()
def update(self, data: Dict[str, Any]) -> None:
"""Update the dashboard with new data."""
self.current_state = data # Store entire data dict
if 'training_info' in data:
self.training_info = data['training_info']
# Update metrics history
for key in ['rewards', 'scores', 'lengths']:
if key in self.training_info:
self.metrics_history[key].append(self.training_info[key])
if len(self.metrics_history[key]) > self.max_history:
self.metrics_history[key].pop(0)
def draw_game_view(self, state: Dict[str, Any], position: Tuple[int, int], size: int, title: str) -> None:
"""Draw a game view at the specified position."""
if not state:
return
x, y = position
# Draw game area background
game_rect = pygame.Rect(x, y, size, size)
pygame.draw.rect(self.screen, self.colors['grid'], game_rect)
pygame.draw.rect(self.screen, self.colors['border'], game_rect, 2)
# Get time limit info directly from state
steps_since_food = state.get('steps_since_food', 0)
current_time_limit = state.get('current_time_limit', 100) # Default to base limit if not provided
time_left = max(0, current_time_limit - steps_since_food)
# Calculate cell size based on grid dimensions
cell_size = size // 30 # Assuming 30x30 grid
# Draw snake
if "snake_body" in state:
for segment in state["snake_body"]:
segment_rect = pygame.Rect(
x + segment[0] * cell_size,
y + segment[1] * cell_size,
cell_size, cell_size
)
pygame.draw.rect(self.screen, self.colors['snake'], segment_rect)
# Draw food
if "food_position" in state:
food_rect = pygame.Rect(
x + state["food_position"][0] * cell_size,
y + state["food_position"][1] * cell_size,
cell_size, cell_size
)
pygame.draw.rect(self.screen, self.colors['food'], food_rect)
# Draw score and timer above game area
score = state.get('score', 0)
score_text = f"Score: {score}"
timer_text = f"Time: {time_left}/{current_time_limit}" # Show both current and max time
combined_text = f"{score_text} | {timer_text}"
score_surface = self.text_font.render(combined_text, True, self.colors['text'])
score_rect = score_surface.get_rect(midtop=(x + size//2, y - 25))
self.screen.blit(score_surface, score_rect)
# Draw title below game area, including wrap mode for training view
if title == "Training Episode" and 'wrap_around' in state:
wrap_mode = "(Wrap)" if state['wrap_around'] else "(No Wrap)"
title = f"{title} {wrap_mode}"
title_text = title
title_surface = self.header_font.render(title_text, True, self.colors['text'])
title_rect = title_surface.get_rect(midtop=(x + size//2, y + size + 10))
self.screen.blit(title_surface, title_rect)
def draw_metrics(self, x: int, y: int, width: int) -> None:
"""Draw training metrics."""
if not self.training_info:
return
font = pygame.font.Font(None, 24)
y_offset = 0
# Draw high scores
high_score = self.training_info.get("high_score", 0)
recent_high = self.training_info.get("recent_high_score", 0)
metrics = [
("All-time High Score", high_score),
("Recent High Score", recent_high),
("", ""), # Empty line for spacing
("Training Metrics", ""),
("Total Steps", self.training_info.get('total_timesteps', 0)),
("Episode Reward", f"{self.training_info.get('episode_reward', 0):.2f}"),
("Mean Reward", f"{self.training_info.get('mean_reward', 0):.2f}"),
("Episode Length", self.training_info.get('episode_length', 0)),
("Mean Length", f"{self.training_info.get('mean_length', 0):.2f}"),
("FPS", f"{len(self.frame_times) / sum(self.frame_times):.1f}" if self.frame_times else "0")
]
for i, (label, value) in enumerate(metrics):
text = f"{label}: {value}" if label else ""
surface = self.text_font.render(text, True, self.colors['text'])
self.screen.blit(surface, (x, y + i * 25))
def draw_graph(self, data: List[float], rect: pygame.Rect, title: str,
color: Tuple[int, int, int], min_val: float, max_val: float,
smoothing_window: int = 10) -> None:
"""Draw a line graph of the given data.
Args:
data: List of data points.
rect: Rectangle area for drawing the graph.
title: Title for the graph.
color: Color for the graph line.
min_val: Minimum value for scaling the y-axis.
max_val: Maximum value for scaling the y-axis.
smoothing_window: Optional smoothing window size for drawing
long term trends. Increase this value for a smoother, longer-term trend.
Default is 10.
"""
if not data:
return
# Create a separate surface for the graph
graph_surface = pygame.Surface((rect.width, rect.height))
graph_surface.fill(self.colors['graph_bg'])
# Draw grid lines on the graph surface
num_lines = 5
for i in range(num_lines):
y_line = i * rect.height / (num_lines - 1)
pygame.draw.line(graph_surface, self.colors['graph_grid'],
(0, y_line), (rect.width, y_line))
# Add padding to the min/max values for better visualization
value_range = max_val - min_val
if value_range == 0:
value_range = 1.0
padding = value_range * 0.1
min_val_adjusted = min_val - padding
max_val_adjusted = max_val + padding
# Smooth the data using the provided smoothing_window parameter
window_size = smoothing_window if len(data) >= smoothing_window else len(data)
if window_size > 1:
smoothed_data = np.convolve(data, np.ones(window_size) / window_size, mode='valid')
else:
smoothed_data = data
points = []
for i, value in enumerate(smoothed_data):
x_point = i * rect.width / len(smoothed_data)
y_point = rect.height - ((value - min_val_adjusted) * rect.height /
(max_val_adjusted - min_val_adjusted))
points.append((x_point, y_point))
if len(points) > 1:
pygame.draw.lines(graph_surface, color, False, points, 2)
# Draw border on the graph surface
pygame.draw.rect(graph_surface, self.colors['border'],
pygame.Rect(0, 0, rect.width, rect.height), 1)
# Blit the graph surface onto the main screen
self.screen.blit(graph_surface, rect)
# Draw title and current value
current_value = data[-1] if data else 0
title_text = f"{title} (Current: {current_value:.2f})"
title_surface = self.text_font.render(title_text, True, self.colors['text'])
title_rect = title_surface.get_rect(midtop=(rect.centerx, rect.top - 20))
self.screen.blit(title_surface, title_rect)
# Draw value labels to the right of the graph
label_spacing = rect.height / (num_lines - 1)
label_x = rect.right + 10
for i in range(num_lines):
value = max_val_adjusted - (i * (max_val_adjusted - min_val_adjusted) / (num_lines - 1))
label_text = f"{value:.1f}"
label_surface = self.text_font.render(label_text, True, self.colors['text'])
label_rect = label_surface.get_rect(
left=label_x,
centery=rect.top + i * label_spacing
)
self.screen.blit(label_surface, label_rect)
def draw_weights_visualization(self, x: int, y: int, width: int, height: int, weights_info: Dict) -> None:
"""Draw the neural network weights visualization."""
if not weights_info:
return
# Get data
action_probs = weights_info.get('action_probs', [0, 0, 0, 0])
value_estimate = weights_info.get('value_estimate', 0.0)
action_labels = weights_info.get('action_labels', ['Up', 'Right', 'Down', 'Left'])
# Draw background
rect = pygame.Rect(x, y, width, height)
pygame.draw.rect(self.screen, self.colors['graph_bg'], rect)
pygame.draw.rect(self.screen, self.colors['border'], rect, 1)
# Draw title
title = "Action Probabilities"
title_surface = self.header_font.render(title, True, self.colors['text'])
title_rect = title_surface.get_rect(midtop=(x + width//2, y + 5))
self.screen.blit(title_surface, title_rect)
# Draw action probability bars
bar_height = 20
bar_spacing = 30
bar_start_y = y + 50
max_bar_width = width - 120 # Leave space for labels
for i, (prob, label) in enumerate(zip(action_probs, action_labels)):
# Draw label
label_surface = self.text_font.render(label, True, self.colors['text'])
self.screen.blit(label_surface, (x + 10, bar_start_y + i * bar_spacing))
# Draw bar background
bar_bg_rect = pygame.Rect(x + 70, bar_start_y + i * bar_spacing, max_bar_width, bar_height)
pygame.draw.rect(self.screen, self.colors['grid'], bar_bg_rect)
# Draw probability bar
bar_width = int(prob * max_bar_width)
bar_rect = pygame.Rect(x + 70, bar_start_y + i * bar_spacing, bar_width, bar_height)
pygame.draw.rect(self.screen, self.colors['highlight'], bar_rect)
# Draw probability value
prob_text = f"{prob:.2f}"
prob_surface = self.text_font.render(prob_text, True, self.colors['text'])
self.screen.blit(prob_surface, (x + 80 + max_bar_width, bar_start_y + i * bar_spacing))
# Draw value estimate
value_text = f"Value Estimate: {value_estimate:.2f}"
value_surface = self.text_font.render(value_text, True, self.colors['text'])
value_rect = value_surface.get_rect(midbottom=(x + width//2, y + height - 10))
self.screen.blit(value_surface, value_rect)
def draw_progress_bar(self, x: int, y: int, width: int, height: int, progress: float, text: str) -> None:
"""Draw a progress bar with text.
Args:
x, y: Position of the progress bar
width, height: Dimensions of the progress bar
progress: Progress value between 0 and 1
text: Text to display above the progress bar
"""
# Draw background
bg_rect = pygame.Rect(x, y, width, height)
pygame.draw.rect(self.screen, self.colors['graph_bg'], bg_rect)
pygame.draw.rect(self.screen, self.colors['border'], bg_rect, 1)
# Draw progress
if progress > 0:
progress_width = int(width * progress)
progress_rect = pygame.Rect(x, y, progress_width, height)
pygame.draw.rect(self.screen, self.colors['highlight'], progress_rect)
# Draw text
text_surface = self.text_font.render(text, True, self.colors['text'])
text_rect = text_surface.get_rect(bottomleft=(x, y - 5))
self.screen.blit(text_surface, text_rect)
def render(self) -> None:
"""Render the dashboard."""
if not self.current_state:
return
self.screen.fill(self.colors['background'])
# Draw game views
if 'training_state' in self.current_state:
# Draw both training games
self.draw_game_view(
self.current_state['training_state'][0], # No wrap training
(50, 50),
250,
"Training (No Wrap)"
)
self.draw_game_view(
self.current_state['training_state'][1], # Wrap training
(350, 50),
250,
"Training (Wrap)"
)
if 'eval_states' in self.current_state:
self.draw_game_view(
self.current_state['eval_states'][0],
(650, 50),
250,
"Evaluation (No Wrap)"
)
self.draw_game_view(
self.current_state['eval_states'][1],
(950, 50),
250,
"Evaluation (Wrap)"
)
if 'demo_state' in self.current_state:
self.draw_game_view(
self.current_state['demo_state'],
(1250, 50),
250,
"Demo Game"
)
# Draw metrics
if self.training_info:
metrics_data = [
(self.training_info['rewards_history'], "Episode Rewards"),
(self.training_info['lengths_history'], "Episode Lengths"),
(self.training_info['eval_scores_history'], "Evaluation Scores")
]
for i, (data, label) in enumerate(metrics_data):
if data: # Only draw if we have data
rect = pygame.Rect(50, 400 + i * 200, 500, 150)
self.draw_graph(
data,
rect,
label,
self.colors['graph_line'],
min(data) if data else 0,
max(data) if data else 1
)
# Draw current metrics text
metrics_text = [
f"Total Steps: {self.training_info.get('total_timesteps', 0):,}",
f"Mean Reward: {self.training_info.get('mean_reward', 0):.2f}",
f"Mean Length: {self.training_info.get('mean_length', 0):.1f}",
f"Mean Eval Score: {self.training_info.get('mean_eval_score', 0):.2f}",
f"High Score: {self.training_info.get('high_score', 0)}",
f"Recent High Score: {self.training_info.get('recent_high_score', 0)}"
]
# Add vertical spacer between metrics sections
metrics_text.insert(3, "") # Insert empty string as spacer after first 3 metrics
font = pygame.font.Font(None, 36)
for i, text in enumerate(metrics_text):
surface = font.render(text, True, self.colors['text'])
self.screen.blit(surface, (600, 400 + i * 30))
# Draw action probabilities if available
if 'weights_info' in self.current_state:
self.draw_weights_visualization(
1100, 400,
300, 300,
self.current_state['weights_info']
)
# Draw trend indicator graph
if self.training_info:
# Calculate trend indicators
window = 100 # Use last 100 episodes for trends
rewards = self.training_info['rewards_history'][-window:]
lengths = self.training_info['lengths_history'][-window:]
eval_scores = self.training_info['eval_scores_history'][-window:]
if rewards and lengths and eval_scores:
# Normalize each metric to 0-1 range for fair combination
norm_rewards = [(r - min(rewards)) / (max(rewards) - min(rewards) + 1e-8) for r in rewards]
norm_lengths = [(l - min(lengths)) / (max(lengths) - min(lengths) + 1e-8) for l in lengths]
norm_scores = [(s - min(eval_scores)) / (max(eval_scores) - min(eval_scores) + 1e-8) for s in eval_scores]
# Combine metrics with weights (rewards: 0.4, lengths: 0.3, eval_scores: 0.3)
trend_data = [0.4 * r + 0.3 * l + 0.3 * s
for r, l, s in zip(norm_rewards, norm_lengths, norm_scores)]
# Add trend direction indicator
trending = 0 # -1, 0, 1
if len(trend_data) >= 2:
recent_trend = trend_data[-1] - trend_data[0]
trending = 1 if recent_trend > 0.1 else -1 if recent_trend < -0.1 else 0
trend_text = "↑ Improving" if trending == 1 else "↓ Declining" if trending == -1 else "→ Stable"
trend_color = (0, 255, 0) if trending == 1 else (255, 0, 0) if trending == -1 else (200, 200, 0)
trend_surface = self.text_font.render(trend_text, True, trend_color)
trend_rect = trend_surface.get_rect(midtop=(1250, 860))
self.screen.blit(trend_surface, trend_rect)
graph_color = (0, 255, 0) if trending == 1 else (255, 0, 0) if trending == -1 else (200, 200, 0)
# Draw trend graph 50 pixels below the action probabilities graph
trend_rect = pygame.Rect(1100, 750, 300, 150)
self.draw_graph(
trend_data,
trend_rect,
"Learning Trend",
graph_color, # color based on trend
0, # Min value (normalized)
1, # Max value (normalized)
smoothing_window=10 # Use a larger smoothing window for longer-term trends
)
# Draw training progress bar at the bottom of the screen
if self.training_info:
# Get current session steps (from start of this training run)
current_session_steps = self.training_info.get('total_timesteps', 0) - self.training_info.get('initial_timesteps', 0)
target_timesteps = self.training_info.get('target_timesteps', 1000000)
progress = min(1.0, current_session_steps / target_timesteps)
# Show both current session progress and total model training
total_model_steps = self.training_info.get('total_timesteps', 0)
progress_text = (
f"Current Session: {current_session_steps:,}/{target_timesteps:,} steps ({progress*100:.1f}%) | "
f"Total Model Training: {total_model_steps:,} steps"
)
self.draw_progress_bar(50, self.height - 40, self.width - 100, 20, progress, progress_text)
pygame.display.flip()
def handle_events(self):
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
self.running = False
def run_dashboard(update_queue: Queue) -> None:
"""Run the training dashboard in a separate thread."""
dashboard = TrainingDashboard()
try:
dashboard.run(update_queue)
finally:
# Send stop signal back to training process
update_queue.put(('stop_training', None))
pygame.quit()

View File

@ -6,7 +6,8 @@ It handles command-line arguments for different game modes and configurations.
"""
import argparse
from src import Game, GameMode
from src.game import Game, GameState
from src.ui.menu import GameMode
def main():
"""
@ -35,12 +36,19 @@ def main():
if args.test:
game.width = 400
game.height = 300
game.block_size = 20
game.fps = 30
game.game_session.grid_size = 20
game.clock.tick(30) # Set FPS to 30
# Apply debug mode settings
if args.debug:
game.debug = True
game.settings.debug_mode = True
if game.game_session:
game.game_session.show_grid = True
# Apply AI mode settings
if args.ai_only:
game.game_mode = GameMode.AI_MEDIUM
game.state = GameState.PLAYING
# Run the game
try:

36
src/config/__init__.py Normal file
View File

@ -0,0 +1,36 @@
"""
Game Configuration Package
This package contains game configuration and constants.
"""
from .constants import *
from .settings import GameRules, GameSettings
__all__ = [
'GameRules',
'GameSettings',
# Include all constants
'WINDOW_WIDTH',
'WINDOW_HEIGHT',
'DEFAULT_GRID_WIDTH',
'DEFAULT_GRID_HEIGHT',
'DEFAULT_GRID_SIZE',
'DEFAULT_PADDING',
'DEFAULT_SCORE_HEIGHT',
'BLACK',
'WHITE',
'GREEN',
'DARK_GREEN',
'GRAY',
'DARK_GRAY',
'GRID_COLOR',
'TITLE_FONT_SIZE',
'SUBTITLE_FONT_SIZE',
'SCORE_FONT_SIZE',
'DEBUG_FONT_SIZE',
'FPS',
'DEFAULT_MOVE_COOLDOWN',
'MIN_MOVE_COOLDOWN',
'MENU_ITEM_SPACING'
]

25
src/config/colors.py Normal file
View File

@ -0,0 +1,25 @@
"""
Game Colors
This module contains all the color values used throughout the game.
"""
# Modern color palette
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
GREEN = (0, 255, 0)
DARK_GREEN = (0, 100, 0)
NEON_GREEN = (57, 255, 20) # Brighter, more neon green
LIME_GREEN = (50, 205, 50) # Softer green for variety
FOREST_GREEN = (0, 100, 0) # Dark green for contrast
DARKER_GREEN = (0, 40, 0) # Very dark green for borders
GRAY = (128, 128, 128)
DARK_GRAY = (25, 25, 25) # Darker background for better contrast
GRID_COLOR = (40, 40, 40)
SUBTLE_GRID_COLOR = (35, 35, 35)
RED = (255, 0, 0)
NEON_RED = (255, 20, 57) # Brighter, more neon red for food
GLOW_GREEN = (150, 255, 150) # Light green for glow effects
SNAKE_GRADIENT_START = (57, 255, 20) # Head color
SNAKE_GRADIENT_END = (0, 150, 0) # Tail color

31
src/config/constants.py Normal file
View File

@ -0,0 +1,31 @@
"""
Game Constants
This module contains all the constant values used throughout the game.
"""
# Window settings
WINDOW_WIDTH = 1024
WINDOW_HEIGHT = 768
# Grid settings
DEFAULT_GRID_WIDTH = 30
DEFAULT_GRID_HEIGHT = 30
DEFAULT_GRID_SIZE = 20
# Game area settings
DEFAULT_PADDING = 40
DEFAULT_SCORE_HEIGHT = 60
# Font sizes
TITLE_FONT_SIZE = 72
SUBTITLE_FONT_SIZE = 36
SCORE_FONT_SIZE = 36
DEBUG_FONT_SIZE = 24
# Game timing
FPS = 60
DEFAULT_MOVE_COOLDOWN = 100
MIN_MOVE_COOLDOWN = 50
# Menu settings
MENU_ITEM_SPACING = 50

52
src/config/settings.py Normal file
View File

@ -0,0 +1,52 @@
"""
Game Settings
This module contains the game settings and rules that can be configured.
"""
class GameRules:
"""Game rules that can be configured through the settings menu."""
def __init__(self):
self.wrap_around = True # Whether snake wraps around screen edges
self.speed_increase = True # Whether snake speeds up as it grows
self.min_move_cooldown = 50 # Minimum movement delay in milliseconds
self.initial_move_cooldown = 100 # Initial movement delay
self.starting_length = 3 # Starting length of the snake
def update_rule(self, rule_name, value):
"""
Update a game rule with a new value.
Args:
rule_name (str): Name of the rule to update
value: New value for the rule
Returns:
bool: True if update was successful, False otherwise
"""
if hasattr(self, rule_name):
# Define validation rules for specific attributes
validators = {
'starting_length': lambda x: 1 <= x <= 10,
'wrap_around': lambda x: isinstance(x, bool),
'speed_increase': lambda x: isinstance(x, bool),
'min_move_cooldown': lambda x: x > 0,
'initial_move_cooldown': lambda x: x > 0
}
# Validate the value if there's a validator for this rule
if rule_name in validators and not validators[rule_name](value):
return False
setattr(self, rule_name, value)
return True
return False
class GameSettings:
"""Global game settings."""
def __init__(self):
self.debug_mode = False # Whether to show debug information
self.show_grid = self.debug_mode # Whether to show the grid lines
self.rules = GameRules() # Game rules instance

16
src/core/__init__.py Normal file
View File

@ -0,0 +1,16 @@
"""
Game Core Package
This package contains the core game mechanics and entities.
"""
from .snake import Snake, Direction
from .food import Food
from .game_session import GameSession
__all__ = [
'Snake',
'Direction',
'Food',
'GameSession'
]

82
src/core/food.py Normal file
View File

@ -0,0 +1,82 @@
"""
Food Module
This module provides the Food class that represents the food item in the game.
"""
import pygame
import random
from typing import Tuple, List
from src.ui import GameArea
from src.config.colors import *
class Food:
def __init__(self, color: Tuple[int, int, int] = NEON_RED):
"""
Initialize food with a color.
Args:
color: RGB color tuple for the food
"""
self.color = color
self.position = (0, 0) # Will be set by spawn_at_position()
def spawn_at_position(self, available_positions: List[Tuple[int, int]]) -> None:
"""
Spawn food at a random position from the available positions.
Args:
available_positions: List of valid grid positions where food can spawn
"""
if not available_positions:
# No valid positions (snake fills screen) - game should be won
return
# Choose random position from valid positions
self.position = random.choice(available_positions)
def draw(self, screen: pygame.Surface, game_area: GameArea) -> None:
"""
Draw the food on the screen with a glowing effect
Args:
screen: Pygame surface to draw on
game_area: GameArea instance for coordinate conversion
"""
screen_x, screen_y = game_area.get_screen_pos(self.position[0], self.position[1])
# Draw outer glow
glow_size = game_area.grid_size + 4
glow_surface = pygame.Surface((glow_size, glow_size), pygame.SRCALPHA)
# Create radial gradient for glow
for i in range(3):
alpha = 100 - (i * 30) # Fade out alpha
size = glow_size - (i * 2)
pos = (glow_size - size) // 2
pygame.draw.circle(glow_surface, (*NEON_RED[:3], alpha), (glow_size//2, glow_size//2), size//2)
# Draw glow
screen.blit(glow_surface, (screen_x - 2, screen_y - 2))
# Draw main food body
food_rect = pygame.Rect(
screen_x + 2,
screen_y + 2,
game_area.grid_size - 4,
game_area.grid_size - 4
)
pygame.draw.rect(screen, NEON_RED, food_rect, border_radius=4)
# Draw highlight
highlight_rect = pygame.Rect(
screen_x + 4,
screen_y + 4,
game_area.grid_size - 8,
game_area.grid_size - 8
)
pygame.draw.rect(screen, (*NEON_RED, 200), highlight_rect, border_radius=3)
def check_collision(self, position: Tuple[int, int]) -> bool:
"""Check if the given position collides with the food"""
return self.position == position

239
src/core/game_session.py Normal file
View File

@ -0,0 +1,239 @@
"""
Game Session Manager
This module provides a reusable game session class that encapsulates the core gameplay mechanics.
It can be used by both the main game and the training environment to ensure consistent behavior.
"""
import pygame
from typing import Tuple, Dict, Optional
from src.core import Snake, Direction, Food
from src.ui import GameArea
from src.config import (
WINDOW_WIDTH,
WINDOW_HEIGHT,
FPS
)
from src.config.colors import *
from src.config.settings import GameRules
class GameSession:
"""Manages a single game session with consistent rules and boundaries."""
def __init__(self, grid_width: int, grid_height: int, rules : GameRules, window_width: int = WINDOW_WIDTH, window_height: int = WINDOW_HEIGHT):
"""
Initialize a new game session.
Args:
grid_width: Number of grid cells horizontally
grid_height: Number of grid cells vertically
rules: GameRules instance containing game settings
window_width: Width of the game window
window_height: Height of the game window
"""
self.grid_width = grid_width
self.grid_height = grid_height
self.rules = rules
self.window_width = window_width
self.window_height = window_height
# Game objects
self.snake = None
self.food = None
# Game state
self.score = 0
self.is_game_over = False
self.move_cooldown = rules.initial_move_cooldown
self.last_move_time = 0
self.reset()
def reset(self) -> Dict:
"""
Reset the game session to its initial state, including the game area.
Returns:
Dict containing the initial game state
"""
# Initialize game area with window dimensions
self.game_area = GameArea(
window_width=self.window_width,
window_height=self.window_height,
grid_width=self.grid_width,
grid_height=self.grid_height
)
# Initialize snake in the middle of the grid
start_x = self.grid_width // 2
start_y = self.grid_height // 2
self.snake = Snake(
start_pos=(start_x, start_y),
grid_width=self.grid_width,
grid_height=self.grid_height
)
# Initialize food
self.food = Food()
self._spawn_food()
# Reset game state
self.score = 0
self.is_game_over = False
self.move_cooldown = self.rules.initial_move_cooldown
self.last_move_time = 0
return self.get_state()
def _spawn_food(self) -> None:
"""Spawn food in a random empty grid cell."""
# Get all occupied positions
occupied = set(self.snake.body)
# Get all possible positions
all_positions = [(x, y) for x in range(self.grid_width)
for y in range(self.grid_height)]
# Filter out occupied positions
available = [pos for pos in all_positions if pos not in occupied]
if available:
self.food.spawn_at_position(available)
def step(self, action: Optional[Direction] = None, current_time: int = 0) -> Tuple[Dict, float, bool]:
"""
Advance the game state by one step.
Args:
action: Optional direction to change to
current_time: Current game time in milliseconds
Returns:
Tuple of (game_state, reward, done)
"""
if self.is_game_over:
return self.get_state(), 0, True
# Apply action if provided
if action is not None:
self.snake.change_direction(action)
# Update snake movement animation
self.snake.update_movement(1.0 / FPS) #
# Check if enough time has passed for next move
if current_time - self.last_move_time < self.move_cooldown:
return self.get_state(), 0, False
self.last_move_time = current_time
# Process any pending direction changes first
if self.snake.input_buffer:
# Try first direction
next_direction = self.snake.input_buffer[0]
if self.snake.is_direction_valid(next_direction):
self.snake.direction = next_direction
self.snake.input_buffer.pop(0)
# If invalid and we have another input, try that one first
elif len(self.snake.input_buffer) > 1:
alt_direction = self.snake.input_buffer[1]
if self.snake.is_direction_valid(alt_direction):
self.snake.direction = alt_direction
self.snake.input_buffer.pop(1) # Remove the successful second input
# Remove the first input regardless as it's had its chance
self.snake.input_buffer.pop(0)
else:
# Single invalid input, just remove it
self.snake.input_buffer.pop(0)
# Now get movement vector with updated direction
dx, dy = self.snake.direction.to_vector()
new_head = (self.snake.body[0][0] + dx, self.snake.body[0][1] + dy)
# Check for immediate wall collision (before moving)
if not self.rules.wrap_around:
if (new_head[0] < 0 or new_head[0] >= self.grid_width or
new_head[1] < 0 or new_head[1] >= self.grid_height):
self.is_game_over = True
return self.get_state(), -1, True
# check for collision with snake body
if new_head in self.snake.body[1:]:
self.is_game_over = True
return self.get_state(), -1, True
# Move snake
head_x, head_y = self.snake.move()
# Handle wrap-around if enabled
if self.rules.wrap_around:
head_x = head_x % self.grid_width
head_y = head_y % self.grid_height
self.snake.body[0] = (head_x, head_y)
self.snake.target_positions[0] = (head_x, head_y) # Update target position for smooth movement
# Check self collision
if (head_x, head_y) in self.snake.body[1:]:
self.is_game_over = True
return self.get_state(), -1, True
# Check food collision
if (head_x, head_y) == self.food.position:
self.score += 1
self.snake.grow()
self._spawn_food()
# Increase speed if enabled
if self.rules.speed_increase:
self.move_cooldown = max(
self.rules.min_move_cooldown,
self.move_cooldown - 2
)
return self.get_state(), 1, False
return self.get_state(), 0, False
def get_state(self) -> Dict:
"""
Get the current game state.
Returns:
Dict containing the game state
"""
return {
"grid_width": self.grid_width,
"grid_height": self.grid_height,
"snake_head": self.snake.body[0],
"snake_body": self.snake.body,
"snake_direction": self.snake.direction,
"food_position": self.food.position,
"score": self.score,
"game_over": self.is_game_over,
"move_cooldown": self.move_cooldown
}
def render(self, screen: pygame.Surface) -> None:
"""
Render the game state to the screen.
Args:
screen: Pygame surface to render to
"""
# Draw game area first
self.game_area.draw(screen)
# Draw game objects
self.snake.draw(screen, self.game_area)
self.food.draw(screen, self.game_area)
# Draw grid (optional, for debugging)
if hasattr(self, 'show_grid') and self.show_grid:
for x in range(self.grid_width + 1):
start_pos = (self.game_area.x + x * self.game_area.grid_size, self.game_area.y)
end_pos = (self.game_area.x + x * self.game_area.grid_size, self.game_area.y + self.game_area.height)
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
for y in range(self.grid_height + 1):
start_pos = (self.game_area.x, self.game_area.y + y * self.game_area.grid_size)
end_pos = (self.game_area.x + self.game_area.width, self.game_area.y + y * self.game_area.grid_size)
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)

357
src/core/snake.py Normal file
View File

@ -0,0 +1,357 @@
"""
Snake Module
This module provides the Snake class and Direction enum for snake movement and rendering.
"""
import pygame
from enum import Enum, auto
from typing import Tuple, List
import math
from src.config.colors import *
import pygame.gfxdraw
class Direction(Enum):
UP = auto()
DOWN = auto()
LEFT = auto()
RIGHT = auto()
def to_vector(self) -> Tuple[int, int]:
# Convert direction to a movement vector
if self == Direction.UP:
return (0, -1)
elif self == Direction.DOWN:
return (0, 1)
elif self == Direction.LEFT:
return (-1, 0)
else: # Direction.RIGHT
return (1, 0)
class Snake:
def __init__(self, start_pos: Tuple[int, int], grid_width: int, grid_height: int, length: int = 3):
"""
Initialize snake at the given grid position.
Args:
start_pos: Starting position in grid coordinates (x, y)
grid_width: Width of the game grid
grid_height: Height of the game grid
length: Initial length of the snake
"""
self.direction = Direction.RIGHT
self.body = [start_pos] # Head is at index 0
self.growing = False
self.grid_width = grid_width
self.grid_height = grid_height
# FIX: Start from 1 to avoid duplicate head segment.
for i in range(1, length):
self.body.append((start_pos[0] - i, start_pos[1]))
# Initialize visual positions for each body segment
self.visual_positions = [(float(pos[0]), float(pos[1])) for pos in self.body]
self.move_progress = 1.0 # Progress of current move (0.0 to 1.0)
self.move_speed = 0.2 # Movement speed
self.target_positions = list(self.body) # Target grid positions
# Input buffer for direction changes
self.input_buffer = [] # Store up to 2 pending direction changes
self.max_buffer_size = 2
# Create base segment surface with gradient
self.segment_size = 32 # Base size for segment surface
self.base_surface = pygame.Surface((self.segment_size, self.segment_size), pygame.SRCALPHA)
# Create a radial gradient for the base segment
center = self.segment_size // 2
for radius in range(center, -1, -1):
# Calculate color intensity based on distance from center
intensity = 1.0 - (radius / center) ** 0.8 # Adjust power for gradient shape
color = (
int(SNAKE_GRADIENT_START[0] * intensity),
int(SNAKE_GRADIENT_START[1] * intensity),
int(SNAKE_GRADIENT_START[2] * intensity),
255
)
pygame.draw.circle(self.base_surface, color, (center, center), radius)
def move(self) -> Tuple[int, int]:
"""
Move the snake one grid cell in current direction.
Returns:
New head position (x, y)
"""
# Get new head position
head = self.body[0]
new_head = self._get_new_head_position(head)
# Update grid positions
self.body.insert(0, new_head)
if not self.growing:
self.body.pop()
else:
self.growing = False
# Update visual positions
self.target_positions = list(self.body) # Copy current grid positions as targets
if len(self.visual_positions) < len(self.target_positions):
self.visual_positions.append(self.visual_positions[-1])
# Reset move progress for smooth animation
self.move_progress = 0.0
return new_head
def update_movement(self, dt: float):
"""Update visual position interpolation with consistent timing."""
if self.move_progress >= 1.0:
return
# Use cubic easing for smoother acceleration/deceleration
self.move_progress = min(1.0, self.move_progress + self.move_speed)
t = self.move_progress
overall_progress = t * t * (3 - 2 * t)
# Update each segment with fixed delay windows
segment_count = len(self.visual_positions)
for i in range(segment_count):
if i >= len(self.target_positions):
continue
current = self.visual_positions[i]
target = self.target_positions[i]
dx = target[0] - current[0]
dy = target[1] - current[1]
# Check if this segment needs to wrap
wrap_x = abs(dx) > 2
wrap_y = abs(dy) > 2
if wrap_x or wrap_y:
self.visual_positions[i] = target
else:
# Use fixed delay windows instead of compounding delays
delay_window = 0.2 # Total delay spread across all segments
segment_delay = (i / segment_count) * delay_window
# Calculate progress for this segment
segment_progress = max(0.0, (overall_progress - segment_delay) / (1.0 - segment_delay))
segment_progress = min(1.0, segment_progress)
# Apply smoothed movement
new_x = current[0] + dx * segment_progress
new_y = current[1] + dy * segment_progress
self.visual_positions[i] = (new_x, new_y)
def _get_new_head_position(self, head: Tuple[int, int]) -> Tuple[int, int]:
"""Calculate new head position based on current direction"""
x, y = head
dx, dy = self.direction.to_vector()
return (x + dx, y + dy)
def grow(self):
"""Mark the snake to grow on next move"""
self.growing = True
def check_collision(self, width: int, height: int, wrap_around: bool = False) -> bool:
"""
Check if snake has collided with walls or itself.
Args:
width: Game area width
height: Game area height
wrap_around: If True, snake wraps around screen edges instead of colliding
"""
head = self.body[0]
if not wrap_around:
# Check wall collision
if (head[0] < 0 or head[0] >= width or
head[1] < 0 or head[1] >= height):
return True
# Check self collision (skip head)
if head in self.body[1:]:
return True
return False
def wrap_position(self, width: int, height: int):
"""Wrap snake's head position around screen edges"""
head_x, head_y = self.body[0]
wrapped_head = (
head_x % width,
head_y % height
)
self.body[0] = wrapped_head
def is_direction_valid(self, new_direction: Direction) -> bool:
"""
Try to change direction, ensuring no 180-degree turns.
Args:
new_direction: The direction to change to
Returns:
True if direction was changed, False otherwise
"""
opposite_directions = {
Direction.UP: Direction.DOWN,
Direction.DOWN: Direction.UP,
Direction.LEFT: Direction.RIGHT,
Direction.RIGHT: Direction.LEFT
}
if opposite_directions[new_direction] == self.direction or new_direction == self.direction:
return False
self.direction = new_direction
return True
def change_direction(self, new_direction: Direction):
"""
Buffer a direction change to be applied on next move.
Args:
new_direction: The direction to change to
"""
if self.input_buffer and self.input_buffer[-1] == new_direction:
return
if len(self.input_buffer) < self.max_buffer_size:
self.input_buffer.append(new_direction)
def draw(self, screen: pygame.Surface, game_area):
"""Draw snake with clean, connected design and cute tongue."""
if len(self.visual_positions) < 2:
return
base_size = game_area.grid_size
segment_size = base_size - 4
# Draw snake body segments
for i in range(len(self.visual_positions) - 1):
pos1 = self.visual_positions[i]
pos2 = self.visual_positions[i + 1]
screen_positions = self._get_wrapped_segment_positions(pos1, pos2, game_area)
if not screen_positions:
continue
for sp1, sp2 in screen_positions:
# Calculate segment properties
progress = max(0.4, 1.0 - (i / len(self.visual_positions)))
color = (
int(SNAKE_GRADIENT_START[0] * progress),
int(SNAKE_GRADIENT_START[1] * progress),
int(SNAKE_GRADIENT_START[2] * progress)
)
# Draw connecting segments with extra thickness
center1 = (sp1[0] + base_size//2, sp1[1] + base_size//2)
center2 = (sp2[0] + base_size//2, sp2[1] + base_size//2)
# Draw a thicker base line first
pygame.draw.line(screen, color, center1, center2, segment_size + 4)
# Draw larger circles at joints to ensure corner coverage
pygame.draw.circle(screen, color, center1, (segment_size + 2)//2)
pygame.draw.circle(screen, color, center2, (segment_size + 2)//2)
# If this is a corner (check by looking at next segment)
if i < len(self.visual_positions) - 2:
pos3 = self.visual_positions[i + 2]
if pos2[0] != pos1[0] and pos3[1] != pos2[1] or \
pos2[1] != pos1[1] and pos3[0] != pos2[0]:
# Draw extra circle at the corner to ensure coverage
pygame.draw.circle(screen, color, center2, (segment_size + 6)//2)
# Draw head
head_pos = self.visual_positions[0]
screen_pos = game_area.get_screen_pos(head_pos[0], head_pos[1])
center = (screen_pos[0] + base_size//2, screen_pos[1] + base_size//2)
# Draw head circle
pygame.draw.circle(screen, SNAKE_GRADIENT_START, center, segment_size//2)
# Calculate direction for eyes and tongue
next_pos = self.visual_positions[1] if len(self.visual_positions) > 1 else head_pos
dx = head_pos[0] - next_pos[0] # Reversed the direction calculation
dy = head_pos[1] - next_pos[1] # Reversed the direction calculation
angle = math.atan2(dy, dx)
# Draw eyes
eye_offset = segment_size//4
eye_size = 2
eye_angle = angle + math.pi/2
for side in [-1, 1]:
eye_x = center[0] + math.cos(eye_angle) * eye_offset * side
eye_y = center[1] + math.sin(eye_angle) * eye_offset
pygame.draw.circle(screen, (0, 0, 0), (int(eye_x), int(eye_y)), eye_size)
# Draw flickering tongue with more pronounced animation
time = pygame.time.get_ticks()
tongue_flick = math.sin(time * 0.01) * 0.8 + 0.2 # Slower, more pronounced flicking
# Tongue base position (at front of head)
tongue_base_x = center[0] + math.cos(angle) * (segment_size//2)
tongue_base_y = center[1] + math.sin(angle) * (segment_size//2)
# Tongue length varies with flicking animation
tongue_length = (3 + tongue_flick * 6) # Length varies between 3 and 9 pixels
fork_length = 4 # Slightly longer fork
fork_angle = math.pi/3 # Wider fork angle
# Calculate tongue tip
tongue_tip_x = tongue_base_x + math.cos(angle) * tongue_length
tongue_tip_y = tongue_base_y + math.sin(angle) * tongue_length
# Calculate fork tips
left_fork_x = tongue_tip_x + math.cos(angle + fork_angle) * fork_length
left_fork_y = tongue_tip_y + math.sin(angle + fork_angle) * fork_length
right_fork_x = tongue_tip_x + math.cos(angle - fork_angle) * fork_length
right_fork_y = tongue_tip_y + math.sin(angle - fork_angle) * fork_length
# Draw tongue with brighter red
tongue_color = (255, 0, 0) # Bright red
# Draw main tongue line
pygame.draw.line(screen, tongue_color,
(tongue_base_x, tongue_base_y),
(tongue_tip_x, tongue_tip_y), 2)
# Draw forked tips
pygame.draw.line(screen, tongue_color,
(tongue_tip_x, tongue_tip_y),
(left_fork_x, left_fork_y), 2)
pygame.draw.line(screen, tongue_color,
(tongue_tip_x, tongue_tip_y),
(right_fork_x, right_fork_y), 2)
def _get_wrapped_segment_positions(self, pos1, pos2, game_area):
"""
Get screen positions for segment rendering, handling wrap-around.
Only returns positions for segments that are close enough to interpolate.
"""
grid_w = self.grid_width
grid_h = self.grid_height
# Calculate primary direction
dx = pos2[0] - pos1[0]
dy = pos2[1] - pos1[1]
# If segments are too far apart (wrapping), don't draw connecting segment
if abs(dx) > 2 or abs(dy) > 2:
return []
# Convert to screen coordinates
screen_pos1 = game_area.get_screen_pos(pos1[0], pos1[1])
screen_pos2 = game_area.get_screen_pos(pos2[0], pos2[1])
return [(screen_pos1, screen_pos2)]

View File

@ -1,50 +0,0 @@
import pygame
import random
from typing import Tuple, List
class Food:
def __init__(self, block_size: int, color: Tuple[int, int, int] = (255, 0, 0)):
self.block_size = block_size
self.color = color
self.position = (0, 0) # Will be set by spawn()
def spawn(self, width: int, height: int, occupied_positions: List[Tuple[int, int]]) -> None:
"""
Spawn food at a random position that isn't occupied by the snake.
Args:
width: Game area width
height: Game area height
occupied_positions: List of positions (typically snake body positions) where food shouldn't spawn
"""
# Calculate all valid grid positions
valid_positions = [
(x, y)
for x in range(0, width, self.block_size)
for y in range(0, height, self.block_size)
if (x, y) not in occupied_positions
]
if not valid_positions:
# No valid positions (snake fills screen) - game should be won
return
# Choose random position from valid positions
self.position = random.choice(valid_positions)
def draw(self, screen: pygame.Surface) -> None:
"""Draw the food on the screen"""
pygame.draw.rect(
screen,
self.color,
pygame.Rect(
self.position[0],
self.position[1],
self.block_size,
self.block_size
)
)
def check_collision(self, position: Tuple[int, int]) -> bool:
"""Check if the given position collides with the food"""
return self.position == position

View File

@ -1,190 +1,109 @@
"""
Main Game Module
This module provides the main game class that handles the game loop,
state management, and user interface.
"""
import pygame
import sys
from enum import Enum, auto
from src.snake import Snake, Direction
from src.food import Food
from src.menu import Menu, GameMode, MenuItem
from src.config import (
WINDOW_WIDTH,
WINDOW_HEIGHT,
FPS,
SCORE_FONT_SIZE
)
from src.config import GameSettings
from src.config.colors import BLACK, GREEN, NEON_GREEN
from src.core import Direction, GameSession
from src.ui import Menu, GameMode, SettingsMenu, PauseMenu, GameArea
class GameState(Enum):
"""Game states."""
MENU = auto()
SETTINGS = auto() # New state for settings menu
SETTINGS = auto()
PLAYING = auto()
GAME_OVER = auto()
PAUSED = auto()
class GameRules:
def __init__(self):
self.wrap_around = False # Whether snake wraps around screen edges
self.speed_increase = True # Whether snake speeds up as it grows
self.min_move_cooldown = 50 # Minimum movement delay in milliseconds
self.initial_move_cooldown = 100 # Initial movement delay
class SettingsMenu:
def __init__(self, width, height, rules):
self.width = width
self.height = height
self.rules = rules
self.setup_menu_items()
self.title_font = pygame.font.Font(None, 72)
self.subtitle_font = pygame.font.Font(None, 36)
# Create title surfaces
self.title_surface = self.title_font.render("Settings", True, (0, 255, 0))
self.title_rect = self.title_surface.get_rect(center=(width//2, height//4))
# Initialize first item as selected
self.selected_index = 0
self.menu_items[0].hover = True
self.menu_items[0]._setup_font()
def setup_menu_items(self):
start_y = self.height // 2
spacing = 50
center_x = self.width // 2
self.menu_items = [
MenuItem(f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}",
(center_x, start_y),
'toggle_wrap'),
MenuItem(f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}",
(center_x, start_y + spacing),
'toggle_speed'),
MenuItem("Back to Menu",
(center_x, start_y + spacing * 3),
'back')
]
def update(self):
# Handle mouse hover
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
# Update selected index when mouse hovers
self.selected_index = i
item.hover = True
item._setup_font()
else:
# Keep keyboard selection visible
item.hover = (i == self.selected_index)
item._setup_font()
def handle_input(self, event):
if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
# Handle mouse clicks
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
self.selected_index = i
if item.action == 'toggle_wrap':
self.rules.wrap_around = not self.rules.wrap_around
item.text = f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}"
item._setup_font()
elif item.action == 'toggle_speed':
self.rules.speed_increase = not self.rules.speed_increase
item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}"
item._setup_font()
elif item.action == 'back':
return 'back'
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
return 'back'
elif event.key == pygame.K_RETURN:
item = self.menu_items[self.selected_index]
if item.action == 'toggle_wrap':
self.rules.wrap_around = not self.rules.wrap_around
item.text = f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}"
item._setup_font()
elif item.action == 'toggle_speed':
self.rules.speed_increase = not self.rules.speed_increase
item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}"
item._setup_font()
elif item.action == 'back':
return 'back'
elif event.key in (pygame.K_UP, pygame.K_DOWN):
# Update selected index
if event.key == pygame.K_UP:
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
else:
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
# Update hover states
for i, item in enumerate(self.menu_items):
item.hover = (i == self.selected_index)
item._setup_font()
return None
def draw(self, screen):
# Draw background
screen.fill((0, 0, 0))
# Draw title
screen.blit(self.title_surface, self.title_rect)
# Draw menu items
for item in self.menu_items:
item.draw(screen)
# Draw controls
controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to go back"
font = pygame.font.Font(None, 24)
controls_surface = font.render(controls_text, True, (100, 100, 100))
screen.blit(controls_surface,
(self.width - controls_surface.get_width() - 10,
self.height - 30))
class Game:
"""Main game class that manages the game loop and state."""
def __init__(self):
# Initialize pygame and font system
if not pygame.get_init():
pygame.init()
if not pygame.font.get_init():
pygame.font.init()
"""Initialize the game."""
pygame.init()
pygame.font.init()
# Initialize display
self.width = 800
self.height = 600
# Window setup
self.width = WINDOW_WIDTH
self.height = WINDOW_HEIGHT
self.screen = pygame.display.set_mode((self.width, self.height))
pygame.display.set_caption("AI Snake Game")
pygame.display.set_caption("Snake Game")
# Initialize clock
# Initialize game components
self.clock = pygame.time.Clock()
self.fps = 60
# Game rules
self.rules = GameRules()
# Game objects
self.block_size = 20
self.menu = Menu(self.width, self.height)
self.reset_game()
self.font = pygame.font.Font(None, SCORE_FONT_SIZE)
self.settings = GameSettings()
self.menu = Menu()
self.settings_menu = SettingsMenu(rules=self.settings.rules)
self.pause_menu = PauseMenu()
# Game state
self.state = GameState.MENU # Start in menu state
self.running = True
self.state = GameState.MENU
self.game_session = None
self.game_mode = None
self.running = True
# Debug mode
self.debug = False
# Add settings menu
self.settings_menu = SettingsMenu(self.width, self.height, self.rules)
self.reset_game()
def reset_game(self):
"""Reset the game state for a new game"""
self.snake = Snake((self.width // 2, self.height // 2), self.block_size)
self.food = Food(self.block_size)
self.food.spawn(self.width, self.height, self.snake.body)
self.score = 0
"""Initialize or reset the game session."""
if self.game_session:
del self.game_session
self.game_session = GameSession(
grid_width=30,
grid_height=30,
rules=self.settings.rules,
window_width=self.width,
window_height=self.height
)
def handle_events(self):
"""Handle game events."""
current_time = pygame.time.get_ticks()
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
elif event.type == pygame.MOUSEBUTTONDOWN:
# Handle mouse clicks in menus
if self.state == GameState.MENU:
selected_mode = self.menu.handle_input(event)
if selected_mode is not None:
if selected_mode == GameMode.SETTINGS:
self.state = GameState.SETTINGS
elif selected_mode == GameMode.PLAYER:
self.game_mode = GameMode.PLAYER
self.state = GameState.PLAYING
self.reset_game()
elif selected_mode in (GameMode.AI_EASY, GameMode.AI_MEDIUM, GameMode.AI_HARD):
self.game_mode = selected_mode
self.state = GameState.PLAYING
self.reset_game()
else: # Quit selected
self.running = False
elif self.state == GameState.SETTINGS:
result = self.settings_menu.handle_input(event)
if result == 'back':
self.state = GameState.MENU
elif self.state == GameState.PAUSED:
result = self.pause_menu.handle_input(event)
if result == 'resume':
self.state = GameState.PLAYING
elif result == 'menu':
self.state = GameState.MENU
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
if self.state == GameState.PLAYING:
@ -196,156 +115,108 @@ class Game:
elif self.state == GameState.SETTINGS:
self.state = GameState.MENU
elif event.key == pygame.K_F3: # Toggle debug mode
self.debug = not self.debug
if self.state == GameState.MENU:
# Handle menu input
selected_mode = self.menu.handle_input(event)
if selected_mode is not None:
if selected_mode == GameMode.SETTINGS:
self.state = GameState.SETTINGS
elif selected_mode == GameMode.PLAYER:
self.game_mode = GameMode.PLAYER
self.state = GameState.PLAYING
elif selected_mode in (GameMode.AI_EASY, GameMode.AI_MEDIUM, GameMode.AI_HARD):
self.game_mode = selected_mode
self.state = GameState.PLAYING
else: # Quit selected
self.running = False
elif self.state == GameState.SETTINGS:
# Handle settings menu input
result = self.settings_menu.handle_input(event)
if result == 'back':
self.state = GameState.MENU
self.settings.debug_mode = not self.settings.debug_mode
if self.game_session:
self.game_session.show_grid = self.settings.debug_mode
elif self.state == GameState.PLAYING:
# Handle snake direction
if event.key == pygame.K_UP:
self.snake.change_direction(Direction.UP, current_time)
self.game_session.step(Direction.UP, current_time)
elif event.key == pygame.K_DOWN:
self.snake.change_direction(Direction.DOWN, current_time)
self.game_session.step(Direction.DOWN, current_time)
elif event.key == pygame.K_LEFT:
self.snake.change_direction(Direction.LEFT, current_time)
self.game_session.step(Direction.LEFT, current_time)
elif event.key == pygame.K_RIGHT:
self.snake.change_direction(Direction.RIGHT, current_time)
self.game_session.step(Direction.RIGHT, current_time)
elif self.state == GameState.GAME_OVER:
# Any key to return to menu in game over state
self.state = GameState.MENU
self.reset_game()
def update(self):
"""Update game state."""
current_time = pygame.time.get_ticks()
if self.state == GameState.MENU:
self.menu.update()
elif self.state == GameState.SETTINGS:
self.settings_menu.update()
elif self.state == GameState.PAUSED:
self.pause_menu.update()
elif self.state == GameState.PLAYING:
# Move snake
if self.snake.move(pygame.time.get_ticks()):
# Handle wrap-around
if self.rules.wrap_around:
self.snake.wrap_position(self.width, self.height)
# Check collisions
if self.snake.check_collision(self.width, self.height, self.rules.wrap_around):
self.state = GameState.GAME_OVER
return
# Check food collision
if self.food.check_collision(self.snake.body[0]):
self.snake.grow()
self.score += 1
self.food.spawn(self.width, self.height, self.snake.body)
# Increase speed if enabled
if self.rules.speed_increase:
self.snake.move_cooldown = max(
self.rules.min_move_cooldown,
self.rules.initial_move_cooldown - self.score * 2
)
def draw_grid(self):
"""Draw the game grid (debug mode)"""
for x in range(0, self.width, self.block_size):
pygame.draw.line(self.screen, (50, 50, 50), (x, 0), (x, self.height))
for y in range(0, self.height, self.block_size):
pygame.draw.line(self.screen, (50, 50, 50), (0, y), (self.width, y))
# Update game session
_, _, done = self.game_session.step(current_time=current_time)
if done:
self.state = GameState.GAME_OVER
def draw_debug_info(self):
"""Draw debug information"""
"""Draw debug information."""
if not self.game_session:
return
state = self.game_session.get_state()
font = pygame.font.Font(None, 24)
debug_info = [
f"FPS: {int(self.clock.get_fps())}",
f"Snake Length: {len(self.snake.body)}",
f"Snake Head: {self.snake.body[0]}",
f"Food Pos: {self.food.position}",
f"Snake Length: {len(state['snake_body'])}",
f"Snake Head: {state['snake_body'][0]}",
f"Food Pos: {state['food_position']}",
f"Game State: {self.state.name}",
f"Game Mode: {self.game_mode.name if self.game_mode else 'None'}",
f"Wrap Around: {self.rules.wrap_around}",
f"Move Cooldown: {self.snake.move_cooldown}ms"
f"Move Cooldown: {state['move_cooldown']}ms"
]
for i, text in enumerate(debug_info):
surface = font.render(text, True, (0, 255, 0))
surface = font.render(text, True, GREEN)
self.screen.blit(surface, (10, self.height - 200 + i * 25))
def render(self):
# Clear screen
self.screen.fill((0, 0, 0)) # Black background
def draw_overlay(self, text):
"""Draw a semi-transparent overlay with text."""
overlay = pygame.Surface((self.width, self.height))
overlay.fill(BLACK)
overlay.set_alpha(128)
self.screen.blit(overlay, (0, 0))
# Draw grid in debug mode
if self.debug and self.state not in (GameState.MENU, GameState.SETTINGS):
self.draw_grid()
text_surface = self.font.render(text, True, NEON_GREEN)
text_rect = text_surface.get_rect(center=(self.width // 2, self.height // 2))
self.screen.blit(text_surface, text_rect)
def render(self):
"""Render the current game state."""
# Clear screen
self.screen.fill(BLACK)
if self.state == GameState.MENU:
self.menu.draw(self.screen)
elif self.state == GameState.SETTINGS:
self.settings_menu.draw(self.screen)
elif self.state == GameState.PLAYING or self.state == GameState.PAUSED:
# Draw game objects
self.snake.draw(self.screen)
self.food.draw(self.screen)
elif self.state in [GameState.PLAYING, GameState.PAUSED, GameState.GAME_OVER]:
# Draw game session
self.game_session.render(self.screen)
# Draw score
font = pygame.font.Font(None, 36)
score_text = font.render(f'Score: {self.score}', True, (255, 255, 255))
self.screen.blit(score_text, (10, 10))
# Draw score panel
score_text = f"Score: {self.game_session.score}"
score_surface = self.font.render(score_text, True, GREEN)
score_rect = score_surface.get_rect(midtop=(self.width // 2, 20))
self.screen.blit(score_surface, score_rect)
# Draw game mode
mode_text = font.render(f'Mode: {self.game_mode.name}', True, (255, 255, 255))
self.screen.blit(mode_text, (10, 50))
# Draw pause indicator
# Draw overlays
if self.state == GameState.PAUSED:
pause_text = font.render('PAUSED', True, (255, 255, 255))
text_rect = pause_text.get_rect(center=(self.width//2, self.height//2))
self.screen.blit(pause_text, text_rect)
elif self.state == GameState.GAME_OVER:
font = pygame.font.Font(None, 74)
text = font.render('Game Over!', True, (255, 0, 0))
score_text = font.render(f'Score: {self.score}', True, (255, 255, 255))
restart_text = font.render('Press any key for menu', True, (255, 255, 255))
self.pause_menu.draw(self.screen)
elif self.state == GameState.GAME_OVER:
self.draw_overlay("GAME OVER")
text_rect = text.get_rect(center=(self.width//2, self.height//2 - 50))
score_rect = score_text.get_rect(center=(self.width//2, self.height//2 + 50))
restart_rect = restart_text.get_rect(center=(self.width//2, self.height//2 + 150))
self.screen.blit(text, text_rect)
self.screen.blit(score_text, score_rect)
self.screen.blit(restart_text, restart_rect)
# Draw debug info
if self.debug:
self.draw_debug_info()
# Draw debug info if enabled
if self.settings.debug_mode:
self.draw_debug_info()
# Update display
pygame.display.flip()
def run(self):
"""Run the main game loop."""
self.running = True
while self.running:
self.handle_events()
self.update()
self.render()
self.clock.tick(self.fps)
self.clock.tick(FPS)

View File

@ -1,8 +1,16 @@
"""
Main entry point for the AI Snake Game.
This module provides the standard entry point for running the game
without any command line arguments.
"""
import pygame
import sys
from game import Game
from src.game import Game
def main():
"""Run the game with default settings."""
# Initialize Pygame
pygame.init()

View File

@ -81,6 +81,8 @@ class Menu:
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
self.selected_index = i
item.hover = True
item._setup_font()
return item.action
elif event.type == pygame.KEYDOWN:

View File

@ -1,153 +0,0 @@
import pygame
from enum import Enum, auto
from typing import List, Tuple
class Direction(Enum):
UP = auto()
DOWN = auto()
LEFT = auto()
RIGHT = auto()
class Snake:
def __init__(self, start_pos: Tuple[int, int], block_size: int):
self.block_size = block_size
self.direction = Direction.RIGHT
self.body = [start_pos] # Head is at index 0
self.growing = False
# Movement cooldown
self.move_cooldown = 100 # milliseconds
self.last_move_time = 0
# Direction change cooldown
self.direction_change_cooldown = 50 # milliseconds
self.last_direction_change = 0
self.queued_direction = None
self.last_valid_move_time = 0 # Track when we last actually moved
def move(self, current_time: int) -> bool:
"""Move the snake if enough time has passed. Returns True if moved."""
if current_time - self.last_move_time < self.move_cooldown:
return False
# Apply queued direction change if it exists and is valid
if self.queued_direction:
if current_time - self.last_valid_move_time >= self.direction_change_cooldown:
if self._apply_direction_change(self.queued_direction):
self.last_direction_change = current_time
self.queued_direction = None
# Update position
head = self.body[0]
new_head = self._get_new_head_position(head)
# Insert new head
self.body.insert(0, new_head)
# Remove tail if not growing
if not self.growing:
self.body.pop()
else:
self.growing = False
self.last_move_time = current_time
self.last_valid_move_time = current_time
return True
def _get_new_head_position(self, head: Tuple[int, int]) -> Tuple[int, int]:
x, y = head
if self.direction == Direction.UP:
return (x, y - self.block_size)
elif self.direction == Direction.DOWN:
return (x, y + self.block_size)
elif self.direction == Direction.LEFT:
return (x - self.block_size, y)
else: # Direction.RIGHT
return (x + self.block_size, y)
def grow(self):
"""Mark the snake to grow on next move"""
self.growing = True
def check_collision(self, width: int, height: int, wrap_around: bool = False) -> bool:
"""
Check if snake has collided with walls or itself.
Args:
width: Game area width
height: Game area height
wrap_around: If True, snake wraps around screen edges instead of colliding
"""
head = self.body[0]
if not wrap_around:
# Check wall collision
if (head[0] < 0 or head[0] >= width or
head[1] < 0 or head[1] >= height):
return True
# Check self collision (skip head)
if head in self.body[1:]:
return True
return False
def wrap_position(self, width: int, height: int):
"""Wrap snake's head position around screen edges"""
head_x, head_y = self.body[0]
wrapped_head = (
head_x % width if head_x != width else 0,
head_y % height if head_y != height else 0
)
self.body[0] = wrapped_head
def _apply_direction_change(self, new_direction: Direction) -> bool:
"""
Internal method to actually change direction.
Returns True if direction was changed, False otherwise.
"""
# Prevent 180-degree turns when snake is longer than 1
if len(self.body) > 1:
opposites = {
Direction.UP: Direction.DOWN,
Direction.DOWN: Direction.UP,
Direction.LEFT: Direction.RIGHT,
Direction.RIGHT: Direction.LEFT
}
if opposites[new_direction] == self.direction:
return False
self.direction = new_direction
return True
def change_direction(self, new_direction: Direction, current_time: int):
"""
Queue a direction change if it's valid and cooldown has passed.
Args:
new_direction: The desired new direction
current_time: Current game time in milliseconds
"""
# If we haven't moved since the last direction change, queue it
if current_time - self.last_valid_move_time < self.direction_change_cooldown:
self.queued_direction = new_direction
return
# Try to change direction immediately
if self._apply_direction_change(new_direction):
self.last_direction_change = current_time
def draw(self, screen: pygame.Surface):
"""Draw the snake on the screen"""
# Draw head in a slightly different color
head_color = (0, 200, 0) # Darker green for head
body_color = (0, 255, 0) # Regular green for body
# Draw body segments
for i, segment in enumerate(self.body):
color = head_color if i == 0 else body_color
pygame.draw.rect(
screen,
color,
pygame.Rect(segment[0], segment[1], self.block_size, self.block_size)
)

20
src/ui/__init__.py Normal file
View File

@ -0,0 +1,20 @@
"""
Game UI Package
This package contains all user interface components.
"""
from .menu import Menu, GameMode
from .menu_item import MenuItem
from .settings_menu import SettingsMenu
from .pause_menu import PauseMenu
from .game_area import GameArea
__all__ = [
'Menu',
'GameMode',
'MenuItem',
'SettingsMenu',
'PauseMenu',
'GameArea'
]

138
src/ui/game_area.py Normal file
View File

@ -0,0 +1,138 @@
"""
Game Area Module
This module provides the GameArea class that handles the rendering of the game area
and coordinate conversions between screen and grid coordinates.
"""
import pygame
from src.config import (
DEFAULT_PADDING,
DEFAULT_SCORE_HEIGHT,
)
from src.config.colors import *
from typing import Tuple
class GameArea:
"""Manages the game area rendering and coordinate conversions."""
def __init__(self, window_width, window_height, grid_size=20, grid_width=30, grid_height=30,
padding=DEFAULT_PADDING, score_height=DEFAULT_SCORE_HEIGHT):
# Grid properties
self.grid_size = grid_size
self.grid_width = grid_width
self.grid_height = grid_height
# Area properties
self.padding = padding
self.score_height = score_height
self.border_thickness = 4 # Increased border thickness
# Calculate game area dimensions
self.width = self.grid_width * self.grid_size
self.height = self.grid_height * self.grid_size
# Center the game area in the window
self.x = (window_width - self.width) // 2
self.y = self.score_height + self.padding
# Create rectangle for game bounds
self.rect = pygame.Rect(self.x, self.y, self.width, self.height)
# Create inner rectangle for gameplay area with thicker border
self.inner_rect = pygame.Rect(
self.x + self.border_thickness,
self.y + self.border_thickness,
self.width - (self.border_thickness * 2),
self.height - (self.border_thickness * 2)
)
self.screen_offset = (self.x, self.y)
def get_grid_pos(self, screen_x, screen_y):
"""Convert screen coordinates to grid position."""
grid_x = (screen_x - self.x) // self.grid_size
grid_y = (screen_y - self.y) // self.grid_size
return grid_x, grid_y
def get_screen_pos(self, grid_x, grid_y):
"""Convert grid position to screen coordinates."""
screen_x = self.x + (grid_x * self.grid_size)
screen_y = self.y + (grid_y * self.grid_size)
return screen_x, screen_y
def is_within_bounds(self, grid_x, grid_y):
"""Check if grid position is within game bounds."""
return 0 <= grid_x < self.grid_width and 0 <= grid_y < self.grid_height
def draw(self, screen):
"""Draw the game area boundary and background with modern effects."""
# Draw outer border with gradient effect
for i in range(self.border_thickness):
border_rect = pygame.Rect(
self.x + i,
self.y + i,
self.width - (i * 2),
self.height - (i * 2)
)
# Gradient from darker to lighter
color_factor = i / self.border_thickness
r = int(DARKER_GREEN[0] * (1 - color_factor) + FOREST_GREEN[0] * color_factor)
g = int(DARKER_GREEN[1] * (1 - color_factor) + FOREST_GREEN[1] * color_factor)
b = int(DARKER_GREEN[2] * (1 - color_factor) + FOREST_GREEN[2] * color_factor)
pygame.draw.rect(screen, (r, g, b), border_rect)
# Draw inner background
pygame.draw.rect(screen, DARK_GRAY, self.inner_rect)
# Draw subtle grid pattern
if hasattr(self, 'show_grid') and self.show_grid:
# Debug grid lines
for x in range(self.grid_width + 1):
start_pos = (self.x + x * self.grid_size, self.y)
end_pos = (self.x + x * self.grid_size, self.y + self.height)
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
for y in range(self.grid_height + 1):
start_pos = (self.x, self.y + y * self.grid_size)
end_pos = (self.x + self.width, self.y + y * self.grid_size)
pygame.draw.line(screen, GRID_COLOR, start_pos, end_pos, 1)
else:
# Subtle grid dots with fade effect
for x in range(self.grid_width + 1):
for y in range(self.grid_height + 1):
pos_x = self.x + x * self.grid_size
pos_y = self.y + y * self.grid_size
# Calculate distance from center for fade effect
center_x = self.x + (self.width / 2)
center_y = self.y + (self.height / 2)
dist_x = abs(pos_x - center_x) / (self.width / 2)
dist_y = abs(pos_y - center_y) / (self.height / 2)
dist = min(1.0, (dist_x + dist_y) / 2)
# Fade dot color based on distance from center
dot_color = tuple(int(c * (1 - dist * 0.5)) for c in SUBTLE_GRID_COLOR)
pygame.draw.circle(screen, dot_color, (pos_x, pos_y), 1)
# Draw inner border glow
glow_colors = [
(int(NEON_GREEN[0] * 0.4), int(NEON_GREEN[1] * 0.4), int(NEON_GREEN[2] * 0.4)),
(int(NEON_GREEN[0] * 0.6), int(NEON_GREEN[1] * 0.6), int(NEON_GREEN[2] * 0.6)),
NEON_GREEN
]
for i, color in enumerate(glow_colors):
glow_rect = pygame.Rect(
self.inner_rect.x - i,
self.inner_rect.y - i,
self.inner_rect.width + (i * 2),
self.inner_rect.height + (i * 2)
)
pygame.draw.rect(screen, color, glow_rect, 1)
def grid_to_screen(self, grid_pos: Tuple[int, int]) -> Tuple[int, int]:
"""Convert grid coordinates to screen coordinates"""
x = self.rect.left + grid_pos[0] * self.grid_size + self.grid_size // 2
y = self.rect.top + grid_pos[1] * self.grid_size + self.grid_size // 2
return (x, y)

127
src/ui/menu.py Normal file
View File

@ -0,0 +1,127 @@
"""
Main Menu Module
This module provides the main menu interface for the game.
"""
import pygame
from enum import Enum, auto
from src.config import (
WINDOW_WIDTH,
WINDOW_HEIGHT,
TITLE_FONT_SIZE,
SUBTITLE_FONT_SIZE,
MENU_ITEM_SPACING
)
from src.config.colors import BLACK, GRAY
from src.ui.menu_item import MenuItem
class GameMode(Enum):
"""Available game modes."""
PLAYER = auto()
AI_EASY = auto()
AI_MEDIUM = auto()
AI_HARD = auto()
SETTINGS = auto()
class Menu:
"""Main menu interface."""
def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT):
"""Initialize the main menu."""
self.width = width
self.height = height
self.setup_menu_items()
# Create fonts
self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE)
self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE)
# Create title surfaces
self.title_surface = self.title_font.render("Snake Game", True, (0, 255, 0))
self.title_rect = self.title_surface.get_rect(center=(width//2, height//4))
# Initialize first item as selected
self.selected_index = 0
self.menu_items[0].hover = True
self.menu_items[0]._setup_font()
def setup_menu_items(self):
"""Setup the menu items."""
start_y = self.height // 2
spacing = MENU_ITEM_SPACING
center_x = self.width // 2
self.menu_items = [
MenuItem("Player Game", (center_x, start_y), GameMode.PLAYER),
MenuItem("AI Game (Easy)", (center_x, start_y + spacing), GameMode.AI_EASY),
MenuItem("AI Game (Medium)", (center_x, start_y + spacing * 2), GameMode.AI_MEDIUM),
MenuItem("AI Game (Hard)", (center_x, start_y + spacing * 3), GameMode.AI_HARD),
MenuItem("Settings", (center_x, start_y + spacing * 4), GameMode.SETTINGS),
MenuItem("Quit", (center_x, start_y + spacing * 5), 'quit')
]
def update(self):
"""Update menu state."""
# Handle mouse hover
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
# Update selected index when mouse hovers
self.selected_index = i
item.hover = True
item._setup_font()
else:
# Keep keyboard selection visible
item.hover = (i == self.selected_index)
item._setup_font()
def handle_input(self, event):
"""
Handle input events.
Returns:
GameMode or None: The selected game mode or None if no selection made
"""
if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
# Handle mouse clicks
mouse_pos = pygame.mouse.get_pos()
for item in self.menu_items:
if item.rect.collidepoint(mouse_pos):
return item.action
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_RETURN:
return self.menu_items[self.selected_index].action
elif event.key in (pygame.K_UP, pygame.K_DOWN):
# Update selected index
if event.key == pygame.K_UP:
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
else:
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
# Update hover states
for i, item in enumerate(self.menu_items):
item.hover = (i == self.selected_index)
item._setup_font()
return None
def draw(self, screen):
"""Draw the menu to the screen."""
# Draw background
screen.fill(BLACK)
# Draw title
screen.blit(self.title_surface, self.title_rect)
# Draw menu items
for item in self.menu_items:
item.draw(screen)
# Draw controls
controls_text = "Arrow keys or mouse to navigate, Enter to select"
controls_surface = self.subtitle_font.render(controls_text, True, GRAY)
screen.blit(controls_surface,
(self.width - controls_surface.get_width() - 10,
self.height - 30))

84
src/ui/menu_item.py Normal file
View File

@ -0,0 +1,84 @@
"""
Menu Item Module
This module provides the MenuItem class that represents a clickable menu item.
"""
import pygame
from src.config.colors import *
class MenuItem:
"""A clickable menu item with hover effects."""
def __init__(self, text, position, action, font_size_normal=36, font_size_hover=48):
"""
Initialize a menu item.
Args:
text: The text to display
position: (x, y) tuple for center position
action: String identifier for the action to take when clicked
font_size_normal: Font size when not hovered (default 36)
font_size_hover: Font size when hovered (default 48)
"""
self.text = text
self.position = position
self.action = action
self.hover = False
self.font_size_normal = font_size_normal
self.font_size_hover = font_size_hover
self._setup_font()
def _setup_font(self):
"""Setup the font based on hover state."""
size = self.font_size_hover if self.hover else self.font_size_normal
self.font = pygame.font.Font(None, size)
self.surface = self.font.render(self.text, True, GREEN if self.hover else WHITE)
self.rect = self.surface.get_rect(center=self.position)
def draw(self, screen):
"""Draw the menu item to the screen."""
screen.blit(self.surface, self.rect)
class NumberMenuItem(MenuItem):
"""A menu item that displays a number that can be adjusted with arrow keys."""
def __init__(self, text, position, action, min_value, max_value,
step=1, font_size_normal=36, font_size_hover=48):
"""
Initialize a number menu item.
Args:
text: The text to display
position: (x, y) tuple for center position
action: String identifier for the action to take when clicked
min_value: Minimum allowed value
max_value: Maximum allowed value
step: Amount to increment/decrement by (default 1)
font_size_normal: Font size when not hovered (default 36)
font_size_hover: Font size when hovered (default 48)
"""
super().__init__(text, position, action, font_size_normal, font_size_hover)
self.min_value = min_value
self.max_value = max_value
self.step = step
self.current_value = int(text.split(": ")[1]) # Extract number from text
def increment(self):
"""Increment the current value by step amount."""
if self.current_value + self.step <= self.max_value:
self.current_value += self.step
self._update_text()
def decrement(self):
"""Decrement the current value by step amount."""
if self.current_value - self.step >= self.min_value:
self.current_value -= self.step
self._update_text()
def _update_text(self):
"""Update the displayed text with new value."""
base_text = self.text.split(": ")[0]
self.text = f"{base_text}: {self.current_value}"
self._setup_font()

129
src/ui/pause_menu.py Normal file
View File

@ -0,0 +1,129 @@
"""
Pause Menu Module
This module provides the pause menu interface that appears during gameplay.
"""
import pygame
from src.config import (
WINDOW_WIDTH,
WINDOW_HEIGHT,
TITLE_FONT_SIZE,
SUBTITLE_FONT_SIZE,
MENU_ITEM_SPACING
)
from src.config.colors import *
from src.ui.menu_item import MenuItem
class PauseMenu:
"""Pause menu interface."""
def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT):
"""
Initialize the pause menu.
Args:
width: Window width
height: Window height
"""
self.width = width
self.height = height
self.setup_menu_items()
self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE)
self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE)
# Create title surfaces
self.title_surface = self.title_font.render("Paused", True, (0, 255, 0))
self.title_rect = self.title_surface.get_rect(center=(width//2, height//3))
# Initialize first item as selected
self.selected_index = 0
self.menu_items[0].hover = True
self.menu_items[0]._setup_font()
def setup_menu_items(self):
"""Setup the menu items."""
start_y = self.height // 2
spacing = MENU_ITEM_SPACING
center_x = self.width // 2
self.menu_items = [
MenuItem("Resume",
(center_x, start_y),
'resume'),
MenuItem("Return to Menu",
(center_x, start_y + spacing),
'menu')
]
def update(self):
"""Update menu state."""
# Handle mouse hover
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
# Update selected index when mouse hovers
self.selected_index = i
item.hover = True
item._setup_font()
else:
# Keep keyboard selection visible
item.hover = (i == self.selected_index)
item._setup_font()
def handle_input(self, event):
"""
Handle input events.
Returns:
str or None: The action to take or None if no action
"""
if event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
# Handle mouse clicks
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
self.selected_index = i
return item.action
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
return 'resume'
elif event.key == pygame.K_RETURN:
return self.menu_items[self.selected_index].action
elif event.key in (pygame.K_UP, pygame.K_DOWN):
# Update selected index
if event.key == pygame.K_UP:
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
else:
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
# Update hover states
for i, item in enumerate(self.menu_items):
item.hover = (i == self.selected_index)
item._setup_font()
return None
def draw(self, screen):
"""Draw the menu to the screen."""
# Draw semi-transparent background
overlay = pygame.Surface((screen.get_width(), screen.get_height()))
overlay.fill(BLACK)
overlay.set_alpha(128)
screen.blit(overlay, (0, 0))
# Draw title
screen.blit(self.title_surface, self.title_rect)
# Draw menu items
for item in self.menu_items:
item.draw(screen)
# Draw controls
controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to resume"
controls_surface = self.subtitle_font.render(controls_text, True, GRAY)
screen.blit(controls_surface,
(self.width - controls_surface.get_width() - 10,
self.height - 30))

159
src/ui/settings_menu.py Normal file
View File

@ -0,0 +1,159 @@
"""
Settings Menu Module
This module provides the settings menu interface for configuring game rules.
"""
import pygame
from src.config import (
WINDOW_WIDTH,
WINDOW_HEIGHT,
TITLE_FONT_SIZE,
SUBTITLE_FONT_SIZE,
MENU_ITEM_SPACING,
)
from src.config.colors import *
from src.ui.menu_item import MenuItem, NumberMenuItem
class SettingsMenu:
"""Settings menu interface."""
def __init__(self, width=WINDOW_WIDTH, height=WINDOW_HEIGHT, rules=None):
"""
Initialize the settings menu.
Args:
width: Window width
height: Window height
rules: GameRules instance to modify
"""
self.width = width
self.height = height
self.rules = rules
self.setup_menu_items()
self.title_font = pygame.font.Font(None, TITLE_FONT_SIZE)
self.subtitle_font = pygame.font.Font(None, SUBTITLE_FONT_SIZE)
# Create title surfaces
self.title_surface = self.title_font.render("Settings", True, (0, 255, 0))
self.title_rect = self.title_surface.get_rect(center=(width//2, height//4))
# Initialize first item as selected
self.selected_index = 0
self.menu_items[0].hover = True
self.menu_items[0]._setup_font()
def setup_menu_items(self):
"""Setup the menu items."""
start_y = self.height // 2
spacing = MENU_ITEM_SPACING
center_x = self.width // 2
self.menu_items = [
MenuItem(f"Wrap Around: {'On' if self.rules.wrap_around else 'Off'}",
(center_x, start_y),
'toggle_wrap'),
MenuItem(f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}",
(center_x, start_y + spacing),
'toggle_speed'),
NumberMenuItem(f"Starting Length: {self.rules.starting_length}",
(center_x, start_y + spacing * 2),
'set_starting_length',
1, 10),
MenuItem("Back to Menu",
(center_x, start_y + spacing * 3),
'back')
]
def update(self):
"""Update menu state."""
# Handle mouse hover
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
# Update selected index when mouse hovers
self.selected_index = i
item.hover = True
item._setup_font()
else:
# Keep keyboard selection visible
item.hover = (i == self.selected_index)
item._setup_font()
def handle_input(self, event):
"""
Handle input events.
Returns:
str or None: The action to take or None if no action
"""
if event.type == pygame.MOUSEBUTTONDOWN:
# Handle mouse clicks
mouse_pos = pygame.mouse.get_pos()
for i, item in enumerate(self.menu_items):
if item.rect.collidepoint(mouse_pos):
self.selected_index = i
if event.button == 1:
if item.action == 'toggle_wrap':
self.rules.update_rule('wrap_around', not self.rules.wrap_around)
elif item.action == 'toggle_speed':
self.rules.speed_increase = not self.rules.speed_increase
item.text = f"Speed Increase: {'On' if self.rules.speed_increase else 'Off'}"
item._setup_font()
elif item.action == 'set_starting_length':
item.increment()
self.rules.update_rule('starting_length', item.current_value)
elif item.action == 'back':
return 'back'
elif event.button == 3:
if item.action == 'set_starting_length':
item.decrement()
self.rules.update_rule('starting_length', item.current_value)
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
return 'back'
elif event.key == pygame.K_RETURN:
item = self.menu_items[self.selected_index]
if item.action == 'toggle_wrap':
self.rules.update_rule('wrap_around', not self.rules.wrap_around)
elif item.action == 'toggle_speed':
self.rules.update_rule('speed_increase', not self.rules.speed_increase)
elif item.action == 'set_starting_length':
item.increment()
self.rules.update_rule('starting_length', item.current_value)
elif item.action == 'back':
return 'back'
elif event.key in (pygame.K_UP, pygame.K_DOWN):
# Update selected index
if event.key == pygame.K_UP:
self.selected_index = (self.selected_index - 1) % len(self.menu_items)
else:
self.selected_index = (self.selected_index + 1) % len(self.menu_items)
# Update hover states
for i, item in enumerate(self.menu_items):
item.hover = (i == self.selected_index)
item._setup_font()
return None
def draw(self, screen):
"""Draw the menu to the screen."""
# Draw background
screen.fill(BLACK)
# Draw title
screen.blit(self.title_surface, self.title_rect)
# Draw menu items
for item in self.menu_items:
item.draw(screen)
# Draw controls
controls_text = "Arrow keys or mouse to navigate, Enter to select, Esc to go back"
controls_surface = self.subtitle_font.render(controls_text, True, GRAY)
screen.blit(controls_surface,
(self.width - controls_surface.get_width() - 10,
self.height - 30))

View File

@ -8,12 +8,12 @@ This package contains all test modules including:
"""
# Import test utilities and fixtures
from tests.conftest import game_config, mock_screen
from conftest import game_config, mock_screen
__version__ = '0.1.0'
# Define test utilities available for import
__all__ = [
'game_config',
'mock_screen'
'mock_screen',
]

View File

@ -1,5 +1,5 @@
import pytest
from src.food import Food
from src.core import Food
def test_food_initialization():
"""Test food initialization"""

View File

@ -1,5 +1,4 @@
import pytest
from src.snake import Snake, Direction
from src.core import Snake, Direction
def test_snake_initialization():
"""Test snake initialization with default values"""

15
training_config.json Normal file
View File

@ -0,0 +1,15 @@
{
"timesteps": 3000000,
"learning_rate": 3e-4,
"batch_size": 256,
"n_envs": 16,
"n_steps": 2048,
"gamma": 0.99,
"ent_coef": 0.1,
"n_epochs": 10,
"gae_lambda": 0.95,
"clip_range": 0.2,
"vf_coef": 0.5,
"max_grad_norm": 0.5,
"normalize_advantage": true
}