diff --git a/pokedex/util/astar.py b/pokedex/util/astar.py new file mode 100644 index 0000000..0e54080 --- /dev/null +++ b/pokedex/util/astar.py @@ -0,0 +1,299 @@ +"""A pure-Python implementation of the A* search algorithm +""" + +import heapq + +class Node(object): + """Node for the A* search algorithm. + + To get started, implement the `expand` method and call `search`. + + N.B. Node object must be hashable. + """ + + def expand(self): + """Return a list of (costs, transition, next_node) for next states + + "Next states" are those reachable from this node. + + May return any finite iterable. + """ + raise NotImplementedError + + def estimate(self, goal): + """Return an *optimistic* estimate of the cost to the given goal node. + + If there are multiple goal states, return the lowest estimate among all + of them. + """ + return 0 + + def is_goal(self, goal): + """Return true iff this is a goal node. + """ + return self == goal + + def find_path(self, goal=None, **kwargs): + """Return the best path to the goal + + Returns an iterator of (cost, transition, node) triples, in reverse + order (i.e. the first element will have the total cost and goal node). + + If `goal` will be passed to the `estimate` and `is_goal` methods. + + See a_star for the advanced keyword arguments, `notify` and + `estimate_error_callback`. + """ + paths = self.find_all_paths(goal=goal, **kwargs) + try: + return paths.next() + except StopIteration: + return None + + def find_all_paths(self, goal=None, **kwargs): + """Yield the best path to each goal + + Returns an iterator of paths. See the `search` method for how paths + look. + + Giving the `goal` argument will cause it to search for that goal, + instead of consulting the `is_goal` method. + This means that if you wish to find more than one path, you must not + pass a `goal` to this method, and instead reimplament `is_goal`. + + See a_star for the advanced keyword arguments, `notify` and + `estimate_error_callback`. + """ + return a_star( + initial=self, + expand=lambda s: s.expand(), + estimate=lambda s: s.estimate(goal), + is_goal=lambda s: s.is_goal(goal), + **kwargs) + +### The main algorithm + +def a_star(initial, expand, is_goal, estimate=lambda x: 0, notify=None, + estimate_error_callback=None): + """A* search algorithm for a consistent heuristic + + General background: http://en.wikipedia.org/wiki/A*_search_algorithm + + This algorithm will work in large or infinite search spaces. + + This version of the algorithm is modified for multiple possible goals: + it does not end when it reaches a goal. Rather, it yields the best path + for each goal. + (Exhausting the iterator is of course not recommended for large search + spaces.) + + Returns an iterable of paths, where each path is an iterable of + (cummulative cost, transition, node) triples representing the path to + the goal. The transition is the one leading to the corresponding node. + The path is in reverse order, thus its first element will contain the + total cost and the goal node. + The initial node is not included in the returned path. + + Arguments: + + `initial`: the initial node + + `expand`: function yielding a (cost of transition, transition, next node) + triple for each node reachable from its argument. + The `transition` element is application data; it is not touched, only + returned as part of the best path. + `estimate`: function(x) returning optimistic estimate of cost from node x + to a goal. If not given, 0 will be used for estimates. + `is_goal`: function(x) returning true iff x is a goal node + + `notify`: If given, if is called at each step with three arguments: + - current cost (with estimate). The cost to the next goal will not be + smaller than this. + - current node + - open set cardinality: roughly, an estimate of the size of the + boundary between "explored" and "unexplored" parts of node space + - debug: stats that be useful for debugging or tuning (in this + implementation, this is the open heap size) + The number of calls to notify or the current cost can be useful as + stopping criteria; the other values may help in tuning estimators. + + `estimate_error_callback`: function handling cases where an estimate was + detected not to be optimistic (as A* requires). The function is given a + path (as would be returned by a_star, except it does not lead to a goal + node). By default, nothing is done (indeed, an estimate that's not + strictly optimistic can be useful, esp. if the optimal path is not + required) + """ + # g: best cummulative cost (from initial node) found so far + # h: optimistic estimate of cost to goal + # f: g + h + closed = set() # nodes we don't want to visit again + est = estimate(initial) # estimate total cost + opened = _HeapDict() # node -> (f, g, h) + opened[initial] = (est, 0, est) + came_from = {initial: None} # node -> (prev_node, came_from[prev_node]) + while True: # _HeapDict will raise StopIteration for us + x, (f, g, h) = opened.pop() + closed.add(x) + + if notify is not None: + notify(f, x, len(opened.dict), len(opened.heap)) + + if is_goal(x): + yield _trace_path(came_from[x]) + + for cost, transition, y in expand(x): + if y in closed: + continue + tentative_g = g + cost + + old_f, old_g, h = opened.get(y, (None, None, None)) + + if old_f is None: + h = estimate(y) + elif tentative_g > old_g: + continue + + came_from[y] = ((tentative_g, transition, y), came_from[x]) + new_f = tentative_g + h + + opened[y] = new_f, tentative_g, h + + if estimate_error_callback is not None and new_f < f: + estimate_error_callback(_trace_path(came_from[y])) + +def _trace_path(cdr): + """Backtrace an A* result""" + # Convert a lispy list to a pythony iterator + while cdr: + car, cdr = cdr + yield car + +class _HeapDict(object): + """A custom parallel heap/dict structure -- the best of both worlds. + + This is NOT a general-purpose class; it only supports what a_star needs. + """ + # The dict has the definitive contents + # The heap has (value, key) pairs. It may have some extra elements. + def __init__(self): + self.dict = {} + self.heap = [] + + def __setitem__(self, key, value): + self.dict[key] = value + heapq.heappush(self.heap, (value, key)) + + def __delitem__(self, key): + del self.dict[key] + + def get(self, key, default): + """Return value for key, or default if not found + """ + return self.dict.get(key, default) + + def pop(self): + """Return (key, value) with the smallest value. + + Raise StopIteration (!!) if empty + """ + while True: + try: + value, key = heapq.heappop(self.heap) + if value is self.dict[key]: + del self.dict[key] + return key, value + except KeyError: + # deleted from dict = not here + pass + except IndexError: + # nothing more to pop + raise StopIteration + + +### Example/test + + +def test_example_knights(): + """Test/example: the "knights" problem + + Definition and another solution may be found at: + http://brandon.sternefamily.net/posts/2005/02/a-star-algorithm-in-python/ + """ + # Legal moves + moves = { 1: [4, 7], + 2: [8, 10], + 3: [9], + 4: [1, 6, 10], + 5: [7], + 6: [4], + 7: [1, 5], + 8: [2, 9], + 9: [8, 3], + 10: [2, 4] } + + class Positions(dict, Node): + """Node class representing positions as a dictionary. + + Keys are unique piece names, values are (color, position) where color + is True for white, False for black. + """ + def expand(self): + for piece, (color, position) in self.items(): + for new_position in moves[position]: + if new_position not in (p for c, p in self.values()): + new_node = Positions(self) + new_node.update({piece: (color, new_position)}) + yield 1, None, new_node + + def estimate(self, goal): + # Number of misplaced figures + misplaced = 0 + for piece, (color, position) in self.items(): + if (color, position) not in goal.values(): + misplaced += 1 + return misplaced + + def is_goal(self, goal): + return self.estimate(goal) == 0 + + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + initial = Positions({ + 'White 1': (True, 1), + 'white 2': (True, 6), + 'Black 1': (False, 5), + 'black 2': (False, 7), + }) + + # Goal: colors should be switched + goal = Positions((piece, (not color, position)) + for piece, (color, position) in initial.items()) + + def print_board(positions, linebreak='\n', extra=''): + board = dict((position, piece) + for piece, (color, position) in positions.items()) + for i in range(1, 11): + # line breaks + if i in (2, 6, 9): + print linebreak, + print board.get(i, '_')[0], + print extra + + def notify(cost, state, b, c): + print 'Looking at state with cost %s:' % cost, + print_board(state, '|', '(%s; %s; %s)' % (state.estimate(goal), b, c)) + + solution_path = list(initial.search(goal, notify=notify)) + + print 'Step', 0 + print_board(initial) + for i, (cost, transition, positions) in enumerate(reversed(solution_path)): + print 'Step', i + 1 + print_board(positions) + + # Check solution is correct + cost, transition, positions = solution_path[0] + assert set(positions.values()) == set(goal.values()) + assert cost == 40 diff --git a/pokedex/util/movesets.py b/pokedex/util/movesets.py new file mode 100755 index 0000000..6f90a06 --- /dev/null +++ b/pokedex/util/movesets.py @@ -0,0 +1,327 @@ +#! /usr/bin/env python +# Encoding: UTF-8 + +import sys +import argparse +from collections import defaultdict + +from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.sql.expression import not_, and_, or_ + +from pokedex.db import connect, tables, util + +from pokedex.util import querytimer +from pokedex.util.astar import a_star + +class IllegalMoveCombination(ValueError): pass +class TooManyMoves(IllegalMoveCombination): pass +class NoMoves(IllegalMoveCombination): pass +class MovesNotLearnable(IllegalMoveCombination): pass +class NoParent(IllegalMoveCombination): pass +class TargetExcluded(IllegalMoveCombination): pass + +class MovesetSearch(object): + def __init__(self, session, pokemon, version, moves, level=100, costs=None, + exclude_versions=(), exclude_pokemon=(), debug=False): + + self.generator = self.error = None + + if not moves: + self.error = NoMoves('No moves specified.') + return + elif len(moves) > 4: + self.error = NoMoves('Too many moves specified.') + return + + self.debug = debug + + self.session = session + + if costs is None: + self.costs = default_costs + else: + self.costs = costs + + self.excluded_families = frozenset(p.evolution_chain_id + for p in exclude_pokemon) + + if pokemon: + self.goal_evolution_chain = pokemon.evolution_chain_id + if self.goal_evolution_chain in self.excluded_families: + self.error = TargetExcluded('The target pokemon was excluded.') + return + else: + self.goal_evolution_chain = None + + if debug: + print 'Specified moves:', [move.id for move in moves] + + self.goal_moves = frozenset(move.id for move in moves) + self.goal_version_group = version.version_group_id + + # Fill self.generation_id_by_version_group + self.load_version_groups(version.version_group_id, + [v.version_group_id for v in exclude_versions]) + + self.pokemon_moves = defaultdict( # key: pokemon + lambda: defaultdict( # key: move + lambda: defaultdict( # key: version_group + lambda: defaultdict( # key: method + list)))) # list of (level, cost) + self.movepools = defaultdict(dict) # evo chain -> move -> best cost + self.learnpools = defaultdict(set) # as above, but not egg moves + + easy_moves, non_egg_moves = self.load_pokemon_moves( + self.goal_evolution_chain, 'family') + + hard_moves = self.goal_moves - easy_moves + egg_moves = self.goal_moves - non_egg_moves + if hard_moves: + # Have to breed! + self.load_pokemon_moves(self.goal_evolution_chain, 'others') + + def load_version_groups(self, version, excluded): + query = self.session.query(tables.VersionGroup.id, + tables.VersionGroup.generation_id) + query = query.join(tables.Version.version_group) + if excluded: + query = query.filter(not_(tables.VersionGroup.id.in_(excluded))) + self.generation_id_by_version_group = dict(query) + def expand(v2): + for v1 in self.generation_id_by_version_group: + if self.trade_cost(v1, v2): + yield 0, None, v1 + def is_goal(v): + return True + goal = self.goal_version_group + filtered_map = {goal: self.generation_id_by_version_group[goal]} + for result in a_star(self.goal_version_group, expand, is_goal): + for cost, transition, version in result: + filtered_map[version] = ( + self.generation_id_by_version_group[version]) + self.generation_id_by_version_group = filtered_map + if self.debug: + print 'Excluded version groups:', excluded + print 'Trade cost table:' + print '%03s' % '', + for g1 in sorted(self.generation_id_by_version_group): + print '%03s' % g1, + print + for g1 in sorted(self.generation_id_by_version_group): + print '%03s' % g1, + for g2 in sorted(self.generation_id_by_version_group): + print '%03s' % (self.trade_cost(g1, g2) or '---'), + print + + def load_pokemon_moves(self, evolution_chain, selection): + """Load pokemon_moves, movepools, learnpools + + `selection`: + 'family' for loading only pokemon in evolution_chain + 'others' for loading only pokemon NOT in evolution_chain + + Returns: (easy_moves, non_egg_moves) + If `selection` == 'family': + easy_moves is a set of moves that are easier to obtain than by + breeding + non_egg_moves is a set of moves that don't require breeding + Otherwise, these are empty sets. + """ + if self.debug: + print 'Loading moves, c%s %s' % (evolution_chain, selection) + query = self.session.query( + tables.PokemonMove.pokemon_id, + tables.PokemonMove.move_id, + tables.PokemonMove.version_group_id, + tables.PokemonMoveMethod.identifier, + tables.PokemonMove.level, + tables.Pokemon.evolution_chain_id, + ) + query = query.join(tables.PokemonMove.pokemon) + query = query.filter(tables.PokemonMoveMethod.id == + tables.PokemonMove.pokemon_move_method_id) + query = query.filter(tables.PokemonMove.version_group_id.in_( + set(self.generation_id_by_version_group))) + query = query.filter(or_( + tables.PokemonMove.level > 100, # XXX: Chaff? + tables.PokemonMove.move_id.in_(self.goal_moves), + )) + if self.excluded_families: + query = query.filter(not_(tables.Pokemon.evolution_chain_id.in_( + self.excluded_families))) + if evolution_chain: + if selection == 'family': + query = query.filter(tables.Pokemon.evolution_chain_id == ( + evolution_chain)) + elif selection == 'others': + query = query.filter(tables.Pokemon.evolution_chain_id != ( + evolution_chain)) + query = query.order_by(tables.PokemonMove.level) + easy_moves = set() + non_egg_moves = set() + for pokemon, move, vg, method, level, chain in query: + if move in self.goal_moves: + cost = self.learn_cost(method, vg) + self.movepools[chain][move] = min( + self.movepools[chain].get(move, cost), cost) + if method != 'egg': + self.learnpools[chain].add(move) + non_egg_moves.add(move) + if cost < self.costs['breed']: + easy_moves.add(move) + else: + cost = 0 + self.pokemon_moves[pokemon][move][vg][method].append((level, cost)) + if self.debug and selection == 'family': + print 'Easy moves:', sorted(easy_moves) + print 'Non-egg moves:', sorted(non_egg_moves) + return easy_moves, non_egg_moves + + def learn_cost(self, method, version_group): + """Return cost of learning a move by method (identifier) in ver. group + """ + if method == 'level-up': + return self.costs['level-up'] + gen = self.generation_id_by_version_group[version_group] + if method == 'machine' and gen < 5: + return self.costs['machine-once'] + elif method == 'tutor' and gen == 3: + return self.costs['tutor-once'] + elif method == 'egg': + return self.costs['breed'] + else: + return self.costs[method] + + def trade_cost(self, version_group_from, version_group_to, *thing_generations): + """Return cost of trading between versions, None if impossibble + + `thing_generations` should be the generation IDs of the pokemon and + moves being traded. + """ + # XXX: this ignores HM transfer restrictions + gen_from = self.generation_id_by_version_group[version_group_from] + gen_to = self.generation_id_by_version_group[version_group_to] + if gen_from == gen_to: + return self.costs['trade'] + elif any(gen > gen_to for gen in thing_generations): + return None + elif gen_from in (1, 2): + if gen_to in (1, 2): + return self.costs['trade'] + else: + return None + elif gen_to in (1, 2): + return None + elif gen_from > gen_to: + return None + elif gen_from < gen_to - 1: + return None + else: + return self.costs['trade'] + self.costs['transfer'] + +default_costs = { + # Costs for learning a move in verious ways + 'level-up': 20, # The normal way + 'machine': 40, # Machines are slightly inconvenient. + 'machine-once': 2000, # before gen. 5, TMs only work once. Avoid. + 'tutor': 60, # Tutors are slightly more inconvenient than TMs – can't carry them around + 'tutor-once': 2100, # gen III: tutors only work once (well except Emerald frontier ones) + 'sketch': 10, # Quite cheap. (Doesn't include learning Sketch itself) + + # Gimmick moves – we need to use this method to learn the move anyway, + # so make a big-ish dent in the score if missing + 'stadium-surfing-pikachu': 100, + 'light-ball-egg': 100, # … + + # Ugh... I don't know? + 'colosseum-purification': 100, + 'xd-shadow': 100, + 'xd-purification': 100, + 'form-change': 100, + + # Other actions. + # Breeding should cost more than 3 times than a lv-up/machine/tutor move. + 'evolution': 100, # We have to do this anyway, usually. + 'evolution-delayed': 50, # *in addition* to evolution. Who wants to mash B on every level. + 'breed': 400, # Breeding's a pain. + 'trade': 200, # Trading's a pain, but not as much as breeding. + 'transfer': 200, # *in addition* to trade. For one-way cross-generation transfers + 'delete': 300, # Deleting a move. (Not needed unless deleting an evolution move.) + 'relearn': 150, # Also a pain, though not as big as breeding. + 'per-level': 1, # Prefer less grinding. This is for all lv-ups but the final “grow” +} + +def main(argv): + parser = argparse.ArgumentParser(description= + 'Find out if the specified moveset is valid, and provide a suggestion ' + 'on how to obtain it.') + + parser.add_argument('pokemon', metavar='POKEMON', type=unicode, + help='Pokemon to check the moveset for') + + parser.add_argument('move', metavar='MOVE', type=unicode, nargs='*', + help='Moves in the moveset') + + parser.add_argument('-l', '--level', metavar='LV', type=int, default=100, + help='Level of the pokemon') + + parser.add_argument('-v', '--version', metavar='VER', type=unicode, + default='black', + help='Version to search in.') + + parser.add_argument('-V', '--exclude-version', metavar='VER', type=unicode, + action='append', default=[], + help='Versions to exclude (along with their ' + 'counterparts, if any, e.g. `black` will also exclude White).') + + parser.add_argument('-P', '--exclude-pokemon', metavar='PKM', type=unicode, + action='append', default=[], + help='Pokemon to exclude (along with their families, e.g. `pichu` ' + 'will also exclude Pikachu and Raichu).') + + parser.add_argument('-d', '--debug', action='append_const', const=1, + default=[], + help='Output timing and debugging information (can be specified more ' + 'than once).') + + args = parser.parse_args(argv) + args.debug = len(args.debug) + + if args.debug: + print 'Connecting' + + session = connect(engine_args={'echo': args.debug > 1}) + + if args.debug: + print 'Parsing arguments' + + def _get_list(table, idents, name): + result = [] + for ident in idents: + try: + result.append(util.get(session, table, identifier=ident)) + except NoResultFound: + print>>sys.stderr, ('%s %s not found. Please use ' + 'the identifier.' % (name, ident)) + return 2 + return result + + pokemon = _get_list(tables.Pokemon, [args.pokemon], 'Pokemon')[0] + moves = _get_list(tables.Move, args.move, 'Move') + version = _get_list(tables.Version, [args.version], 'Version')[0] + excl_versions = _get_list(tables.Version, args.exclude_version, 'Version') + excl_pokemon = _get_list(tables.Pokemon, args.exclude_pokemon, 'Pokemon') + + if args.debug: + print 'Starting search' + + search = MovesetSearch(session, pokemon, version, moves, args.level, + exclude_versions=excl_versions, exclude_pokemon=excl_pokemon, + debug=args.debug) + + if search.error: + print 'Error:', search.error + return 1 + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/setup.py b/setup.py index 5c71530..596c0ef 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ setup( 'whoosh>=2.2.2', 'markdown', 'construct', + 'argparse', ], entry_points = {