magarena/src/magic/ai/MCTSAI.java

664 lines
21 KiB
Java
Raw Normal View History

package magic.ai;
import java.util.*;
import magic.model.MagicGame;
import magic.model.MagicPlayer;
2011-06-16 00:12:14 -07:00
import magic.model.phase.MagicPhase;
import magic.model.event.MagicEvent;
2011-06-16 00:12:14 -07:00
import magic.model.choice.*;
/*
UCT algorithm from Kocsis and Sezepesvari 2006
function playOneSeq(root)
nodes = [root]
while (nodes.last is not leaf) do
nodes append descendByUCB1(node.last)
//assume value of leaf nodes are known
//node.init is all elements except the last one
updateValue(nodes.init, -nodes.last.value)
function descendByUCB1(node)
nb = sum of nb in node's children
for each node n in node's children
if n.nb = 0
v[n] = infinity
else
v[n] = 1 - n.value/n.nb + sqrt(2 * log(nb) / n.nb)
return n that maximizes v[n]
function updateValue(nodes, value)
for each node n in nodes
n.value += value
n.nb += 1
value = 1 - value
Modified UCT for MoGO in Wang and Gelly 2007
function playOneGame(state)
create node root from current game state
init tree to empty tree
while there is time and memory
//build the game tree one node at a time
playOneSeqMC(root, tree)
return descendByUCB1(root)
function playOneSeqMC(root, tree)
nodes = [root]
while (nodes.last is not in the tree)
nodes append descendByUCB1(node.last)
tree add nodes.last
nodes.last.value = getValueByMC(nodes.last)
updateValue(nodes.init, -nodes.last.value)
function getValueByMC(node)
play one random game starting from node
return 1 if player 1 (max) wins, 0 if player 2 wins (min)
*/
//AI using Monte Carlo Tree Search
public class MCTSAI implements MagicAI {
private static final int MAX_ACTIONS = 10000;
2011-06-06 02:05:09 -07:00
private final boolean LOGGING;
2011-06-19 23:17:03 -07:00
private final boolean CHEAT;
private boolean ASSERT;
private final List<Integer> LENS = new LinkedList<Integer>();
//higher C -> more exploration less exploitation
2011-06-10 00:24:40 -07:00
static final double C = 1.0;
2011-06-15 01:16:41 -07:00
//boost score of win nodes by BOOST
static final int BOOST = 1000000;
2011-06-15 01:16:41 -07:00
//cache nodes to reuse them in later decision
2011-06-14 00:24:07 -07:00
private final NodeCache cache = new NodeCache(1000);
public MCTSAI() {
//no logging, cheats
this(false, true);
2011-06-19 23:17:03 -07:00
ASSERT = false;
assert ASSERT = true;
}
2011-06-15 01:16:41 -07:00
public MCTSAI(final boolean printLog, final boolean cheat) {
LOGGING = printLog || (System.getProperty("debug") != null);
CHEAT = cheat;
}
2011-06-16 00:12:14 -07:00
static public int obj2StringHash(Object obj) {
return obj2String(obj).hashCode();
}
2011-06-16 00:12:14 -07:00
static public String obj2String(Object obj) {
if (obj == null) {
return "null";
} else if (obj instanceof MagicBuilderPayManaCostResult) {
return ((MagicBuilderPayManaCostResult)obj).getText();
} else {
return obj.toString();
}
2011-06-16 00:12:14 -07:00
}
2011-06-16 00:12:14 -07:00
private void addNode(final MagicGame game, final MCTSGameTree node) {
if (node.isCached()) {
return;
}
2011-06-14 00:24:07 -07:00
final long gid = game.getGameId();
cache.put(gid, node);
node.setCached();
assert log("ADDED: " + game.getIdString());
2011-06-14 00:24:07 -07:00
}
private MCTSGameTree getNode(final MagicGame game, List<Object[]> rootChoices) {
final long gid = game.getGameId();
MCTSGameTree candidate = cache.get(gid);
if (candidate != null) {
assert log("CACHE HIT");
assert log("HIT : " + game.getIdString());
assert printNode(candidate, rootChoices);
2011-06-14 00:24:07 -07:00
return candidate;
} else {
assert log("CACHE MISS");
assert log("MISS : " + game.getIdString());
return new MCTSGameTree(-1, -1);
2011-06-14 00:24:07 -07:00
}
}
private boolean log(final String message) {
if (LOGGING) {
System.err.println(message);
}
return true;
}
2011-06-11 02:58:52 -07:00
private double UCT(final MCTSGameTree parent, final MCTSGameTree child) {
return (parent.isAI() ? 1.0 : -1.0) * child.getV() +
C * Math.sqrt(Math.log(parent.getNumSim()) / child.getNumSim());
}
public synchronized Object[] findNextEventChoiceResults(
2011-06-10 21:27:24 -07:00
final MagicGame startGame,
final MagicPlayer scorePlayer) {
final MagicGame choiceGame = new MagicGame(startGame, scorePlayer);
final MagicEvent event = choiceGame.getNextEvent();
final List<Object[]> rootChoices = event.getArtificialChoiceResults(choiceGame);
final int size = rootChoices.size();
final String pinfo = "MCTS " + scorePlayer.getIndex() + " (" + scorePlayer.getLife() + ")";
// No choice results
2011-06-14 00:24:07 -07:00
assert size > 0 : "ERROR! MCTS: no choice found at start";
// Single choice result
if (size == 1) {
return startGame.map(rootChoices.get(0));
}
//ArtificialLevel = number of seconds to run MCTSAI
//debugging: max time is 1 billion, max sim is 500
//normal : max time is 1000 * str, max sim is 1 billion
2011-06-19 23:17:03 -07:00
final int MAXTIME = ASSERT ?
1000000000 : 1000 * startGame.getArtificialLevel(scorePlayer.getIndex());
2011-06-19 23:17:03 -07:00
final int MAXSIM = ASSERT ?
500 : 1000000000;
final long STARTTIME = System.currentTimeMillis();
//root represents the start state
//final MCTSGameTree root = new MCTSGameTree(-1, -1, -1);
final MCTSGameTree root = getNode(startGame, rootChoices);
assert root.desc != (root.desc = "root");
LENS.clear();
2011-06-10 21:27:24 -07:00
//end simulations once root is solved or time is up
int sims = 0;
for (; System.currentTimeMillis() - STARTTIME < MAXTIME &&
sims < MAXSIM &&
!root.isAIWin(); sims++) {
//clone the MagicGame object for simulation
2011-06-10 21:27:24 -07:00
final MagicGame rootGame = new MagicGame(startGame, scorePlayer);
if (!CHEAT) {
2011-06-10 21:27:24 -07:00
rootGame.setKnownCards();
}
//pass in a clone of the state, genNewTreeNode grows the tree by one node
//and returns the path from the root to the new node
final LinkedList<MCTSGameTree> path = growTree(root, rootGame);
2011-06-14 00:24:07 -07:00
assert path.size() >= 2 : "ERROR! MCTS: length of path is " + path.size();
// play a simulated game to get score
// update all nodes along the path from root to new node
2011-06-10 21:27:24 -07:00
final double score = randomPlay(path.getLast(), rootGame);
2011-06-10 20:10:50 -07:00
// update score and game theoretic value along the chosen path
2011-06-10 20:10:50 -07:00
for (MCTSGameTree child = null, parent = null;
!path.isEmpty(); child = parent) {
2011-06-10 20:10:50 -07:00
parent = path.removeLast();
parent.updateScore(score);
if (child != null && child.isSolved()) {
if (parent.isAI() && child.isAIWin()) {
parent.setAIWin();
} else if (parent.isAI() && child.isAILose()) {
parent.incLose();
} else if (!parent.isAI() && child.isAIWin()) {
parent.incLose();
} else if (!parent.isAI() && child.isAILose()) {
parent.setAILose();
}
}
}
}
2011-06-10 00:24:40 -07:00
2011-06-14 00:24:07 -07:00
assert root.size() > 0 : "ERROR! MCTS: root has no children but there are " + size + " choices";
//select the best choice (child that has the highest secure score)
final MCTSGameTree first = root.first();
2011-06-10 20:10:50 -07:00
double maxR = first.getRank();
int bestC = first.getChoice();
for (MCTSGameTree node : root) {
2011-06-10 20:10:50 -07:00
final double R = node.getRank();
final int C = node.getChoice();
2011-06-10 20:10:50 -07:00
if (R > maxR) {
maxR = R;
bestC = C;
}
}
final Object[] selected = rootChoices.get(bestC);
if (LOGGING) {
2011-06-10 21:27:24 -07:00
final long duration = System.currentTimeMillis() - STARTTIME;
int minL = 1000000;
int maxL = -1;
int sumL = 0;
for (int len : LENS) {
sumL += len;
if (len > maxL) maxL = len;
if (len < minL) minL = len;
}
log("MCTS:\ttime: " + duration +
"\tsims: " + (root.getNumSim() - sims) + "+" + sims +
"\tmin: " + minL +
"\tmax: " + maxL +
"\tavg: " + (sumL / (LENS.size()+1)));
log(pinfo);
for (MCTSGameTree node : root) {
final StringBuffer out = new StringBuffer();
if (node.getChoice() == bestC) {
out.append("* ");
} else {
out.append(" ");
}
out.append('[');
out.append((int)(node.getV() * 100));
out.append('/');
out.append(node.getNumSim());
out.append('/');
if (node.isAIWin()) {
out.append("win");
} else if (node.isAILose()) {
out.append("lose");
} else {
out.append("?");
}
out.append(']');
out.append(CR2String(rootChoices.get(node.getChoice())));
log(out.toString());
}
}
2011-06-10 21:27:24 -07:00
return startGame.map(selected);
}
2011-06-11 02:58:52 -07:00
private static String CR2String(Object[] choiceResults) {
final StringBuffer buffer=new StringBuffer();
if (choiceResults!=null) {
buffer.append(" (");
boolean first=true;
for (final Object choiceResult : choiceResults) {
if (first) {
first=false;
} else {
buffer.append(',');
}
buffer.append(choiceResult);
}
buffer.append(')');
}
return buffer.toString();
}
2011-06-12 02:40:00 -07:00
2011-06-14 00:24:07 -07:00
private boolean checkNode(final MCTSGameTree curr, List<Object[]> choices) {
if (curr.getMaxChildren() != choices.size()) {
return false;
}
2011-06-18 08:55:38 -07:00
for (int i = 0; i < choices.size(); i++) {
final String checkStr = obj2String(choices.get(i)[0]);
if (!curr.choicesStr[i].equals(checkStr)) {
return false;
}
}
for (MCTSGameTree child : curr) {
2011-06-16 02:41:39 -07:00
final String checkStr = obj2String(choices.get(child.getChoice())[0]);
if (!child.desc.equals(checkStr)) {
2011-06-14 00:24:07 -07:00
return false;
}
}
2011-06-14 00:24:07 -07:00
return true;
}
private boolean printNode(final MCTSGameTree curr, List<Object[]> choices) {
if (curr.choicesStr != null) {
for (String str : curr.choicesStr) {
log("PAREN: " + str);
}
} else {
log("PAREN: not defined");
}
for (MCTSGameTree child : curr) {
log("CHILD: " + child.desc);
}
for (Object[] choice : choices) {
log("GAME : " + obj2String(choice[0]));
}
return true;
}
public boolean printPath(final List<MCTSGameTree> path) {
StringBuffer sb = new StringBuffer();
for (MCTSGameTree p : path) {
sb.append(" -> ").append(p.desc);
}
log(sb.toString());
return true;
}
private LinkedList<MCTSGameTree> growTree(final MCTSGameTree root, final MagicGame game) {
2011-06-10 20:10:50 -07:00
final LinkedList<MCTSGameTree> path = new LinkedList<MCTSGameTree>();
boolean found = false;
MCTSGameTree curr = root;
path.add(curr);
for (List<Object[]> choices = getNextChoices(game, false);
2011-06-04 01:50:52 -07:00
choices != null;
choices = getNextChoices(game, false)) {
assert choices.size() > 0 : "ERROR! No choice at start of genNewTreeNode";
2011-06-18 08:55:38 -07:00
assert !curr.hasDetails() || checkNode(curr, choices) :
"ERROR! Inconsistent node found" + "\n" + game + printPath(path) + printNode(curr, choices);
2011-06-04 01:50:52 -07:00
final MagicEvent event = game.getNextEvent();
//first time considering the choices available at this node,
//fill in additional details for curr
if (!curr.hasDetails()) {
curr.setIsAI(game.getScorePlayer() == event.getPlayer());
curr.setMaxChildren(choices.size());
assert curr.setChoicesStr(choices);
2011-06-14 00:24:07 -07:00
}
//look for first non root AI node along this path and add it to cache
if (!found && curr != root && curr.isAI()) {
found = true;
assert curr.isCached() || printPath(path);
addNode(game, curr);
}
//there are unexplored children of node
//assume we explore children of a node in increasing order of the choices
if (curr.size() < choices.size()) {
final int idx = curr.size();
Object[] choice = choices.get(idx);
2011-06-14 00:24:07 -07:00
game.executeNextEvent(choice);
final MCTSGameTree child = new MCTSGameTree(idx, game.getScore());
assert child.desc != (child.desc = obj2String(choice[0]));
curr.addChild(child);
path.add(child);
return path;
//all the children are in the tree, find the "best" child to explore
} else {
2011-06-11 07:12:16 -07:00
assert curr.size() == choices.size() : "ERROR! Different number of choices in node and game" +
printPath(path) + printNode(curr, choices);
2011-06-10 20:10:50 -07:00
MCTSGameTree next = null;
double bestV = Double.NEGATIVE_INFINITY;
for (MCTSGameTree child : curr) {
//skip won nodes
if (child.isAIWin()) {
2011-06-10 20:10:50 -07:00
continue;
}
2011-06-11 02:58:52 -07:00
final double v = UCT(curr, child);
if (v > bestV) {
bestV = v;
2011-06-10 20:10:50 -07:00
next = child;
}
}
2011-04-06 21:31:41 -07:00
//move down the tree
2011-06-10 20:10:50 -07:00
curr = next;
game.executeNextEvent(choices.get(curr.getChoice()));
path.add(curr);
}
}
2011-04-06 21:31:41 -07:00
return path;
}
2011-06-10 20:10:50 -07:00
private double randomPlay(final MCTSGameTree node, final MagicGame game) {
//terminal node, no need for random play
if (game.isFinished()) {
2011-06-14 00:24:07 -07:00
assert game.getLosingPlayer() != null : "ERROR! game finished but no losing player";
if (game.getLosingPlayer() == game.getScorePlayer()) {
2011-06-10 20:10:50 -07:00
node.setAILose();
return -1.0;
} else {
node.setAIWin();
return 1.0;
}
}
final int startActions = game.getNumActions();
getNextChoices(game, true);
final int actions = game.getNumActions() - startActions;
2011-06-10 21:27:24 -07:00
if (LOGGING) {
LENS.add(actions);
2011-06-10 21:27:24 -07:00
}
if (game.getLosingPlayer() == null) {
return 0;
} else if (game.getLosingPlayer() == game.getScorePlayer()) {
return -(1.0 - actions/((double)MAX_ACTIONS));
} else {
return 1.0 - actions/((double)MAX_ACTIONS);
}
}
2011-06-11 07:12:16 -07:00
private List<Object[]> getNextChoices(
final MagicGame game,
final boolean sim) {
final int startActions = game.getNumActions();
2011-06-21 00:36:52 -07:00
//use fact choices during simulation
game.setFastChoices(sim);
// simulate game until it is finished or simulated MAX_ACTIONS actions
while (!game.isFinished() && (game.getNumActions() - startActions) < MAX_ACTIONS) {
//do not accumulate score down the tree
game.setScore(0);
if (!game.hasNextEvent()) {
game.getPhase().executePhase(game);
continue;
}
//game has next event
final MagicEvent event = game.getNextEvent();
if (!event.hasChoice()) {
game.executeNextEvent(MagicEvent.NO_CHOICE_RESULTS);
continue;
}
//event has choice
if (sim) {
//get simulation choice and execute
2011-06-10 23:19:49 -07:00
final Object[] choice = event.getSimulationChoiceResult(game);
2011-06-14 00:24:07 -07:00
assert choice != null : "ERROR! MCTS: no choice found during sim";
game.executeNextEvent(choice);
} else {
//get list of possible AI choices
final List<Object[]> choices = event.getArtificialChoiceResults(game);
final int size = choices.size();
assert size > 0 : "ERROR! MCTS: no choice found getACR";
2011-06-14 00:24:07 -07:00
if (size == 1) {
//single choice
game.executeNextEvent(choices.get(0));
} else {
//multiple choice
return choices;
}
}
}
//game is finished
return null;
}
}
//each tree node stores the choice from the parent that leads to this node
class MCTSGameTree implements Iterable<MCTSGameTree> {
private final LinkedList<MCTSGameTree> children = new LinkedList<MCTSGameTree>();
2011-06-10 20:10:50 -07:00
private final int choice;
private boolean isAI;
private boolean isCached = false;
private int maxChildren = -1;
2011-06-10 20:10:50 -07:00
private int numLose = 0;
private int numSim = 0;
private int evalScore = 0;
2011-06-14 00:24:07 -07:00
private double score = 0;
public String desc;
public String[] choicesStr;
public MCTSGameTree(final int choice, final int evalScore) {
this.evalScore = evalScore;
this.choice = choice;
2011-06-14 00:24:07 -07:00
}
public boolean isCached() {
return isCached;
}
public void setCached() {
isCached = true;
}
public boolean hasDetails() {
return maxChildren != -1;
}
public boolean setChoicesStr(List<Object[]> choices) {
choicesStr = new String[choices.size()];
for (int i = 0; i < choices.size(); i++) {
2011-06-16 00:12:14 -07:00
choicesStr[i] = MCTSAI.obj2String(choices.get(i)[0]);
}
return true;
}
2011-06-10 20:10:50 -07:00
public void setMaxChildren(final int mc) {
maxChildren = mc;
}
public int getMaxChildren() {
return maxChildren;
}
2011-06-10 20:10:50 -07:00
public boolean isAI() {
return isAI;
}
public void setIsAI(final boolean ai) {
this.isAI = ai;
}
public boolean isSolved() {
return evalScore == Integer.MAX_VALUE || evalScore == Integer.MIN_VALUE;
}
public void updateScore(final double score) {
this.score += score;
numSim += 1;
}
2011-06-10 20:10:50 -07:00
public boolean isAIWin() {
return evalScore == Integer.MAX_VALUE;
}
public boolean isAILose() {
return evalScore == Integer.MIN_VALUE;
}
public void incLose() {
numLose++;
if (numLose == maxChildren) {
if (isAI) {
setAILose();
} else {
setAIWin();
}
}
}
public int getChoice() {
return choice;
}
2011-06-10 20:10:50 -07:00
public void setAIWin() {
evalScore = Integer.MAX_VALUE;
}
public void setAILose() {
evalScore = Integer.MIN_VALUE;
}
public int getEvalScore() {
return evalScore;
}
public double getScore() {
return score;
}
2011-06-10 20:10:50 -07:00
public double getRank() {
if (isAIWin()) {
return MCTSAI.BOOST + getNumSim();
2011-06-10 20:10:50 -07:00
} else if (isAILose()) {
2011-06-14 00:24:07 -07:00
return getNumSim();
2011-06-10 20:10:50 -07:00
} else {
return getNumSim();
}
}
public int getNumSim() {
return numSim;
}
public double getV() {
return score / numSim;
}
public double getSecureScore() {
return getV() + 1.0/Math.sqrt(getNumSim());
}
public void addChild(MCTSGameTree child) {
assert children.size() < maxChildren : "ERROR! Number of children nodes exceed maxChildren";
children.add(child);
}
public void removeLast() {
children.removeLast();
2011-06-09 22:15:48 -07:00
}
public MCTSGameTree first() {
return children.get(0);
}
public Iterator<MCTSGameTree> iterator() {
return children.iterator();
}
public int size() {
return children.size();
}
}
2011-06-14 00:24:07 -07:00
class NodeCache extends LinkedHashMap<Long, MCTSGameTree> {
private static final long serialVersionUID = 1L;
private final int capacity;
2011-06-14 00:24:07 -07:00
public NodeCache(int capacity) {
super(capacity + 1, 1.1f, true);
this.capacity = capacity;
}
protected boolean removeEldestEntry(Map.Entry eldest) {
return size() > capacity;
}
}