001/*-------------------------------------------------------------------------+ 002| | 003| Copyright 2005-2011 The ConQAT Project | 004| | 005| Licensed under the Apache License, Version 2.0 (the "License"); | 006| you may not use this file except in compliance with the License. | 007| You may obtain a copy of the License at | 008| | 009| http://www.apache.org/licenses/LICENSE-2.0 | 010| | 011| Unless required by applicable law or agreed to in writing, software | 012| distributed under the License is distributed on an "AS IS" BASIS, | 013| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 014| See the License for the specific language governing permissions and | 015| limitations under the License. | 016+-------------------------------------------------------------------------*/ 017package org.conqat.lib.commons.algo; 018 019import java.util.Arrays; 020import java.util.List; 021 022import org.conqat.lib.commons.collections.PairList; 023 024/** 025 * A class for calculating maximum weighted matching using an augmenting path 026 * algorithm running in O(n^3*m), where n is the size of the smaller node set 027 * and m the size of the larger one. In practice the running time is much less. 028 * <p> 029 * This class is not thread save! 030 * 031 * @author hummelb 032 * 033 * @param <N1> 034 * The first node type 035 * @param <N2> 036 * The second node type 037 */ 038public class MaxWeightMatching<N1, N2> { 039 040 /** 041 * Flag indicating whether we are running in swapped mode. Swapped mode is 042 * needed as our algorithm requires the second set of nodes not to be 043 * smaller than the first set. If this is not the case, we just swap these 044 * sets, but we need this flag to adjust some parts of the code. 045 */ 046 private boolean swapped; 047 048 /** Size of the first (or second if {@link #swapped}) node set. */ 049 private int size1; 050 051 /** Size of the second (or first if {@link #swapped}) node set. */ 052 private int size2; 053 054 /** The first node set. */ 055 private List<N1> nodes1; 056 057 /** The second node set. */ 058 private List<N2> nodes2; 059 060 /** The provider for the weights (i.e. weight matrix). */ 061 private IWeightProvider<N1, N2> weightProvider; 062 063 /** 064 * Cache used to reduce the number of queries to {@link #weightProvider}. 065 * See {@link #getWeight(int, int)}. 066 * 067 * Since the runtime of this algorithm is high (see class comment), we 068 * expect runtime to limit the data size on which this algorithm is called. 069 * We thus do not use a memory-sensitive caching scheme here. 070 */ 071 private Double[][] weightCache; 072 073 /** 074 * This array stores for each node of the second set the index of the node 075 * from the first set, it is matched to (or -1 if is not in matching). If 076 * {@link #swapped}, first and second set change meaning. 077 */ 078 private int[] mate = new int[16]; 079 080 /** 081 * This is used while searching shortest path and stores the node index we 082 * came from. 083 */ 084 private int[] from = new int[16]; 085 086 /** 087 * This is used while searching shortest path and stores the distance (i.e. 088 * weight sum) to this node. 089 */ 090 private double[] dist = new double[16]; 091 092 /** 093 * Calculate the weighted bipartite matching. 094 * 095 * @param matching 096 * if this is non <code>null</code>, the matching (i.e. the pairs 097 * of nodes matched onto each other) will be put into it. 098 * 099 * @return the weight of the matching. 100 */ 101 public double calculateMatching(List<N1> nodes1, List<N2> nodes2, IWeightProvider<N1, N2> weightProvider, 102 PairList<N1, N2> matching) { 103 104 if (matching != null) { 105 matching.clear(); 106 } 107 108 if (nodes1.isEmpty() || nodes2.isEmpty()) { 109 return 0; 110 } 111 112 init(nodes1, nodes2, weightProvider); 113 prepareInternalArrays(); 114 115 for (int i = 0; i < size1; ++i) { 116 augmentFrom(i); 117 } 118 119 double res = 0; 120 for (int i = 0; i < size2; ++i) { 121 if (mate[i] >= 0) { 122 if (matching != null) { 123 if (swapped) { 124 matching.add(nodes1.get(i), nodes2.get(mate[i])); 125 } else { 126 matching.add(nodes1.get(mate[i]), nodes2.get(i)); 127 } 128 } 129 res += getWeight(mate[i], i); 130 } 131 } 132 return res; 133 } 134 135 /** 136 * Initializes the data structures from the parameters to the 137 * {@link #calculateMatching(List, List, org.conqat.lib.commons.algo.MaxWeightMatching.IWeightProvider, PairList)} 138 * method. 139 */ 140 private void init(List<N1> nodes1, List<N2> nodes2, IWeightProvider<N1, N2> weightProvider) { 141 if (nodes1.size() <= nodes2.size()) { 142 size1 = nodes1.size(); 143 size2 = nodes2.size(); 144 swapped = false; 145 } else { 146 size1 = nodes2.size(); 147 size2 = nodes1.size(); 148 swapped = true; 149 } 150 this.nodes1 = nodes1; 151 this.nodes2 = nodes2; 152 this.weightProvider = weightProvider; 153 weightCache = new Double[nodes1.size()][nodes2.size()]; 154 } 155 156 /** Make sure all internal arrays are large enough. */ 157 private void prepareInternalArrays() { 158 if (size2 > mate.length) { 159 int newSize = mate.length; 160 while (newSize < size2) { 161 newSize *= 2; 162 } 163 mate = new int[newSize]; 164 from = new int[newSize]; 165 dist = new double[newSize]; 166 } 167 168 Arrays.fill(mate, 0, size2, -1); 169 } 170 171 /** 172 * Calculate shortest augmenting path and augment along it starting from the 173 * given node (index). 174 */ 175 private void augmentFrom(int u) { 176 for (int i = 0; i < size2; ++i) { 177 from[i] = -1; 178 dist[i] = getWeight(u, i); 179 } 180 bellmanFord(); 181 int target = findBestUnmatchedTarget(); 182 augmentAlongPath(u, target); 183 } 184 185 /** Calculate the shortest path using Bellman-Ford algorithm. */ 186 private void bellmanFord() { 187 boolean changed = true; 188 while (changed) { 189 changed = false; 190 for (int i = 0; i < size2; ++i) { 191 if (mate[i] < 0) { 192 continue; 193 } 194 double w = getWeight(mate[i], i); 195 for (int j = 0; j < size2; ++j) { 196 if (i == j) { 197 continue; 198 } 199 double newDist = dist[i] - w + getWeight(mate[i], j); 200 if (newDist - 1e-15 > dist[j]) { 201 dist[j] = newDist; 202 from[j] = i; 203 changed = true; 204 } 205 } 206 } 207 } 208 } 209 210 /** Find the best target which is not yet in the matching. */ 211 private int findBestUnmatchedTarget() { 212 int target = -1; 213 for (int i = 0; i < size2; ++i) { 214 if (mate[i] < 0) { 215 if (target < 0 || dist[i] > dist[target]) { 216 target = i; 217 } 218 } 219 } 220 return target; 221 } 222 223 /** 224 * Augment along the given path to the target by adjusting the mate array. 225 */ 226 private void augmentAlongPath(int u, int target) { 227 while (from[target] >= 0) { 228 mate[target] = mate[from[target]]; 229 target = from[target]; 230 } 231 mate[target] = u; 232 } 233 234 /** 235 * Returns the weight between two nodes (=indices) handling swapping 236 * transparently. The weight is cached to (1) reduce the number of calls to 237 * the (potentially expensive) weight provider and (2) protect against 238 * non-deterministic weight providers that do not return consistent weights 239 * in queries with the same parameter. 240 */ 241 private double getWeight(int i1, int i2) { 242 243 int k1 = i1; 244 int k2 = i2; 245 if (swapped) { 246 k1 = i2; 247 k2 = i1; 248 } 249 250 Double result = weightCache[k1][k2]; 251 if (result == null) { 252 result = weightProvider.getConnectionWeight(nodes1.get(k1), nodes2.get(k2)); 253 weightCache[k1][k2] = result; 254 } 255 return result; 256 } 257 258 /** A class providing the weight for a connection between two nodes. */ 259 public interface IWeightProvider<N1, N2> { 260 261 /** Returns the weight of the connection between both nodes. */ 262 double getConnectionWeight(N1 node1, N2 node2); 263 } 264}