You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
184 lines
8.8 KiB
184 lines
8.8 KiB
|
|
from typing import List, Set
|
|
from routes.game_world import GameWorld
|
|
import heapq
|
|
import networkx as nx
|
|
|
|
class RoutePlan:
|
|
def __init__(self, path: List[str], cost: int, conditions: Set[str]):
|
|
self.path = path
|
|
self.cost = cost
|
|
self.conditions = conditions
|
|
|
|
FLY_OUT_OF_BATTLE = 'Fly out of battle'
|
|
|
|
class State:
|
|
def __init__(self, location, conditions, cost, path, visited_required_nodes):
|
|
self.location = location
|
|
self.conditions = conditions # A frozenset of conditions
|
|
self.cost = cost
|
|
self.path = path # List of locations visited in order
|
|
self.visited_required_nodes = visited_required_nodes # A frozenset of required nodes visited
|
|
|
|
def __lt__(self, other):
|
|
return self.cost < other.cost # For priority queue
|
|
|
|
class RoutePlanner:
|
|
def __init__(self, world: GameWorld):
|
|
self.world: GameWorld = world
|
|
|
|
def heuristic(self, state, goal_conditions, required_nodes):
|
|
# Since we don't have actual distances, we can use the number of badges remaining as the heuristic
|
|
remaining_conditions = goal_conditions - state.conditions
|
|
remaining_nodes = required_nodes - state.visited_required_nodes
|
|
return len(remaining_conditions) + len(remaining_nodes)
|
|
|
|
def heuristic2(self, state, goal_conditions, end_goal, required_nodes, distances):
|
|
remaining_conditions = goal_conditions - state.conditions
|
|
remaining_nodes = required_nodes - state.visited_required_nodes
|
|
|
|
# Find the shortest distance from current_state.location to any required node + eventually to the goal
|
|
# As a simple first step: take the minimum distance from the current node to any required node or the goal.
|
|
node_candidates = list(remaining_nodes) + [end_goal]
|
|
min_dist = float('inf')
|
|
for candidate in node_candidates:
|
|
d = distances.get((state.location, candidate), float('inf'))
|
|
if d < min_dist:
|
|
min_dist = d
|
|
|
|
# If no must-visit nodes remain, just consider distance to the goal
|
|
if not remaining_nodes:
|
|
min_dist = distances.get((state.location, end_goal), float('inf'))
|
|
|
|
# Combine with remaining conditions count as before
|
|
return len(remaining_conditions) + len(remaining_nodes) + (min_dist if min_dist != float('inf') else 0)
|
|
|
|
|
|
def is_goal_state(self, state, goal_location, goals, required_nodes):
|
|
return (
|
|
state.location == goal_location and
|
|
goals.issubset(state.conditions) and
|
|
required_nodes.issubset(state.visited_required_nodes)
|
|
)
|
|
|
|
def compute_shortest_path(self, graph, key_nodes):
|
|
distances = {} # distances[(u,v)] = shortest distance from u to v ignoring conditions
|
|
|
|
for node in key_nodes:
|
|
dist_from_node = nx.single_source_shortest_path_length(graph, node)
|
|
for other in key_nodes:
|
|
distances[(node, other)] = dist_from_node.get(other, float('inf'))
|
|
|
|
return distances
|
|
|
|
def astar_search(self) -> RoutePlan:
|
|
from collections import deque
|
|
|
|
self.goals = set(self.world.goals)
|
|
|
|
key_nodes = [self.world.start, self.world.end] + list(self.world.towns_and_cities)
|
|
if len(self.world.must_visit) > 0:
|
|
key_nodes += list(self.world.must_visit)
|
|
distances = self.compute_shortest_path(self.world.graph, key_nodes)
|
|
|
|
# Priority queue for open states
|
|
open_list = []
|
|
heapq.heappush(open_list, (0, State(
|
|
location=self.world.start,
|
|
conditions=self.world.initial_conditions, # Start with no conditions
|
|
cost=0,
|
|
path=[self.world.start],
|
|
visited_required_nodes=frozenset([self.world.start]) if self.world.start in self.world.must_visit else frozenset()
|
|
)))
|
|
|
|
# Closed set to keep track of visited states
|
|
closed_set = {}
|
|
|
|
while open_list:
|
|
_, current_state = heapq.heappop(open_list)
|
|
|
|
# Check if we've reached the goal location with all required conditions
|
|
if self.is_goal_state(current_state, self.world.end, self.goals, self.world.must_visit):
|
|
return RoutePlan(current_state.path, current_state.cost, current_state.conditions)
|
|
|
|
# Check if we've already visited this state with equal or better conditions
|
|
state_key = (current_state.location, current_state.conditions, current_state.visited_required_nodes)
|
|
if state_key in closed_set and closed_set[state_key] <= current_state.cost:
|
|
continue # Skip this state
|
|
|
|
closed_set[state_key] = current_state.cost
|
|
|
|
# Expand neighbors via normal moves
|
|
for neighbor in self.world.graph.neighbors(current_state.location):
|
|
edge_data = self.world.graph.get_edge_data(current_state.location, neighbor)
|
|
edge_condition = edge_data.get('condition', [])
|
|
|
|
if edge_condition is None:
|
|
edge_requires = set()
|
|
else:
|
|
edge_requires = set(edge_condition)
|
|
|
|
# Check if we have the required conditions to traverse this edge
|
|
if not edge_requires.issubset(current_state.conditions):
|
|
continue # Can't traverse this edge
|
|
|
|
# Update conditions based on grants at the neighbor node
|
|
neighbor_data = self.world.graph.nodes[neighbor]
|
|
new_conditions = set(current_state.conditions)
|
|
|
|
# Check if the neighbor grants any conditions
|
|
grants = neighbor_data.get('grants_conditions', [])
|
|
for grant in grants:
|
|
required_for_grant = set(grant.get('required_conditions', []))
|
|
if required_for_grant.issubset(new_conditions):
|
|
# We can acquire the condition
|
|
new_conditions.add(grant['condition'])
|
|
|
|
# Update visited required nodes
|
|
new_visited_required_nodes = set(current_state.visited_required_nodes)
|
|
if neighbor in self.world.must_visit:
|
|
new_visited_required_nodes.add(neighbor)
|
|
|
|
new_state = State(
|
|
location=neighbor,
|
|
conditions=frozenset(new_conditions),
|
|
cost=current_state.cost + 1, # Assuming uniform cost; adjust if needed
|
|
path=current_state.path + [neighbor],
|
|
visited_required_nodes=frozenset(new_visited_required_nodes)
|
|
)
|
|
|
|
#estimated_total_cost = new_state.cost + self.heuristic(new_state, self.goals, self.world.must_visit)
|
|
estimated_total_cost = new_state.cost + self.heuristic2(new_state, self.goals, self.world.end, self.world.must_visit, distances)
|
|
|
|
heapq.heappush(open_list, (estimated_total_cost, new_state))
|
|
|
|
# Expand neighbors via FLY if applicable
|
|
if FLY_OUT_OF_BATTLE in current_state.conditions and current_state.location in self.world.towns_and_cities:
|
|
for fly_target in self.world.towns_and_cities:
|
|
if fly_target != current_state.location and fly_target in current_state.path:
|
|
# You can fly to this location
|
|
new_conditions = set(current_state.conditions)
|
|
neighbor_data = self.world.graph.nodes[fly_target]
|
|
grants = neighbor_data.get('grants_conditions', [])
|
|
for grant in grants:
|
|
required_for_grant = set(grant.get('required_conditions', []))
|
|
if required_for_grant.issubset(new_conditions):
|
|
new_conditions.add(grant['condition'])
|
|
|
|
# Update visited required nodes
|
|
new_visited_required_nodes = set(current_state.visited_required_nodes)
|
|
if fly_target in self.world.must_visit:
|
|
new_visited_required_nodes.add(fly_target)
|
|
|
|
fly_state = State(
|
|
location=fly_target,
|
|
conditions=frozenset(new_conditions),
|
|
cost=current_state.cost + 1, # Adjust cost if flying is different
|
|
path=current_state.path + [fly_target],
|
|
visited_required_nodes=frozenset(new_visited_required_nodes)
|
|
)
|
|
#estimated_total_cost = fly_state.cost + self.heuristic(fly_state, self.goals, self.world.must_visit)
|
|
estimated_total_cost = fly_state.cost + self.heuristic2(fly_state, self.goals, self.world.end, self.world.must_visit, distances)
|
|
heapq.heappush(open_list, (estimated_total_cost, fly_state))
|
|
|
|
return None # No path found
|