2011-04-03 20:07:36 -07:00
|
|
|
package magic.ai;
|
|
|
|
|
2011-06-09 21:03:27 -07:00
|
|
|
import java.util.*;
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
import magic.model.MagicGame;
|
|
|
|
import magic.model.MagicPlayer;
|
2011-06-16 00:12:14 -07:00
|
|
|
import magic.model.phase.MagicPhase;
|
2011-04-03 20:07:36 -07:00
|
|
|
import magic.model.event.MagicEvent;
|
2011-06-16 00:12:14 -07:00
|
|
|
import magic.model.choice.*;
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
/*
|
|
|
|
UCT algorithm from Kocsis and Sezepesvari 2006
|
|
|
|
|
|
|
|
function playOneSeq(root)
|
|
|
|
nodes = [root]
|
|
|
|
while (nodes.last is not leaf) do
|
|
|
|
nodes append descendByUCB1(node.last)
|
|
|
|
//assume value of leaf nodes are known
|
|
|
|
//node.init is all elements except the last one
|
|
|
|
updateValue(nodes.init, -nodes.last.value)
|
|
|
|
|
|
|
|
function descendByUCB1(node)
|
|
|
|
nb = sum of nb in node's children
|
|
|
|
for each node n in node's children
|
|
|
|
if n.nb = 0
|
|
|
|
v[n] = infinity
|
|
|
|
else
|
|
|
|
v[n] = 1 - n.value/n.nb + sqrt(2 * log(nb) / n.nb)
|
|
|
|
return n that maximizes v[n]
|
|
|
|
|
|
|
|
function updateValue(nodes, value)
|
|
|
|
for each node n in nodes
|
|
|
|
n.value += value
|
|
|
|
n.nb += 1
|
|
|
|
value = 1 - value
|
|
|
|
|
|
|
|
Modified UCT for MoGO in Wang and Gelly 2007
|
|
|
|
|
|
|
|
function playOneGame(state)
|
|
|
|
create node root from current game state
|
|
|
|
init tree to empty tree
|
|
|
|
while there is time and memory
|
|
|
|
//build the game tree one node at a time
|
|
|
|
playOneSeqMC(root, tree)
|
|
|
|
return descendByUCB1(root)
|
|
|
|
|
|
|
|
function playOneSeqMC(root, tree)
|
|
|
|
nodes = [root]
|
|
|
|
while (nodes.last is not in the tree)
|
|
|
|
nodes append descendByUCB1(node.last)
|
|
|
|
tree add nodes.last
|
|
|
|
nodes.last.value = getValueByMC(nodes.last)
|
|
|
|
updateValue(nodes.init, -nodes.last.value)
|
|
|
|
|
|
|
|
function getValueByMC(node)
|
|
|
|
play one random game starting from node
|
|
|
|
return 1 if player 1 (max) wins, 0 if player 2 wins (min)
|
|
|
|
*/
|
|
|
|
|
|
|
|
//AI using Monte Carlo Tree Search
|
|
|
|
public class MCTSAI implements MagicAI {
|
2011-06-06 02:05:09 -07:00
|
|
|
|
2011-06-11 00:40:25 -07:00
|
|
|
private final List<Integer> LENS = new LinkedList<Integer>();
|
2011-04-08 20:05:09 -07:00
|
|
|
private final boolean LOGGING;
|
2011-06-07 05:57:29 -07:00
|
|
|
private final boolean CHEAT;
|
2011-06-16 22:11:46 -07:00
|
|
|
private static final int MAX_ACTIONS = 10000;
|
2011-06-11 01:25:23 -07:00
|
|
|
|
2011-06-09 23:28:47 -07:00
|
|
|
//higher C -> more exploration less exploitation
|
2011-06-10 00:24:40 -07:00
|
|
|
static final double C = 1.0;
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-06-15 01:16:41 -07:00
|
|
|
//boost score of win nodes by BOOST
|
2011-06-11 01:25:23 -07:00
|
|
|
static final int BOOST = 1000000;
|
|
|
|
|
2011-06-15 01:16:41 -07:00
|
|
|
//cache nodes to reuse them in later decision
|
2011-06-14 00:24:07 -07:00
|
|
|
private final NodeCache cache = new NodeCache(1000);
|
|
|
|
|
2011-04-08 20:05:09 -07:00
|
|
|
public MCTSAI() {
|
2011-06-15 01:16:41 -07:00
|
|
|
//no loggig, cheats
|
2011-06-07 05:57:29 -07:00
|
|
|
this(false, true);
|
2011-04-08 20:05:09 -07:00
|
|
|
}
|
|
|
|
|
2011-06-15 01:16:41 -07:00
|
|
|
public MCTSAI(final boolean printLog, final boolean cheat) {
|
2011-06-10 23:02:31 -07:00
|
|
|
LOGGING = printLog || (System.getProperty("debug") != null);
|
2011-06-07 05:57:29 -07:00
|
|
|
CHEAT = cheat;
|
2011-04-08 20:05:09 -07:00
|
|
|
}
|
2011-06-15 21:17:26 -07:00
|
|
|
|
2011-06-16 00:12:14 -07:00
|
|
|
static public int obj2StringHash(Object obj) {
|
|
|
|
return obj2String(obj).hashCode();
|
2011-06-15 21:17:26 -07:00
|
|
|
}
|
2011-04-08 20:05:09 -07:00
|
|
|
|
2011-06-16 00:12:14 -07:00
|
|
|
static public String obj2String(Object obj) {
|
|
|
|
if (obj == null) {
|
|
|
|
return "null";
|
|
|
|
} else if (obj instanceof MagicBuilderPayManaCostResult) {
|
|
|
|
return ((MagicBuilderPayManaCostResult)obj).getText();
|
|
|
|
} else {
|
|
|
|
return obj.toString();
|
2011-06-15 21:17:26 -07:00
|
|
|
}
|
2011-06-16 00:12:14 -07:00
|
|
|
}
|
2011-06-15 21:17:26 -07:00
|
|
|
|
2011-06-16 00:12:14 -07:00
|
|
|
private void addNode(final MagicGame game, final MCTSGameTree node) {
|
2011-06-14 00:24:07 -07:00
|
|
|
final long gid = game.getGameId();
|
2011-06-15 00:59:27 -07:00
|
|
|
cache.put(gid, node);
|
2011-06-15 21:17:26 -07:00
|
|
|
node.setCached();
|
2011-06-17 19:59:05 -07:00
|
|
|
System.err.println("ADDED: " + game.getIdString());
|
2011-06-14 00:24:07 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
private MCTSGameTree getNode(final MagicGame game, List<Object[]> rootChoices) {
|
|
|
|
final long gid = game.getGameId();
|
|
|
|
MCTSGameTree candidate = cache.get(gid);
|
|
|
|
|
|
|
|
if (candidate != null) {
|
2011-06-16 02:25:01 -07:00
|
|
|
System.err.println("CACHE HIT");
|
2011-06-14 20:23:22 -07:00
|
|
|
System.err.println("HIT : " + game.getIdString());
|
2011-06-14 21:52:28 -07:00
|
|
|
printNode(candidate, rootChoices);
|
2011-06-14 00:24:07 -07:00
|
|
|
return candidate;
|
|
|
|
} else {
|
2011-06-16 02:25:01 -07:00
|
|
|
System.err.println("CACHE MISS");
|
2011-06-14 20:23:22 -07:00
|
|
|
System.err.println("MISS : " + game.getIdString());
|
2011-06-14 19:40:21 -07:00
|
|
|
printNode(candidate, rootChoices);
|
2011-06-14 00:24:07 -07:00
|
|
|
return new MCTSGameTree(-1, -1, -1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-04-08 20:05:09 -07:00
|
|
|
private void log(final String message) {
|
2011-04-03 20:07:36 -07:00
|
|
|
if (LOGGING) {
|
2011-06-05 19:45:45 -07:00
|
|
|
System.err.println(message);
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-06-09 21:03:27 -07:00
|
|
|
|
2011-06-11 02:58:52 -07:00
|
|
|
private double UCT(final MCTSGameTree parent, final MCTSGameTree child) {
|
|
|
|
return (parent.isAI() ? 1.0 : -1.0) * child.getV() +
|
|
|
|
C * Math.sqrt(Math.log(parent.getNumSim()) / child.getNumSim());
|
2011-06-10 03:11:33 -07:00
|
|
|
}
|
2011-06-09 21:03:27 -07:00
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
public synchronized Object[] findNextEventChoiceResults(
|
2011-06-10 21:27:24 -07:00
|
|
|
final MagicGame startGame,
|
2011-04-03 20:07:36 -07:00
|
|
|
final MagicPlayer scorePlayer) {
|
2011-06-13 01:25:31 -07:00
|
|
|
|
|
|
|
final MagicGame choiceGame = new MagicGame(startGame, scorePlayer);
|
|
|
|
final MagicEvent event = choiceGame.getNextEvent();
|
|
|
|
final List<Object[]> rootChoices = event.getArtificialChoiceResults(choiceGame);
|
|
|
|
|
|
|
|
final int size = rootChoices.size();
|
2011-04-08 20:00:06 -07:00
|
|
|
final String pinfo = "MCTS " + scorePlayer.getIndex() + " (" + scorePlayer.getLife() + ")";
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
// No choice results
|
2011-06-14 00:24:07 -07:00
|
|
|
assert size > 0 : "ERROR! MCTS: no choice found at start";
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
// Single choice result
|
|
|
|
if (size == 1) {
|
2011-06-13 01:25:31 -07:00
|
|
|
return startGame.map(rootChoices.get(0));
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
2011-06-13 01:25:31 -07:00
|
|
|
//ArtificialLevel = number of seconds to run MCTSAI
|
2011-06-17 03:16:53 -07:00
|
|
|
//debugging: max time is 1 billion, max sim is 500
|
|
|
|
//normal : max time is 1000 * str, max sim is 1 billion
|
|
|
|
final int MAXTIME = System.getProperty("debug") != null ?
|
|
|
|
1000000000 : 1000 * startGame.getArtificialLevel(scorePlayer.getIndex());
|
|
|
|
final int MAXSIM = System.getProperty("debug") != null ?
|
|
|
|
500 : 1000000000;
|
|
|
|
final long STARTTIME = System.currentTimeMillis();
|
2011-04-07 23:44:28 -07:00
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
//root represents the start state
|
2011-06-14 18:59:08 -07:00
|
|
|
//final MCTSGameTree root = new MCTSGameTree(-1, -1, -1);
|
|
|
|
final MCTSGameTree root = getNode(startGame, rootChoices);
|
2011-06-15 21:17:26 -07:00
|
|
|
root.desc = "root";
|
2011-06-11 00:40:25 -07:00
|
|
|
LENS.clear();
|
2011-06-10 21:27:24 -07:00
|
|
|
|
2011-06-10 20:45:40 -07:00
|
|
|
//end simulations once root is solved or time is up
|
2011-06-14 21:52:28 -07:00
|
|
|
int sims = 0;
|
2011-06-17 03:16:53 -07:00
|
|
|
for (; System.currentTimeMillis() - STARTTIME < MAXTIME &&
|
|
|
|
sims < MAXSIM &&
|
|
|
|
!root.isAIWin(); sims++) {
|
2011-06-17 01:35:17 -07:00
|
|
|
//clone the MagicGame object for simulation
|
2011-06-10 21:27:24 -07:00
|
|
|
final MagicGame rootGame = new MagicGame(startGame, scorePlayer);
|
2011-06-07 05:57:29 -07:00
|
|
|
if (!CHEAT) {
|
2011-06-10 21:27:24 -07:00
|
|
|
rootGame.setKnownCards();
|
2011-06-07 05:57:29 -07:00
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
//pass in a clone of the state, genNewTreeNode grows the tree by one node
|
|
|
|
//and returns the path from the root to the new node
|
2011-06-15 21:17:26 -07:00
|
|
|
final LinkedList<MCTSGameTree> path = growTree(root, rootGame);
|
2011-06-14 00:24:07 -07:00
|
|
|
|
|
|
|
assert path.size() >= 2 : "ERROR! MCTS: length of path is " + path.size();
|
2011-06-11 00:40:25 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
// play a simulated game to get score
|
|
|
|
// update all nodes along the path from root to new node
|
2011-06-10 21:27:24 -07:00
|
|
|
final double score = randomPlay(path.getLast(), rootGame);
|
2011-06-10 20:10:50 -07:00
|
|
|
|
|
|
|
for (MCTSGameTree child = null, parent = null;
|
2011-06-10 20:45:40 -07:00
|
|
|
!path.isEmpty();
|
|
|
|
child = parent) {
|
2011-06-10 20:10:50 -07:00
|
|
|
|
|
|
|
parent = path.removeLast();
|
|
|
|
parent.updateScore(score);
|
|
|
|
|
|
|
|
//update game theoretic value
|
|
|
|
if (child != null && child.isSolved()) {
|
|
|
|
if (parent.isAI() && child.isAIWin()) {
|
|
|
|
parent.setAIWin();
|
|
|
|
} else if (parent.isAI() && child.isAILose()) {
|
|
|
|
parent.incLose();
|
|
|
|
} else if (!parent.isAI() && child.isAIWin()) {
|
|
|
|
parent.incLose();
|
|
|
|
} else if (!parent.isAI() && child.isAILose()) {
|
|
|
|
parent.setAILose();
|
|
|
|
}
|
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-06-10 00:24:40 -07:00
|
|
|
|
2011-06-14 00:24:07 -07:00
|
|
|
assert root.size() > 0 : "ERROR! MCTS: root has no children but there are " + size + " choices";
|
2011-06-10 06:42:45 -07:00
|
|
|
|
2011-04-10 07:35:57 -07:00
|
|
|
//select the best choice (child that has the highest secure score)
|
2011-06-07 19:54:04 -07:00
|
|
|
final MCTSGameTree first = root.first();
|
2011-06-10 20:10:50 -07:00
|
|
|
double maxR = first.getRank();
|
2011-06-07 19:54:04 -07:00
|
|
|
int bestC = first.getChoice();
|
2011-04-07 02:19:23 -07:00
|
|
|
for (MCTSGameTree node : root) {
|
2011-06-10 20:10:50 -07:00
|
|
|
final double R = node.getRank();
|
2011-06-07 19:54:04 -07:00
|
|
|
final int C = node.getChoice();
|
2011-06-10 20:10:50 -07:00
|
|
|
if (R > maxR) {
|
|
|
|
maxR = R;
|
2011-06-07 19:54:04 -07:00
|
|
|
bestC = C;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-06-13 01:25:31 -07:00
|
|
|
final Object[] selected = rootChoices.get(bestC);
|
2011-06-05 21:29:44 -07:00
|
|
|
|
|
|
|
if (LOGGING) {
|
2011-06-10 21:27:24 -07:00
|
|
|
final long duration = System.currentTimeMillis() - STARTTIME;
|
2011-06-05 21:29:44 -07:00
|
|
|
int minL = 1000000;
|
|
|
|
int maxL = -1;
|
|
|
|
int sumL = 0;
|
2011-06-11 00:40:25 -07:00
|
|
|
for (int len : LENS) {
|
2011-06-05 21:29:44 -07:00
|
|
|
sumL += len;
|
|
|
|
if (len > maxL) maxL = len;
|
|
|
|
if (len < minL) minL = len;
|
|
|
|
}
|
2011-06-15 21:17:26 -07:00
|
|
|
log("MCTS:\ttime: " + duration +
|
|
|
|
"\tsims: " + (root.getNumSim() - sims) + "+" + sims +
|
|
|
|
"\tmin: " + minL +
|
|
|
|
"\tmax: " + maxL +
|
|
|
|
"\tavg: " + (sumL / (LENS.size()+1)));
|
2011-06-09 21:03:27 -07:00
|
|
|
log(pinfo);
|
2011-06-10 20:45:40 -07:00
|
|
|
for (MCTSGameTree node : root) {
|
|
|
|
final StringBuffer out = new StringBuffer();
|
|
|
|
if (node.getChoice() == bestC) {
|
|
|
|
out.append("* ");
|
|
|
|
} else {
|
|
|
|
out.append(" ");
|
|
|
|
}
|
|
|
|
out.append('[');
|
|
|
|
out.append((int)(node.getV() * 100));
|
|
|
|
out.append('/');
|
|
|
|
out.append(node.getNumSim());
|
|
|
|
out.append('/');
|
|
|
|
if (node.isAIWin()) {
|
|
|
|
out.append("win");
|
|
|
|
} else if (node.isAILose()) {
|
|
|
|
out.append("lose");
|
|
|
|
} else {
|
|
|
|
out.append("?");
|
|
|
|
}
|
|
|
|
out.append(']');
|
2011-06-13 01:25:31 -07:00
|
|
|
out.append(CR2String(rootChoices.get(node.getChoice())));
|
2011-06-10 20:45:40 -07:00
|
|
|
log(out.toString());
|
2011-06-09 21:03:27 -07:00
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-06-09 21:03:27 -07:00
|
|
|
|
2011-06-10 21:27:24 -07:00
|
|
|
return startGame.map(selected);
|
2011-06-10 20:45:40 -07:00
|
|
|
}
|
|
|
|
|
2011-06-11 02:58:52 -07:00
|
|
|
private static String CR2String(Object[] choiceResults) {
|
2011-06-10 20:45:40 -07:00
|
|
|
final StringBuffer buffer=new StringBuffer();
|
|
|
|
if (choiceResults!=null) {
|
|
|
|
buffer.append(" (");
|
|
|
|
boolean first=true;
|
|
|
|
for (final Object choiceResult : choiceResults) {
|
|
|
|
if (first) {
|
|
|
|
first=false;
|
|
|
|
} else {
|
|
|
|
buffer.append(',');
|
|
|
|
}
|
|
|
|
buffer.append(choiceResult);
|
|
|
|
}
|
|
|
|
buffer.append(')');
|
|
|
|
}
|
|
|
|
return buffer.toString();
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-06-12 02:40:00 -07:00
|
|
|
|
2011-06-14 00:24:07 -07:00
|
|
|
private boolean checkNode(final MCTSGameTree curr, List<Object[]> choices) {
|
2011-06-11 19:30:08 -07:00
|
|
|
for (MCTSGameTree child : curr) {
|
2011-06-16 02:41:39 -07:00
|
|
|
final String checkStr = obj2String(choices.get(child.getChoice())[0]);
|
|
|
|
if (!child.desc.equals(checkStr)) {
|
2011-06-14 00:24:07 -07:00
|
|
|
System.err.println("ERROR! tree node and choice do not match");
|
2011-06-11 19:30:08 -07:00
|
|
|
printNode(curr, choices);
|
2011-06-14 00:24:07 -07:00
|
|
|
return false;
|
2011-06-11 19:30:08 -07:00
|
|
|
}
|
|
|
|
}
|
2011-06-14 00:24:07 -07:00
|
|
|
return true;
|
2011-06-11 19:30:08 -07:00
|
|
|
}
|
|
|
|
|
2011-06-14 20:23:22 -07:00
|
|
|
private static String printNode(final MCTSGameTree curr, List<Object[]> choices) {
|
2011-06-14 19:40:21 -07:00
|
|
|
if (curr == null) {
|
2011-06-14 20:23:22 -07:00
|
|
|
return "NODE is null";
|
|
|
|
}
|
2011-06-15 21:17:26 -07:00
|
|
|
if (curr.choicesStr != null) {
|
|
|
|
for (String str : curr.choicesStr) {
|
|
|
|
System.err.println("PAREN: " + str);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
System.err.println("PAREN: not defined");
|
2011-06-14 19:40:21 -07:00
|
|
|
}
|
2011-06-11 19:30:08 -07:00
|
|
|
for (MCTSGameTree child : curr) {
|
2011-06-15 00:59:27 -07:00
|
|
|
System.err.println("CHILD: " + child.desc);
|
2011-06-11 19:30:08 -07:00
|
|
|
}
|
|
|
|
for (Object[] choice : choices) {
|
2011-06-16 00:12:14 -07:00
|
|
|
final int checksum = obj2StringHash(choice[0]);
|
|
|
|
System.err.println("GAME : " + obj2String(choice[0]));
|
2011-06-11 19:30:08 -07:00
|
|
|
}
|
2011-06-14 20:23:22 -07:00
|
|
|
return "";
|
2011-06-11 19:30:08 -07:00
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-06-18 08:43:13 -07:00
|
|
|
public static String printPath(final List<MCTSGameTree> path) {
|
|
|
|
for (MCTSGameTree p : path) {
|
|
|
|
System.err.print(" -> " + p.desc);
|
|
|
|
}
|
|
|
|
System.err.println();
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
|
2011-06-15 21:17:26 -07:00
|
|
|
private LinkedList<MCTSGameTree> growTree(final MCTSGameTree root, final MagicGame game) {
|
2011-06-10 20:10:50 -07:00
|
|
|
final LinkedList<MCTSGameTree> path = new LinkedList<MCTSGameTree>();
|
2011-06-15 21:17:26 -07:00
|
|
|
boolean found = false;
|
2011-04-05 22:27:39 -07:00
|
|
|
MCTSGameTree curr = root;
|
|
|
|
path.add(curr);
|
2011-04-04 20:31:42 -07:00
|
|
|
|
2011-06-13 01:25:31 -07:00
|
|
|
for (List<Object[]> choices = getNextChoices(game, curr == root, false);
|
2011-06-04 01:50:52 -07:00
|
|
|
choices != null;
|
2011-06-13 01:25:31 -07:00
|
|
|
choices = getNextChoices(game, curr == root, false)) {
|
2011-06-14 19:40:21 -07:00
|
|
|
|
|
|
|
assert choices.size() > 0 : "ERROR! No choice at start of genNewTreeNode";
|
2011-06-15 00:59:27 -07:00
|
|
|
assert !curr.hasDetails() || curr.getMaxChildren() == choices.size() :
|
2011-06-14 20:23:22 -07:00
|
|
|
"ERROR! Capacity of node is " + curr.getMaxChildren() + ", number of choices is " + choices.size()
|
2011-06-18 08:43:13 -07:00
|
|
|
+ printPath(path) + printNode(curr, choices);
|
2011-06-04 01:50:52 -07:00
|
|
|
|
|
|
|
final MagicEvent event = game.getNextEvent();
|
2011-06-15 00:59:27 -07:00
|
|
|
|
|
|
|
//first time considering the choices available at this node
|
|
|
|
if (!curr.hasDetails()) {
|
|
|
|
curr.setIsAI(game.getScorePlayer() == event.getPlayer());
|
|
|
|
curr.setMaxChildren(choices.size());
|
|
|
|
curr.setChoicesStr(choices);
|
2011-06-14 00:24:07 -07:00
|
|
|
}
|
|
|
|
|
2011-06-15 21:17:26 -07:00
|
|
|
//look for first non root AI node along this path and add it to cache
|
|
|
|
if (!found && curr != root && curr.isAI()) {
|
|
|
|
found = true;
|
2011-06-16 00:12:14 -07:00
|
|
|
if (!curr.isCached()) {
|
|
|
|
addNode(game, curr);
|
|
|
|
}
|
2011-06-15 21:17:26 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
//there are unexplored children of node
|
|
|
|
//assume we explore children of a node in increasing order of the choices
|
2011-04-05 22:27:39 -07:00
|
|
|
if (curr.size() < choices.size()) {
|
2011-06-14 00:24:07 -07:00
|
|
|
Object[] choice = choices.get(curr.size());
|
|
|
|
game.executeNextEvent(choice);
|
|
|
|
final MCTSGameTree child = new MCTSGameTree(
|
|
|
|
curr.size(),
|
|
|
|
game.getScore(),
|
2011-06-16 00:12:14 -07:00
|
|
|
obj2StringHash(choice[0]));
|
|
|
|
child.desc = obj2String(choice[0]);
|
2011-04-05 22:27:39 -07:00
|
|
|
curr.addChild(child);
|
|
|
|
path.add(child);
|
|
|
|
return path;
|
2011-06-15 21:17:26 -07:00
|
|
|
|
|
|
|
//all the children are in the tree, find the "best" child to explore
|
2011-04-05 22:27:39 -07:00
|
|
|
} else {
|
2011-06-16 02:41:39 -07:00
|
|
|
assert checkNode(curr, choices);
|
2011-06-11 07:12:16 -07:00
|
|
|
|
2011-06-14 20:23:22 -07:00
|
|
|
assert curr.size() == choices.size() : "ERROR! Different number of choices in node and game" +
|
|
|
|
printNode(curr,choices);
|
2011-06-09 22:36:56 -07:00
|
|
|
|
2011-06-10 20:10:50 -07:00
|
|
|
MCTSGameTree next = null;
|
|
|
|
double bestV = Double.NEGATIVE_INFINITY;
|
|
|
|
for (MCTSGameTree child : curr) {
|
2011-06-11 22:57:06 -07:00
|
|
|
//skip won nodes
|
|
|
|
if (child.isAIWin()) {
|
2011-06-10 20:10:50 -07:00
|
|
|
continue;
|
|
|
|
}
|
2011-06-11 02:58:52 -07:00
|
|
|
final double v = UCT(curr, child);
|
2011-04-08 18:47:53 -07:00
|
|
|
if (v > bestV) {
|
2011-04-05 22:27:39 -07:00
|
|
|
bestV = v;
|
2011-06-10 20:10:50 -07:00
|
|
|
next = child;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-04-06 21:31:41 -07:00
|
|
|
|
|
|
|
//move down the tree
|
2011-06-10 20:10:50 -07:00
|
|
|
curr = next;
|
2011-04-08 18:47:53 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
game.executeNextEvent(choices.get(curr.getChoice()));
|
|
|
|
path.add(curr);
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-04-05 22:27:39 -07:00
|
|
|
}
|
2011-04-06 21:31:41 -07:00
|
|
|
|
2011-04-04 20:31:42 -07:00
|
|
|
return path;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
2011-06-10 20:10:50 -07:00
|
|
|
private double randomPlay(final MCTSGameTree node, final MagicGame game) {
|
|
|
|
//terminal node, no need for random play
|
|
|
|
if (game.isFinished()) {
|
2011-06-14 00:24:07 -07:00
|
|
|
assert game.getLosingPlayer() != null : "ERROR! game finished but no losing player";
|
|
|
|
|
|
|
|
if (game.getLosingPlayer() == game.getScorePlayer()) {
|
2011-06-10 20:10:50 -07:00
|
|
|
node.setAILose();
|
|
|
|
return -1.0;
|
|
|
|
} else {
|
|
|
|
node.setAIWin();
|
|
|
|
return 1.0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-06-16 20:48:15 -07:00
|
|
|
final int startActions = game.getNumActions();
|
2011-06-13 01:25:31 -07:00
|
|
|
getNextChoices(game, false, true);
|
2011-06-16 20:48:15 -07:00
|
|
|
final int actions = game.getNumActions() - startActions;
|
2011-06-10 21:27:24 -07:00
|
|
|
|
|
|
|
if (LOGGING) {
|
2011-06-16 20:48:15 -07:00
|
|
|
LENS.add(actions);
|
2011-06-10 21:27:24 -07:00
|
|
|
}
|
2011-06-10 20:45:40 -07:00
|
|
|
|
2011-04-08 20:00:06 -07:00
|
|
|
if (game.getLosingPlayer() == null) {
|
|
|
|
return 0;
|
|
|
|
} else if (game.getLosingPlayer() == game.getScorePlayer()) {
|
2011-06-16 20:48:15 -07:00
|
|
|
return -(1.0 - actions/((double)MAX_ACTIONS));
|
2011-04-05 22:27:39 -07:00
|
|
|
} else {
|
2011-06-16 20:48:15 -07:00
|
|
|
return 1.0 - actions/((double)MAX_ACTIONS);
|
2011-04-05 22:27:39 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-06-11 07:12:16 -07:00
|
|
|
private List<Object[]> getNextChoices(
|
|
|
|
final MagicGame game,
|
2011-06-13 01:25:31 -07:00
|
|
|
final boolean isRoot,
|
2011-06-11 07:12:16 -07:00
|
|
|
final boolean sim) {
|
2011-06-13 01:25:31 -07:00
|
|
|
|
2011-06-16 20:48:15 -07:00
|
|
|
final int startActions = game.getNumActions();
|
2011-06-06 18:53:35 -07:00
|
|
|
|
2011-06-16 20:48:15 -07:00
|
|
|
// simulate game until it is finished or simulated MAX_ACTIONS actions
|
|
|
|
while (!game.isFinished() && (game.getNumActions() - startActions) < MAX_ACTIONS) {
|
2011-06-17 01:35:17 -07:00
|
|
|
//do not accumulate score down the tree
|
|
|
|
game.setScore(0);
|
|
|
|
|
2011-04-07 23:44:28 -07:00
|
|
|
if (!game.hasNextEvent()) {
|
2011-04-05 22:27:39 -07:00
|
|
|
game.getPhase().executePhase(game);
|
2011-04-07 23:44:28 -07:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
//game has next event
|
|
|
|
final MagicEvent event = game.getNextEvent();
|
|
|
|
|
|
|
|
if (!event.hasChoice()) {
|
|
|
|
game.executeNextEvent(MagicEvent.NO_CHOICE_RESULTS);
|
|
|
|
continue;
|
|
|
|
}
|
2011-06-06 18:53:35 -07:00
|
|
|
|
2011-04-07 23:44:28 -07:00
|
|
|
//event has choice
|
2011-06-06 18:53:35 -07:00
|
|
|
|
|
|
|
if (sim) {
|
|
|
|
//get simulation choice and execute
|
2011-06-10 23:19:49 -07:00
|
|
|
final Object[] choice = event.getSimulationChoiceResult(game);
|
2011-06-14 00:24:07 -07:00
|
|
|
assert choice != null : "ERROR! MCTS: no choice found during sim";
|
|
|
|
game.executeNextEvent(choice);
|
2011-04-07 23:44:28 -07:00
|
|
|
} else {
|
2011-06-10 06:42:45 -07:00
|
|
|
//get list of possible AI choices
|
2011-06-06 18:53:35 -07:00
|
|
|
final List<Object[]> choices = event.getArtificialChoiceResults(game);
|
|
|
|
final int size = choices.size();
|
2011-06-14 00:24:07 -07:00
|
|
|
assert size > 0 : "ERROR! MCTS: no choice found";
|
|
|
|
if (size == 1) {
|
|
|
|
//single choice
|
2011-06-06 18:53:35 -07:00
|
|
|
game.executeNextEvent(choices.get(0));
|
|
|
|
} else {
|
|
|
|
//multiple choice
|
|
|
|
return choices;
|
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
2011-06-06 18:53:35 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
//game is finished
|
2011-06-10 07:09:40 -07:00
|
|
|
return null;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//only store one copy of MagicGame
|
|
|
|
//each tree node stores the choice from the parent that leads to this node
|
|
|
|
//so we only need one copy of MagicGame for MCTSAI
|
2011-04-07 02:19:23 -07:00
|
|
|
class MCTSGameTree implements Iterable<MCTSGameTree> {
|
2011-06-09 23:23:20 -07:00
|
|
|
|
2011-06-09 22:36:56 -07:00
|
|
|
private final LinkedList<MCTSGameTree> children = new LinkedList<MCTSGameTree>();
|
2011-06-10 20:10:50 -07:00
|
|
|
private final int choice;
|
2011-06-14 00:24:07 -07:00
|
|
|
private final int checksum;
|
2011-06-10 20:10:50 -07:00
|
|
|
private boolean isAI;
|
2011-06-15 21:17:26 -07:00
|
|
|
private boolean isCached = false;
|
2011-06-14 19:40:21 -07:00
|
|
|
private int maxChildren = -1;
|
2011-06-10 20:10:50 -07:00
|
|
|
private int numLose = 0;
|
2011-04-03 20:07:36 -07:00
|
|
|
private int numSim = 0;
|
2011-04-07 23:44:28 -07:00
|
|
|
private int evalScore = 0;
|
2011-06-14 00:24:07 -07:00
|
|
|
private double score = 0;
|
2011-06-14 18:59:08 -07:00
|
|
|
public String desc;
|
2011-06-14 20:23:22 -07:00
|
|
|
public String[] choicesStr;
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-06-14 00:24:07 -07:00
|
|
|
public MCTSGameTree(final int choice, final int evalScore, final int checksum) {
|
2011-04-07 23:44:28 -07:00
|
|
|
this.evalScore = evalScore;
|
2011-04-03 20:07:36 -07:00
|
|
|
this.choice = choice;
|
2011-06-14 00:24:07 -07:00
|
|
|
this.checksum = checksum;
|
|
|
|
}
|
|
|
|
|
2011-06-15 21:17:26 -07:00
|
|
|
public boolean isCached() {
|
|
|
|
return isCached;
|
|
|
|
}
|
|
|
|
|
|
|
|
public void setCached() {
|
|
|
|
isCached = true;
|
|
|
|
}
|
|
|
|
|
2011-06-15 00:59:27 -07:00
|
|
|
public boolean hasDetails() {
|
2011-06-15 21:17:26 -07:00
|
|
|
return maxChildren != -1;
|
2011-06-15 00:59:27 -07:00
|
|
|
}
|
|
|
|
|
2011-06-14 00:24:07 -07:00
|
|
|
public int getChecksum() {
|
|
|
|
return checksum;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-06-10 20:10:50 -07:00
|
|
|
|
2011-06-14 20:23:22 -07:00
|
|
|
public void setChoicesStr(List<Object[]> choices) {
|
|
|
|
choicesStr = new String[choices.size()];
|
|
|
|
for (int i = 0; i < choices.size(); i++) {
|
2011-06-16 00:12:14 -07:00
|
|
|
choicesStr[i] = MCTSAI.obj2String(choices.get(i)[0]);
|
2011-06-14 20:23:22 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-06-10 20:10:50 -07:00
|
|
|
public void setMaxChildren(final int mc) {
|
|
|
|
maxChildren = mc;
|
|
|
|
}
|
2011-06-14 19:40:21 -07:00
|
|
|
|
|
|
|
public int getMaxChildren() {
|
|
|
|
return maxChildren;
|
|
|
|
}
|
2011-06-10 20:10:50 -07:00
|
|
|
|
|
|
|
public boolean isAI() {
|
|
|
|
return isAI;
|
|
|
|
}
|
|
|
|
|
|
|
|
public void setIsAI(final boolean ai) {
|
|
|
|
this.isAI = ai;
|
|
|
|
}
|
|
|
|
|
|
|
|
public boolean isSolved() {
|
|
|
|
return evalScore == Integer.MAX_VALUE || evalScore == Integer.MIN_VALUE;
|
|
|
|
}
|
2011-04-07 02:19:23 -07:00
|
|
|
|
2011-04-10 07:35:57 -07:00
|
|
|
public void updateScore(final double score) {
|
|
|
|
this.score += score;
|
|
|
|
numSim += 1;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
|
|
|
|
2011-06-10 20:10:50 -07:00
|
|
|
public boolean isAIWin() {
|
|
|
|
return evalScore == Integer.MAX_VALUE;
|
|
|
|
}
|
|
|
|
|
|
|
|
public boolean isAILose() {
|
|
|
|
return evalScore == Integer.MIN_VALUE;
|
|
|
|
}
|
|
|
|
|
|
|
|
public void incLose() {
|
|
|
|
numLose++;
|
|
|
|
if (numLose == maxChildren) {
|
|
|
|
if (isAI) {
|
|
|
|
setAILose();
|
|
|
|
} else {
|
|
|
|
setAIWin();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
public int getChoice() {
|
|
|
|
return choice;
|
|
|
|
}
|
|
|
|
|
2011-06-10 20:10:50 -07:00
|
|
|
public void setAIWin() {
|
|
|
|
evalScore = Integer.MAX_VALUE;
|
|
|
|
}
|
|
|
|
|
|
|
|
public void setAILose() {
|
|
|
|
evalScore = Integer.MIN_VALUE;
|
|
|
|
}
|
|
|
|
|
2011-04-07 23:44:28 -07:00
|
|
|
public int getEvalScore() {
|
|
|
|
return evalScore;
|
|
|
|
}
|
|
|
|
|
2011-04-10 02:44:39 -07:00
|
|
|
public double getScore() {
|
2011-04-03 20:07:36 -07:00
|
|
|
return score;
|
|
|
|
}
|
2011-06-10 20:10:50 -07:00
|
|
|
|
|
|
|
public double getRank() {
|
|
|
|
if (isAIWin()) {
|
2011-06-11 01:25:23 -07:00
|
|
|
return MCTSAI.BOOST + getNumSim();
|
2011-06-10 20:10:50 -07:00
|
|
|
} else if (isAILose()) {
|
2011-06-14 00:24:07 -07:00
|
|
|
return getNumSim();
|
2011-06-10 20:10:50 -07:00
|
|
|
} else {
|
|
|
|
return getNumSim();
|
|
|
|
}
|
|
|
|
}
|
2011-06-09 23:23:20 -07:00
|
|
|
|
2011-04-03 20:07:36 -07:00
|
|
|
public int getNumSim() {
|
|
|
|
return numSim;
|
|
|
|
}
|
|
|
|
|
|
|
|
public double getV() {
|
2011-04-10 02:44:39 -07:00
|
|
|
return score / numSim;
|
2011-04-03 20:07:36 -07:00
|
|
|
}
|
2011-04-10 07:35:57 -07:00
|
|
|
|
|
|
|
public double getSecureScore() {
|
|
|
|
return getV() + 1.0/Math.sqrt(getNumSim());
|
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
|
2011-04-05 22:27:39 -07:00
|
|
|
public void addChild(MCTSGameTree child) {
|
2011-06-14 19:40:21 -07:00
|
|
|
assert children.size() < maxChildren : "ERROR! Number of children nodes exceed maxChildren";
|
2011-04-03 20:07:36 -07:00
|
|
|
children.add(child);
|
|
|
|
}
|
2011-04-10 07:35:57 -07:00
|
|
|
|
2011-06-09 22:36:56 -07:00
|
|
|
public void removeLast() {
|
|
|
|
children.removeLast();
|
2011-06-09 22:15:48 -07:00
|
|
|
}
|
|
|
|
|
2011-04-10 07:35:57 -07:00
|
|
|
public MCTSGameTree first() {
|
|
|
|
return children.get(0);
|
|
|
|
}
|
|
|
|
|
|
|
|
public Iterator<MCTSGameTree> iterator() {
|
|
|
|
return children.iterator();
|
|
|
|
}
|
2011-04-03 20:07:36 -07:00
|
|
|
|
|
|
|
public int size() {
|
|
|
|
return children.size();
|
|
|
|
}
|
|
|
|
}
|
2011-06-09 21:03:27 -07:00
|
|
|
|
2011-06-14 00:24:07 -07:00
|
|
|
class NodeCache extends LinkedHashMap<Long, MCTSGameTree> {
|
2011-06-10 06:42:45 -07:00
|
|
|
private static final long serialVersionUID = 1L;
|
2011-06-09 21:03:27 -07:00
|
|
|
private final int capacity;
|
2011-06-14 00:24:07 -07:00
|
|
|
public NodeCache(int capacity) {
|
2011-06-09 21:03:27 -07:00
|
|
|
super(capacity + 1, 1.1f, true);
|
|
|
|
this.capacity = capacity;
|
|
|
|
}
|
|
|
|
|
|
|
|
protected boolean removeEldestEntry(Map.Entry eldest) {
|
|
|
|
return size() > capacity;
|
|
|
|
}
|
|
|
|
}
|