AI-Snake-Game/src/ai/controller.py

70 lines
2.3 KiB
Python

"""
AI Controller for Snake Game
This module provides the AI controller that uses trained models to play the game.
It handles model loading, state processing, and decision making during gameplay.
"""
import os
import numpy as np
from stable_baselines3 import PPO
from ai.environment import SnakeEnv, Direction
class AIController:
"""AI controller that uses trained models to play the game."""
def __init__(self, difficulty: str = "medium"):
"""
Initialize the AI controller.
Args:
difficulty: "easy", "medium", or "hard"
"""
self.difficulty = difficulty
self.model = None
self.env = SnakeEnv() # For state processing
self._load_model()
def _load_model(self):
"""Load the appropriate model based on difficulty."""
model_path = f"models/{self.difficulty}/best_model.zip"
if not os.path.exists(model_path):
# Fall back to final model if best model doesn't exist
model_path = f"models/{self.difficulty}/final_model.zip"
if not os.path.exists(model_path):
raise FileNotFoundError(
f"No model found for difficulty {self.difficulty}. "
"Please train the model first."
)
self.model = PPO.load(model_path)
def get_action(self, game_state: dict) -> Direction:
"""
Get the next action based on the current game state.
Args:
game_state: Dictionary containing:
- snake: Snake object
- food: Food object
- width: Game width
- height: Game height
Returns:
Direction enum indicating the chosen action
"""
# Update environment with current game state
self.env.snake = game_state["snake"]
self.env.food = game_state["food"]
self.env.width = game_state["width"]
self.env.height = game_state["height"]
# Get state observation
state = self.env._get_state()
# Get action from model
action, _ = self.model.predict(state, deterministic=True)
# Convert action index to Direction
return self.env.action_space[action]