2011-04-03 20:07:36 -07:00
|
|
|
package magic.ai;
|
|
|
|
|
|
|
|
import java.util.Random;
|
|
|
|
import java.util.Arrays;
|
|
|
|
import java.util.List;
|
|
|
|
import java.util.LinkedList;
|
|
|
|
import java.util.ArrayList;
|
2011-04-07 02:19:23 -07:00
|
|
|
import java.util.Iterator;
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
import magic.model.MagicGame;
|
|
|
|
import magic.model.phase.MagicPhase;
|
|
|
|
import magic.model.MagicPlayer;
|
|
|
|
import magic.model.event.MagicEvent;
|
|
|
|
|
|
|
|
/*
|
|
|
|
UCT algorithm from Kocsis and Sezepesvari 2006
|
|
|
|
|
|
|
|
function playOneSeq(root)
|
|
|
|
nodes = [root]
|
|
|
|
while (nodes.last is not leaf) do
|
|
|
|
nodes append descendByUCB1(node.last)
|
|
|
|
//assume value of leaf nodes are known
|
|
|
|
//node.init is all elements except the last one
|
|
|
|
updateValue(nodes.init, -nodes.last.value)
|
|
|
|
|
|
|
|
function descendByUCB1(node)
|
|
|
|
nb = sum of nb in node's children
|
|
|
|
for each node n in node's children
|
|
|
|
if n.nb = 0
|
|
|
|
v[n] = infinity
|
|
|
|
else
|
|
|
|
v[n] = 1 - n.value/n.nb + sqrt(2 * log(nb) / n.nb)
|
|
|
|
return n that maximizes v[n]
|
|
|
|
|
|
|
|
function updateValue(nodes, value)
|
|
|
|
for each node n in nodes
|
|
|
|
n.value += value
|
|
|
|
n.nb += 1
|
|
|
|
value = 1 - value
|
|
|
|
|
|
|
|
Modified UCT for MoGO in Wang and Gelly 2007
|
|
|
|
|
|
|
|
function playOneGame(state)
|
|
|
|
create node root from current game state
|
|
|
|
init tree to empty tree
|
|
|
|
while there is time and memory
|
|
|
|
//build the game tree one node at a time
|
|
|
|
playOneSeqMC(root, tree)
|
|
|
|
return descendByUCB1(root)
|
|
|
|
|
|
|
|
function playOneSeqMC(root, tree)
|
|
|
|
nodes = [root]
|
|
|
|
while (nodes.last is not in the tree)
|
|
|
|
nodes append descendByUCB1(node.last)
|
|
|
|
tree add nodes.last
|
|
|
|
nodes.last.value = getValueByMC(nodes.last)
|
|
|
|
updateValue(nodes.init, -nodes.last.value)
|
|
|
|
|
|
|
|
function getValueByMC(node)
|
|
|
|
play one random game starting from node
|
|
|
|
return 1 if player 1 (max) wins, 0 if player 2 wins (min)
|
|
|
|
*/
|
|
|
|
|
|
|
|
//AI using Monte Carlo Tree Search
|
2011-04-04 10:22:52 -07:00
|
|
|
@SuppressWarnings("unused")
|
2011-04-03 20:07:36 -07:00
|
|
|
public class MCTSAI implements MagicAI {
|
|
|
|
|
2011-04-07 23:44:28 -07:00
|
|
|
private static final int MAXSIM = 1000;
|
|
|
|
private static final int MAXTIME = 10000;
|
|
|
|
private static final boolean LOGGING = false;
|
|
|
|
private final Random RNG = new Random(123);
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
private static void log(final String message) {
|
|
|
|
if (LOGGING) {
|
|
|
|
System.out.println(message);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
private static void logc(final char message) {
|
|
|
|
if (LOGGING) {
|
|
|
|
System.out.print(message);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
public synchronized Object[] findNextEventChoiceResults(
|
|
|
|
final MagicGame game,
|
|
|
|
final MagicPlayer scorePlayer) {
|
|
|
|
|
2011-04-04 20:31:42 -07:00
|
|
|
final long startTime = System.currentTimeMillis();
|
2011-04-08 18:47:53 -07:00
|
|
|
final String pinfo = "MCTS " + scorePlayer.getIndex() + "(" + scorePlayer.getLife() + ")";
|
2011-04-05 22:27:39 -07:00
|
|
|
final List<Object[]> choices = getCR(game, scorePlayer);
|
2011-04-03 20:07:36 -07:00
|
|
|
final int size = choices.size();
|
|
|
|
|
|
|
|
// No choice results
|
|
|
|
if (size == 0) {
|
2011-04-04 20:31:42 -07:00
|
|
|
log(pinfo + " NO CHOICE");
|
|
|
|
return null;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// Single choice result
|
|
|
|
if (size == 1) {
|
2011-04-05 22:27:39 -07:00
|
|
|
final ArtificialChoiceResults selected = getACR(choices).get(0);
|
|
|
|
log(pinfo + " " + selected);
|
|
|
|
return game.map(selected.choiceResults);
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// repeat a number of simulations
|
|
|
|
// each simulation does the following
|
|
|
|
// selects a path down the game tree and create a new leaf
|
|
|
|
// score the leaf by doing a random play to the end of the game
|
|
|
|
// update the score of all the ancestors of the leaf
|
2011-04-05 22:27:39 -07:00
|
|
|
// return the "best" choice
|
2011-04-07 23:44:28 -07:00
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
//root represents the start state
|
2011-04-07 23:44:28 -07:00
|
|
|
final MCTSGameTree root = new MCTSGameTree(-1, -1);
|
|
|
|
for (int i = 1; i <= MAXSIM && System.currentTimeMillis() - startTime < MAXTIME; i++) {
|
2011-04-05 22:27:39 -07:00
|
|
|
//create a new MagicGame for simulation
|
2011-04-08 18:47:53 -07:00
|
|
|
final MagicGame start = new MagicGame(game, scorePlayer, true);
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
//pass in a clone of the state, genNewTreeNode grows the tree by one node
|
|
|
|
//and returns the path from the root to the new node
|
|
|
|
final List<MCTSGameTree> path = genNewTreeNode(root, start);
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
// play a simulated game to get score
|
|
|
|
// update all nodes along the path from root to new node
|
|
|
|
final int score = randomPlay(start);
|
|
|
|
logc((score == 1) ? '.' : 'X');
|
|
|
|
for (MCTSGameTree node : path) {
|
2011-04-03 20:07:36 -07:00
|
|
|
node.updateScore(score);
|
|
|
|
}
|
|
|
|
}
|
2011-04-05 22:27:39 -07:00
|
|
|
logc('\n');
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-08 18:47:53 -07:00
|
|
|
//select the best choice (child that has the largest visit count)
|
|
|
|
int maxV = -1;
|
2011-04-07 23:44:28 -07:00
|
|
|
int maxS = 0;
|
2011-04-03 20:07:36 -07:00
|
|
|
int idx = -1;
|
2011-04-05 22:27:39 -07:00
|
|
|
final List<ArtificialChoiceResults> achoices = getACR(choices);
|
2011-04-07 02:19:23 -07:00
|
|
|
for (MCTSGameTree node : root) {
|
2011-04-05 22:27:39 -07:00
|
|
|
achoices.get(node.getChoice()).worker = node.getScore();
|
|
|
|
achoices.get(node.getChoice()).gameCount = node.getNumSim();
|
2011-04-08 18:47:53 -07:00
|
|
|
if (node.getNumSim() > maxV) {
|
|
|
|
maxV = node.getNumSim();
|
2011-04-07 23:44:28 -07:00
|
|
|
maxS = node.getEvalScore();
|
2011-04-03 20:07:36 -07:00
|
|
|
idx = node.getChoice();
|
|
|
|
}
|
|
|
|
}
|
2011-04-07 23:44:28 -07:00
|
|
|
|
2011-04-04 20:31:42 -07:00
|
|
|
final long duration = System.currentTimeMillis() - startTime;
|
|
|
|
log("MCTS took " + duration);
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
log(pinfo);
|
2011-04-03 20:07:36 -07:00
|
|
|
final ArtificialChoiceResults selected = achoices.get(idx);
|
|
|
|
for (final ArtificialChoiceResults achoice : achoices) {
|
2011-04-05 22:27:39 -07:00
|
|
|
log((achoice == selected ? "* ":" ") + achoice);
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return game.map(selected.choiceResults);
|
|
|
|
}
|
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
private List<Object[]> getCR(final MagicGame game, final MagicPlayer player) {
|
2011-04-03 20:07:36 -07:00
|
|
|
final MagicGame choiceGame = new MagicGame(game, player);
|
|
|
|
final MagicEvent event = choiceGame.getNextEvent();
|
|
|
|
return event.getArtificialChoiceResults(choiceGame);
|
|
|
|
}
|
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
private List<ArtificialChoiceResults> getACR(final List<Object[]> choices) {
|
|
|
|
final List<ArtificialChoiceResults> aiChoiceResultsList =
|
|
|
|
new ArrayList<ArtificialChoiceResults>();
|
2011-04-03 20:07:36 -07:00
|
|
|
for (final Object choiceResults[] : choices) {
|
|
|
|
aiChoiceResultsList.add(new ArtificialChoiceResults(choiceResults));
|
|
|
|
}
|
|
|
|
return aiChoiceResultsList;
|
|
|
|
}
|
|
|
|
|
|
|
|
// p is parent of n
|
|
|
|
// n.nb is how many times the node n is simulated
|
|
|
|
// sum of nb in all children of parent of n (same as p.nb)
|
|
|
|
// select node n (child of node) that maximize v[n]
|
|
|
|
// where v[n] = 1 - n.value/n.nb + sqrt(2 * log(nb) / n.nb)
|
2011-04-04 20:31:42 -07:00
|
|
|
// find a path from root to an unexplored node
|
2011-04-06 21:31:41 -07:00
|
|
|
private List<MCTSGameTree> genNewTreeNode(final MCTSGameTree root, final MagicGame game) {
|
|
|
|
final List<MCTSGameTree> path = new LinkedList<MCTSGameTree>();
|
2011-04-05 22:27:39 -07:00
|
|
|
MCTSGameTree curr = root;
|
|
|
|
path.add(curr);
|
2011-04-04 20:31:42 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
for (MagicEvent event = getNextMultiChoiceEvent(game, curr != root);
|
|
|
|
event != null;
|
|
|
|
event = getNextMultiChoiceEvent(game, curr != root)) {
|
|
|
|
|
|
|
|
final List<Object[]> choices = event.getArtificialChoiceResults(game);
|
2011-04-06 21:31:41 -07:00
|
|
|
|
2011-04-07 02:19:23 -07:00
|
|
|
assert choices.size() > 1 : "number of choices is " + choices.size();
|
2011-04-05 22:27:39 -07:00
|
|
|
|
|
|
|
if (curr.size() < choices.size()) {
|
|
|
|
//there are unexplored children of node
|
|
|
|
//assume we explore children of a node in increasing order of the choices
|
2011-04-07 02:19:23 -07:00
|
|
|
game.executeNextEvent(choices.get(curr.size()));
|
2011-04-07 23:44:28 -07:00
|
|
|
final MCTSGameTree child = new MCTSGameTree(curr.size(), game.getScore());
|
2011-04-05 22:27:39 -07:00
|
|
|
curr.addChild(child);
|
|
|
|
path.add(child);
|
|
|
|
return path;
|
|
|
|
} else {
|
|
|
|
final int totalSim = curr.getNumSim();
|
|
|
|
double bestV = -1e10;
|
2011-04-07 23:44:28 -07:00
|
|
|
MCTSGameTree child = curr.first();
|
2011-04-07 02:19:23 -07:00
|
|
|
for (MCTSGameTree node : curr) {
|
2011-04-06 21:31:41 -07:00
|
|
|
final double v =
|
2011-04-05 22:27:39 -07:00
|
|
|
((game.getScorePlayer() == event.getPlayer()) ? 1.0 : -1.0) * node.getV() +
|
|
|
|
Math.sqrt(2.0 * Math.log(totalSim) / node.getNumSim());
|
2011-04-08 18:47:53 -07:00
|
|
|
if (v > bestV) {
|
2011-04-05 22:27:39 -07:00
|
|
|
bestV = v;
|
2011-04-06 21:31:41 -07:00
|
|
|
child = node;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-04-06 21:31:41 -07:00
|
|
|
|
|
|
|
//move down the tree
|
|
|
|
curr = child;
|
2011-04-07 23:44:28 -07:00
|
|
|
assert curr != null;
|
2011-04-08 18:47:53 -07:00
|
|
|
|
|
|
|
//QQQ: choices.get crashed with out of bounds exception (index 4, size 4)
|
2011-04-05 22:27:39 -07:00
|
|
|
game.executeNextEvent(choices.get(curr.getChoice()));
|
|
|
|
path.add(curr);
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-04-05 22:27:39 -07:00
|
|
|
}
|
2011-04-06 21:31:41 -07:00
|
|
|
|
|
|
|
//game is finished
|
2011-04-07 02:19:23 -07:00
|
|
|
assert game.isFinished() : "game is not finished";
|
2011-04-04 20:31:42 -07:00
|
|
|
return path;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-04-05 22:27:39 -07:00
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
private int randomPlay(final MagicGame game) {
|
2011-04-04 20:31:42 -07:00
|
|
|
// play game until it is finished
|
2011-04-05 22:27:39 -07:00
|
|
|
for (MagicEvent event = getNextMultiChoiceEvent(game, true);
|
|
|
|
event != null;
|
|
|
|
event = getNextMultiChoiceEvent(game, true)) {
|
|
|
|
final List<Object[]> choices = event.getArtificialChoiceResults(game);
|
|
|
|
final int idx = RNG.nextInt(choices.size());
|
|
|
|
final Object[] selected = choices.get(idx);
|
2011-04-07 23:44:28 -07:00
|
|
|
//logc('-');
|
2011-04-05 22:27:39 -07:00
|
|
|
game.executeNextEvent(selected);
|
|
|
|
}
|
2011-04-06 21:31:41 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
// game is finished, check who lost
|
2011-04-07 02:19:23 -07:00
|
|
|
assert game.isFinished() : "game is not finished";
|
|
|
|
assert (game.getLosingPlayer() != null) : "losing player is null";
|
2011-04-07 23:44:28 -07:00
|
|
|
assert (game.getMainPhaseCount() > 0) : "main phase count is zero";
|
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
if (game.getLosingPlayer() == game.getScorePlayer()) {
|
|
|
|
return -1;
|
|
|
|
} else {
|
|
|
|
return 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private MagicEvent getNextMultiChoiceEvent(MagicGame game, boolean fastChoices) {
|
|
|
|
game.setFastChoices(fastChoices);
|
2011-04-07 23:44:28 -07:00
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
while (!game.isFinished()) {
|
2011-04-07 23:44:28 -07:00
|
|
|
if (!game.hasNextEvent()) {
|
2011-04-05 22:27:39 -07:00
|
|
|
game.getPhase().executePhase(game);
|
2011-04-07 23:44:28 -07:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
//game has next event
|
|
|
|
//logc('e');
|
|
|
|
final MagicEvent event = game.getNextEvent();
|
|
|
|
//logc('E');
|
|
|
|
|
|
|
|
if (!event.hasChoice()) {
|
|
|
|
game.executeNextEvent(MagicEvent.NO_CHOICE_RESULTS);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
//event has choice
|
|
|
|
//logc('c');
|
|
|
|
final List<Object[]> choices = event.getArtificialChoiceResults(game);
|
|
|
|
//logc('C');
|
|
|
|
final int size = choices.size();
|
|
|
|
if (size == 0) {
|
|
|
|
//QQQ: when does this occur?
|
|
|
|
assert false : "size of choices is 0" ;
|
|
|
|
} else if (size == 1) {
|
|
|
|
game.executeNextEvent(choices.get(0));
|
|
|
|
} else {
|
|
|
|
//multiple choice
|
|
|
|
return event;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-04-07 23:44:28 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
//game is finished
|
|
|
|
return null;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//only store one copy of MagicGame
|
|
|
|
//each tree node stores the choice from the parent that leads to this node
|
|
|
|
//so we only need one copy of MagicGame for MCTSAI
|
2011-04-07 02:19:23 -07:00
|
|
|
class MCTSGameTree implements Iterable<MCTSGameTree> {
|
2011-04-03 20:07:36 -07:00
|
|
|
private final int choice;
|
2011-04-07 02:19:23 -07:00
|
|
|
private final List<MCTSGameTree> children = new LinkedList<MCTSGameTree>();
|
2011-04-03 20:07:36 -07:00
|
|
|
private int numSim = 0;
|
|
|
|
private int score = 0;
|
2011-04-07 23:44:28 -07:00
|
|
|
private int evalScore = 0;
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-07 23:44:28 -07:00
|
|
|
public MCTSGameTree(int choice, int evalScore) {
|
|
|
|
this.evalScore = evalScore;
|
2011-04-03 20:07:36 -07:00
|
|
|
this.choice = choice;
|
|
|
|
}
|
2011-04-07 23:44:28 -07:00
|
|
|
|
|
|
|
public MCTSGameTree first() {
|
|
|
|
return children.get(0);
|
|
|
|
}
|
2011-04-07 02:19:23 -07:00
|
|
|
|
|
|
|
public Iterator<MCTSGameTree> iterator() {
|
|
|
|
return children.iterator();
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
public int getChoice() {
|
|
|
|
return choice;
|
|
|
|
}
|
|
|
|
|
2011-04-07 23:44:28 -07:00
|
|
|
public int getEvalScore() {
|
|
|
|
return evalScore;
|
|
|
|
}
|
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
public int getScore() {
|
|
|
|
return score;
|
|
|
|
}
|
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
public void updateScore(final int score) {
|
2011-04-03 20:07:36 -07:00
|
|
|
this.score += score;
|
|
|
|
numSim += 1;
|
|
|
|
}
|
|
|
|
|
|
|
|
public int getNumSim() {
|
|
|
|
return numSim;
|
|
|
|
}
|
|
|
|
|
|
|
|
public double getV() {
|
|
|
|
return (double)score / numSim;
|
|
|
|
}
|
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
public void addChild(MCTSGameTree child) {
|
2011-04-03 20:07:36 -07:00
|
|
|
children.add(child);
|
|
|
|
}
|
|
|
|
|
|
|
|
public int size() {
|
|
|
|
return children.size();
|
|
|
|
}
|
|
|
|
}
|