001/*-------------------------------------------------------------------------+
002|                                                                          |
003| Copyright 2005-2011 The ConQAT Project                                   |
004|                                                                          |
005| Licensed under the Apache License, Version 2.0 (the "License");          |
006| you may not use this file except in compliance with the License.         |
007| You may obtain a copy of the License at                                  |
008|                                                                          |
009|    http://www.apache.org/licenses/LICENSE-2.0                            |
010|                                                                          |
011| Unless required by applicable law or agreed to in writing, software      |
012| distributed under the License is distributed on an "AS IS" BASIS,        |
013| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
014| See the License for the specific language governing permissions and      |
015| limitations under the License.                                           |
016+-------------------------------------------------------------------------*/
017package org.conqat.lib.commons.algo;
018
019import java.util.Arrays;
020import java.util.List;
021
022import org.conqat.lib.commons.collections.PairList;
023
024/**
025 * A class for calculating maximum weighted matching using an augmenting path
026 * algorithm running in O(n^3*m), where n is the size of the smaller node set
027 * and m the size of the larger one. In practice the running time is much less.
028 * <p>
029 * This class is not thread save!
030 * 
031 * @author hummelb
032 * 
033 * @param <N1>
034 *            The first node type
035 * @param <N2>
036 *            The second node type
037 */
038public class MaxWeightMatching<N1, N2> {
039
040        /**
041         * Flag indicating whether we are running in swapped mode. Swapped mode is
042         * needed as our algorithm requires the second set of nodes not to be
043         * smaller than the first set. If this is not the case, we just swap these
044         * sets, but we need this flag to adjust some parts of the code.
045         */
046        private boolean swapped;
047
048        /** Size of the first (or second if {@link #swapped}) node set. */
049        private int size1;
050
051        /** Size of the second (or first if {@link #swapped}) node set. */
052        private int size2;
053
054        /** The first node set. */
055        private List<N1> nodes1;
056
057        /** The second node set. */
058        private List<N2> nodes2;
059
060        /** The provider for the weights (i.e. weight matrix). */
061        private IWeightProvider<N1, N2> weightProvider;
062
063        /**
064         * Cache used to reduce the number of queries to {@link #weightProvider}.
065         * See {@link #getWeight(int, int)}.
066         * 
067         * Since the runtime of this algorithm is high (see class comment), we
068         * expect runtime to limit the data size on which this algorithm is called.
069         * We thus do not use a memory-sensitive caching scheme here.
070         */
071        private Double[][] weightCache;
072
073        /**
074         * This array stores for each node of the second set the index of the node
075         * from the first set, it is matched to (or -1 if is not in matching). If
076         * {@link #swapped}, first and second set change meaning.
077         */
078        private int[] mate = new int[16];
079
080        /**
081         * This is used while searching shortest path and stores the node index we
082         * came from.
083         */
084        private int[] from = new int[16];
085
086        /**
087         * This is used while searching shortest path and stores the distance (i.e.
088         * weight sum) to this node.
089         */
090        private double[] dist = new double[16];
091
092        /**
093         * Calculate the weighted bipartite matching.
094         * 
095         * @param matching
096         *            if this is non <code>null</code>, the matching (i.e. the pairs
097         *            of nodes matched onto each other) will be put into it.
098         * 
099         * @return the weight of the matching.
100         */
101        public double calculateMatching(List<N1> nodes1, List<N2> nodes2, IWeightProvider<N1, N2> weightProvider,
102                        PairList<N1, N2> matching) {
103
104                if (matching != null) {
105                        matching.clear();
106                }
107
108                if (nodes1.isEmpty() || nodes2.isEmpty()) {
109                        return 0;
110                }
111
112                init(nodes1, nodes2, weightProvider);
113                prepareInternalArrays();
114
115                for (int i = 0; i < size1; ++i) {
116                        augmentFrom(i);
117                }
118
119                double res = 0;
120                for (int i = 0; i < size2; ++i) {
121                        if (mate[i] >= 0) {
122                                if (matching != null) {
123                                        if (swapped) {
124                                                matching.add(nodes1.get(i), nodes2.get(mate[i]));
125                                        } else {
126                                                matching.add(nodes1.get(mate[i]), nodes2.get(i));
127                                        }
128                                }
129                                res += getWeight(mate[i], i);
130                        }
131                }
132                return res;
133        }
134
135        /**
136         * Initializes the data structures from the parameters to the
137         * {@link #calculateMatching(List, List, org.conqat.lib.commons.algo.MaxWeightMatching.IWeightProvider, PairList)}
138         * method.
139         */
140        private void init(List<N1> nodes1, List<N2> nodes2, IWeightProvider<N1, N2> weightProvider) {
141                if (nodes1.size() <= nodes2.size()) {
142                        size1 = nodes1.size();
143                        size2 = nodes2.size();
144                        swapped = false;
145                } else {
146                        size1 = nodes2.size();
147                        size2 = nodes1.size();
148                        swapped = true;
149                }
150                this.nodes1 = nodes1;
151                this.nodes2 = nodes2;
152                this.weightProvider = weightProvider;
153                weightCache = new Double[nodes1.size()][nodes2.size()];
154        }
155
156        /** Make sure all internal arrays are large enough. */
157        private void prepareInternalArrays() {
158                if (size2 > mate.length) {
159                        int newSize = mate.length;
160                        while (newSize < size2) {
161                                newSize *= 2;
162                        }
163                        mate = new int[newSize];
164                        from = new int[newSize];
165                        dist = new double[newSize];
166                }
167
168                Arrays.fill(mate, 0, size2, -1);
169        }
170
171        /**
172         * Calculate shortest augmenting path and augment along it starting from the
173         * given node (index).
174         */
175        private void augmentFrom(int u) {
176                for (int i = 0; i < size2; ++i) {
177                        from[i] = -1;
178                        dist[i] = getWeight(u, i);
179                }
180                bellmanFord();
181                int target = findBestUnmatchedTarget();
182                augmentAlongPath(u, target);
183        }
184
185        /** Calculate the shortest path using Bellman-Ford algorithm. */
186        private void bellmanFord() {
187                boolean changed = true;
188                while (changed) {
189                        changed = false;
190                        for (int i = 0; i < size2; ++i) {
191                                if (mate[i] < 0) {
192                                        continue;
193                                }
194                                double w = getWeight(mate[i], i);
195                                for (int j = 0; j < size2; ++j) {
196                                        if (i == j) {
197                                                continue;
198                                        }
199                                        double newDist = dist[i] - w + getWeight(mate[i], j);
200                                        if (newDist - 1e-15 > dist[j]) {
201                                                dist[j] = newDist;
202                                                from[j] = i;
203                                                changed = true;
204                                        }
205                                }
206                        }
207                }
208        }
209
210        /** Find the best target which is not yet in the matching. */
211        private int findBestUnmatchedTarget() {
212                int target = -1;
213                for (int i = 0; i < size2; ++i) {
214                        if (mate[i] < 0) {
215                                if (target < 0 || dist[i] > dist[target]) {
216                                        target = i;
217                                }
218                        }
219                }
220                return target;
221        }
222
223        /**
224         * Augment along the given path to the target by adjusting the mate array.
225         */
226        private void augmentAlongPath(int u, int target) {
227                while (from[target] >= 0) {
228                        mate[target] = mate[from[target]];
229                        target = from[target];
230                }
231                mate[target] = u;
232        }
233
234        /**
235         * Returns the weight between two nodes (=indices) handling swapping
236         * transparently. The weight is cached to (1) reduce the number of calls to
237         * the (potentially expensive) weight provider and (2) protect against
238         * non-deterministic weight providers that do not return consistent weights
239         * in queries with the same parameter.
240         */
241        private double getWeight(int i1, int i2) {
242
243                int k1 = i1;
244                int k2 = i2;
245                if (swapped) {
246                        k1 = i2;
247                        k2 = i1;
248                }
249
250                Double result = weightCache[k1][k2];
251                if (result == null) {
252                        result = weightProvider.getConnectionWeight(nodes1.get(k1), nodes2.get(k2));
253                        weightCache[k1][k2] = result;
254                }
255                return result;
256        }
257
258        /** A class providing the weight for a connection between two nodes. */
259        public interface IWeightProvider<N1, N2> {
260
261                /** Returns the weight of the connection between both nodes. */
262                double getConnectionWeight(N1 node1, N2 node2);
263        }
264}