Spaces:
Running
Running
File size: 9,962 Bytes
93779c5 503cdbf 93779c5 503cdbf aac050c 503cdbf 93779c5 503cdbf 93779c5 503cdbf 93779c5 503cdbf 93779c5 503cdbf 93779c5 503cdbf aac050c 93779c5 503cdbf 93779c5 503cdbf 93779c5 503cdbf 93779c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# agent.py
import os
import json
import asyncio
import random
from openai import AsyncOpenAI # Use AsyncOpenAI for async compatibility with poke-env
# Import necessary poke-env components for type hinting and functionality
from poke_env.player import Player
from poke_env.environment.battle import Battle
from poke_env.environment.move import Move
from poke_env.environment.pokemon import Pokemon
from tools import toolsList
class OpenAIAgent(Player):
"""
An AI agent for Pokemon Showdown that uses OpenAI's API
with function calling to decide its moves.
Requires OPENAI_API_KEY environment variable to be set.
"""
def __init__(self, *args, **kwargs):
# Pass account_configuration and other Player args/kwargs to the parent
super().__init__(*args, **kwargs)
# Initialize OpenAI client
# It's slightly better practice to get the key here rather than relying solely on the global env scope
api_key = os.environ["OPENAI_API_KEY"]
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set or loaded.")
# Use AsyncOpenAI for compatibility with poke-env's async nature
self.openai_client = AsyncOpenAI(api_key=api_key)
self.model = "gpt-4o" # Or "gpt-3.5-turbo", "gpt-4-turbo-preview", etc.
# Define the functions OpenAI can "call"
self.functions = toolsList
self.battle_history = [] # Optional: To potentially add context later
def _format_battle_state(self, battle: Battle) -> str:
"""Formats the current battle state into a string for the LLM."""
# Own active Pokemon details
active_pkmn = battle.active_pokemon
active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
f"HP: {active_pkmn.current_hp_fraction * 100:.1f}% " \
f"Status: {active_pkmn.status.name if active_pkmn.status else 'None'} " \
f"Boosts: {active_pkmn.boosts}"
# Opponent active Pokemon details
opponent_pkmn = battle.opponent_active_pokemon
opponent_pkmn_info = f"Opponent's active Pokemon: {opponent_pkmn.species} " \
f"(Type: {'/'.join(map(str, opponent_pkmn.types))}) " \
f"HP: {opponent_pkmn.current_hp_fraction * 100:.1f}% " \
f"Status: {opponent_pkmn.status.name if opponent_pkmn.status else 'None'} " \
f"Boosts: {opponent_pkmn.boosts}"
# Available moves
available_moves_info = "Available moves:\n"
if battle.available_moves:
for move in battle.available_moves:
available_moves_info += f"- {move.id} (Type: {move.type}, BP: {move.base_power}, Acc: {move.accuracy}, PP: {move.current_pp}/{move.max_pp}, Cat: {move.category.name})\n"
else:
available_moves_info += "- None (Must switch or Struggle)\n"
# Available switches
available_switches_info = "Available switches:\n"
if battle.available_switches:
for pkmn in battle.available_switches:
available_switches_info += f"- {pkmn.species} (HP: {pkmn.current_hp_fraction * 100:.1f}%, Status: {pkmn.status.name if pkmn.status else 'None'})\n"
else:
available_switches_info += "- None\n"
# Combine information
state_str = f"{active_pkmn_info}\n" \
f"{opponent_pkmn_info}\n\n" \
f"{available_moves_info}\n" \
f"{available_switches_info}\n" \
f"Weather: {battle.weather}\n" \
f"Terrains: {battle.fields}\n" \
f"Your Side Conditions: {battle.side_conditions}\n" \
f"Opponent Side Conditions: {battle.opponent_side_conditions}\n"
return state_str.strip()
async def _get_openai_decision(self, battle_state: str) -> dict | None:
"""Sends state to OpenAI and gets back the function call decision."""
system_prompt = (
"You are a skilled Pokemon battle AI. Your goal is to win the battle. "
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
"Only choose actions listed as available."
)
user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
try:
response = await self.openai_client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
functions=self.functions,
function_call="auto", # Let the model choose which function to call
temperature=0.5, # Adjust for creativity vs consistency
)
message = response.choices[0].message
if message.function_call:
function_name = message.function_call.name
try:
arguments = json.loads(message.function_call.arguments)
return {"name": function_name, "arguments": arguments}
except json.JSONDecodeError:
print(f"Error decoding function call arguments: {message.function_call.arguments}")
return None
else:
# Model decided not to call a function (or generated text instead)
print(f"Warning: OpenAI did not return a function call. Response: {message.content}")
return None
except Exception as e:
print(f"Error during OpenAI API call: {e}")
return None
def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None:
"""Finds the Move object corresponding to the given name."""
# Normalize name for comparison (lowercase, remove spaces/hyphens)
normalized_name = move_name.lower().replace(" ", "").replace("-", "")
for move in battle.available_moves:
if move.id == normalized_name: # move.id is already normalized
return move
# Fallback: try matching against the display name if ID fails (less reliable)
for move in battle.available_moves:
if move.id == move_name.lower(): # Handle cases like "U-turn" vs "uturn"
return move
if move.name.lower() == move_name.lower():
return move
return None
def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None:
"""Finds the Pokemon object corresponding to the given species name."""
# Normalize name for comparison
normalized_name = pokemon_name.lower()
for pkmn in battle.available_switches:
if pkmn.species.lower() == normalized_name:
return pkmn
return None
async def choose_move(self, battle: Battle) -> str:
"""
Main decision-making function called by poke-env each turn.
"""
# 1. Format battle state
battle_state_str = self._format_battle_state(battle)
# print(f"\n--- Turn {battle.turn} ---") # Debugging
# print(battle_state_str) # Debugging
# 2. Get decision from OpenAI
decision = await self._get_openai_decision(battle_state_str)
# 3. Parse decision and create order
if decision:
function_name = decision["name"]
args = decision["arguments"]
# print(f"OpenAI Recommended: {function_name} with args {args}") # Debugging
if function_name == "choose_move":
move_name = args.get("move_name")
if move_name:
chosen_move = self._find_move_by_name(battle, move_name)
if chosen_move and chosen_move in battle.available_moves:
# print(f"Action: Using move {chosen_move.id}")
return self.create_order(chosen_move)
else:
print(f"Warning: OpenAI chose unavailable/invalid move '{move_name}'. Falling back.")
else:
print(f"Warning: OpenAI 'choose_move' called without 'move_name'. Falling back.")
elif function_name == "choose_switch":
pokemon_name = args.get("pokemon_name")
if pokemon_name:
chosen_switch = self._find_pokemon_by_name(battle, pokemon_name)
if chosen_switch and chosen_switch in battle.available_switches:
# print(f"Action: Switching to {chosen_switch.species}")
return self.create_order(chosen_switch)
else:
print(f"Warning: OpenAI chose unavailable/invalid switch '{pokemon_name}'. Falling back.")
else:
print(f"Warning: OpenAI 'choose_switch' called without 'pokemon_name'. Falling back.")
# 4. Fallback if API fails, returns invalid action, or no function call
print("Fallback: Choosing random move/switch.")
# Ensure options exist before choosing randomly
available_options = battle.available_moves + battle.available_switches
if available_options:
# Use the built-in random choice method from Player for fallback
return self.choose_random_move(battle)
else:
# Should only happen if forced to Struggle
return self.choose_default_move(battle) |