001 // --- BEGIN LICENSE BLOCK ---
002 /*
003 * Copyright (c) 2009, Mikio L. Braun
004 * Copyright (c) 2008, Johannes Schaback
005 * Copyright (c) 2009, Jan Saputra M??ller
006 * All rights reserved.
007 *
008 * Redistribution and use in source and binary forms, with or without
009 * modification, are permitted provided that the following conditions are
010 * met:
011 *
012 * * Redistributions of source code must retain the above copyright
013 * notice, this list of conditions and the following disclaimer.
014 *
015 * * Redistributions in binary form must reproduce the above
016 * copyright notice, this list of conditions and the following
017 * disclaimer in the documentation and/or other materials provided
018 * with the distribution.
019 *
020 * * Neither the name of the Technische Universit??t Berlin nor the
021 * names of its contributors may be used to endorse or promote
022 * products derived from this software without specific prior
023 * written permission.
024 *
025 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
026 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
027 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
028 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
029 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
030 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
031 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
032 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
033 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
034 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
035 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
036 */
037 // --- END LICENSE BLOCK ---
038
039 package org.jblas;
040
041 import org.jblas.exceptions.SizeException;
042 import org.jblas.ranges.Range;
043 import java.io.BufferedReader;
044 import java.io.DataInputStream;
045 import java.io.DataOutputStream;
046 import java.io.FileInputStream;
047 import java.io.FileOutputStream;
048 import java.io.IOException;
049 import java.io.InputStreamReader;
050
051 import java.io.PrintWriter;
052 import java.io.StringWriter;
053 import java.util.AbstractList;
054 import java.util.Arrays;
055 import java.util.Comparator;
056 import java.util.Iterator;
057 import java.util.LinkedList;
058 import java.util.List;
059
060 /**
061 * A general matrix class for <tt>float</tt> typed values.
062 *
063 * Don't be intimidated by the large number of methods this function defines. Most
064 * are overloads provided for ease of use. For example, for each arithmetic operation,
065 * up to six overloaded versions exist to handle in-place computations, and
066 * scalar arguments.
067 *
068 * <h3>Construction</h3>
069 *
070 * <p>To construct a two-dimensional matrices, you can use the following constructors
071 * and static methods.</p>
072 *
073 * <table class="my">
074 * <tr><th>Method<th>Description
075 * <tr><td>FloatMatrix(m,n, [value1, value2, value3...])<td>Values are filled in row by row.
076 * <tr><td>FloatMatrix(new float[][] {{value1, value2, ...}, ...}<td>Inner arrays are columns.
077 * <tr><td>FloatMatrix.zeros(m,n) <td>Initial values set to 0.0f.
078 * <tr><td>FloatMatrix.ones(m,n) <td>Initial values set to 1.0f.
079 * <tr><td>FloatMatrix.rand(m,n) <td>Values drawn at random between 0.0f and 1.0f.
080 * <tr><td>FloatMatrix.randn(m,n) <td>Values drawn from normal distribution.
081 * <tr><td>FloatMatrix.eye(n) <td>Unit matrix (values 0.0f except for 1.0f on the diagonal).
082 * <tr><td>FloatMatrix.diag(array) <td>Diagonal matrix with given diagonal elements.
083 * </table>
084 *
085 * <p>Alternatively, you can construct (column) vectors, if you just supply the length
086 * using the following constructors and static methods.</p>
087 *
088 * <table class="my">
089 * <tr><th>Method<th>Description
090 * <tr><td>FloatMatrix(m)<td>Constructs a column vector.
091 * <tr><td>FloatMatrix(new float[] {value1, value2, ...})<td>Constructs a column vector.
092 * <tr><td>FloatMatrix.zeros(m) <td>Initial values set to 1.0f.
093 * <tr><td>FloatMatrix.ones(m) <td>Initial values set to 0.0f.
094 * <tr><td>FloatMatrix.rand(m) <td>Values drawn at random between 0.0f and 1.0f.
095 * <tr><td>FloatMatrix.randn(m) <td>Values drawn from normal distribution.
096 * </table>
097 *
098 * <p>You can also construct new matrices by concatenating matrices either horziontally
099 * or vertically:</p>
100 *
101 * <table class="my">
102 * <tr><th>Method<th>Description
103 * <tr><td>x.concatHorizontally(y)<td>New matrix will be x next to y.
104 * <tr><td>x.concatVertically(y)<td>New matrix will be x atop y.
105 * </table>
106 *
107 * <h3>Element Access, Copying and Duplication</h3>
108 *
109 * <p>To access individual elements, or whole rows and columns, use the following
110 * methods:<p>
111 *
112 * <table class="my">
113 * <tr><th>x.Method<th>Description
114 * <tr><td>x.get(i,j)<td>Get element in row i and column j.
115 * <tr><td>x.put(i, j, v)<td>Set element in row i and column j to value v
116 * <tr><td>x.get(i)<td>Get the ith element of the matrix (traversing rows first).
117 * <tr><td>x.put(i, v)<td>Set the ith element of the matrix (traversing rows first).
118 * <tr><td>x.getColumn(i)<td>Get a copy of column i.
119 * <tr><td>x.putColumn(i, c)<td>Put matrix c into column i.
120 * <tr><td>x.getRow(i)<td>Get a copy of row i.
121 * <tr><td>x.putRow(i, c)<td>Put matrix c into row i.
122 * <tr><td>x.swapColumns(i, j)<td>Swap the contents of columns i and j.
123 * <tr><td>x.swapRows(i, j)<td>Swap the contents of columns i and j.
124 * </table>
125 *
126 * <p>For <tt>get</tt> and <tt>put</tt>, you can also pass integer arrays,
127 * FloatMatrix objects, or Range objects, which then specify the indices used
128 * as follows:
129 *
130 * <ul>
131 * <li><em>integer array:</em> the elements will be used as indices.
132 * <li><em>FloatMatrix object:</em> non-zero entries specify the indices.
133 * <li><em>Range object:</em> see below.
134 * </ul>
135 *
136 * <p>When using <tt>put</tt> with multiple indices, the assigned object must
137 * have the correct size or be a scalar.</p>
138 *
139 * <p>There exist the following Range objects. The Class <tt>RangeUtils</tt> also
140 * contains the a number of handy helper methods for constructing these ranges.</p>
141 * <table class="my">
142 * <tr><th>Class <th>RangeUtils method <th>Indices
143 * <tr><td>AllRange <td>all() <td>All legal indices.
144 * <tr><td>PointRange <td>point(i) <td> A single point.
145 * <tr><td>IntervalRange <td>interval(a, b)<td> All indices from a to b (inclusive)
146 * <tr><td rowspan=3>IndicesRange <td>indices(int[])<td> The specified indices.
147 * <tr><td>indices(FloatMatrix)<td>The specified indices.
148 * <tr><td>find(FloatMatrix)<td>The non-zero entries of the matrix.
149 * </table>
150 *
151 * <p>The following methods can be used for duplicating and copying matrices.</p>
152 *
153 * <table class="my">
154 * <tr><th>Method<th>Description
155 * <tr><td>x.dup()<td>Get a copy of x.
156 * <tr><td>x.copy(y)<td>Copy the contents of y to x (possible resizing x).
157 * </table>
158 *
159 * <h3>Size and Shape</h3>
160 *
161 * <p>The following methods permit to acces the size of a matrix and change its size or shape.</p>
162 *
163 * <table class="my">
164 * <tr><th>x.Method<th>Description
165 * <tr><td>x.rows<td>Number of rows.
166 * <tr><td>x.columns<td>Number of columns.
167 * <tr><td>x.length<td>Total number of elements.
168 * <tr><td>x.isEmpty()<td>Checks whether rows == 0 and columns == 0.
169 * <tr><td>x.isRowVector()<td>Checks whether rows == 1.
170 * <tr><td>x.isColumnVector()<td>Checks whether columns == 1.
171 * <tr><td>x.isVector()<td>Checks whether rows == 1 or columns == 1.
172 * <tr><td>x.isSquare()<td>Checks whether rows == columns.
173 * <tr><td>x.isScalar()<td>Checks whether length == 1.
174 * <tr><td>x.resize(r, c)<td>Resize the matrix to r rows and c columns, discarding the content.
175 * <tr><td>x.reshape(r, c)<td>Resize the matrix to r rows and c columns.<br> Number of elements must not change.
176 * </table>
177 *
178 * <p>The size is stored in the <tt>rows</tt> and <tt>columns</tt> member variables.
179 * The total number of elements is stored in <tt>length</tt>. Do not change these
180 * values unless you know what you're doing!</p>
181 *
182 * <h3>Arithmetics</h3>
183 *
184 * <p>The usual arithmetic operations are implemented. Each operation exists in a
185 * in-place version, recognizable by the suffix <tt>"i"</tt>, to which you can supply
186 * the result matrix (or <tt>this</tt> is used, if missing). Using in-place operations
187 * can also lead to a smaller memory footprint, as the number of temporary objects
188 * which are directly garbage collected again is reduced.</p>
189 *
190 * <p>Whenever you specify a result vector, the result vector must already have the
191 * correct dimensions.</p>
192 *
193 * <p>For example, you can add two matrices using the <tt>add</tt> method. If you want
194 * to store the result in of <tt>x + y</tt> in <tt>z</tt>, type
195 * <span class=code>
196 * x.addi(y, z) // computes x = y + z.
197 * </span>
198 * Even in-place methods return the result, such that you can easily chain in-place methods,
199 * for example:
200 * <span class=code>
201 * x.addi(y).addi(z) // computes x += y; x += z
202 * </span></p>
203 *
204 * <p>Methods which operate element-wise only make sure that the length of the matrices
205 * is correct. Therefore, you can add a 3 * 3 matrix to a 1 * 9 matrix, for example.</p>
206 *
207 * <p>Finally, there exist versions which take floats instead of FloatMatrix Objects
208 * as arguments. These then compute the operation with the same value as the
209 * right-hand-side. The same effect can be achieved by passing a FloatMatrix with
210 * exactly one element.</p>
211 *
212 * <table class="my">
213 * <tr><th>Operation <th>Method <th>Comment
214 * <tr><td>x + y <td>x.add(y) <td>
215 * <tr><td>x - y <td>x.sub(y), y.rsub(x) <td>rsub subtracts left from right hand side
216 * <tr><td rowspan=3>x * y <td>x.mul(y) <td>element-wise multiplication
217 * <tr> <td>x.mmul(y)<td>matrix-matrix multiplication
218 * <tr> <td>x.dot(y) <td>scalar-product
219 * <tr><td>x / y <td>x.div(y), y.rdiv(x) <td>rdiv divides right hand side by left hand side.
220 * <tr><td>- x <td>x.neg() <td>
221 * </table>
222 *
223 * <p>There also exist operations which work on whole columns or rows.</p>
224 *
225 * <table class="my">
226 * <tr><th>Method <th>Description
227 * <tr><td>x.addRowVector<td>adds a vector to each row (addiRowVector works in-place)
228 * <tr><td>x.addColumnVector<td>adds a vector to each column
229 * <tr><td>x.subRowVector<td>subtracts a vector from each row
230 * <tr><td>x.subColumnVector<td>subtracts a vector from each column
231 * <tr><td>x.mulRow<td>Multiplies a row by a scalar
232 * <tr><td>x.mulColumn<td>multiplies a row by a column
233 * </table>
234 *
235 * <p>In principle, you could achieve the same result by first calling getColumn(),
236 * adding, and then calling putColumn, but these methods are much faster.</p>
237 *
238 * <p>The following comparison operations are available</p>
239 *
240 * <table class="my">
241 * <tr><th>Operation <th>Method
242 * <tr><td>x < y <td>x.lt(y)
243 * <tr><td>x <= y <td>x.le(y)
244 * <tr><td>x > y <td>x.gt(y)
245 * <tr><td>x >= y <td>x.ge(y)
246 * <tr><td>x == y <td>x.eq(y)
247 * <tr><td>x != y <td>x.ne(y)
248 * </table>
249 *
250 * <p> Logical operations are also supported. For these operations, a value different from
251 * zero is treated as "true" and zero is treated as "false". All operations are carried
252 * out elementwise.</p>
253 *
254 * <table class="my">
255 * <tr><th>Operation <th>Method
256 * <tr><td>x & y <td>x.and(y)
257 * <tr><td>x | y <td>x.or(y)
258 * <tr><td>x ^ y <td>x.xor(y)
259 * <tr><td>! x <td>x.not()
260 * </table>
261 *
262 * <p>Finally, there are a few more methods to compute various things:</p>
263 *
264 * <table class="my">
265 * <tr><th>Method <th>Description
266 * <tr><td>x.max() <td>Return maximal element
267 * <tr><td>x.argmax() <td>Return index of largest element
268 * <tr><td>x.min() <td>Return minimal element
269 * <tr><td>x.argmin() <td>Return index of largest element
270 * <tr><td>x.columnMins() <td>Return column-wise minima
271 * <tr><td>x.columnArgmins() <td>Return column-wise index of minima
272 * <tr><td>x.columnMaxs() <td>Return column-wise maxima
273 * <tr><td>x.columnArgmaxs() <td>Return column-wise index of maxima
274 * </table>
275 *
276 * @author Mikio Braun, Johannes Schaback
277 */
278 public class FloatMatrix {
279
280 /** Number of rows. */
281 public int rows;
282 /** Number of columns. */
283 public int columns;
284 /** Total number of elements (for convenience). */
285 public int length;
286 /** The actual data stored by rows (that is, row 0, row 1...). */
287 public float[] data = null; // rows are contiguous
288 public static final FloatMatrix EMPTY = new FloatMatrix();
289
290 /**************************************************************************
291 *
292 * Constructors and factory functions
293 *
294 **************************************************************************/
295 /** Create a new matrix with <i>newRows</i> rows, <i>newColumns</i> columns
296 * using <i>newData></i> as the data. The length of the data is not checked!
297 */
298 public FloatMatrix(int newRows, int newColumns, float... newData) {
299 rows = newRows;
300 columns = newColumns;
301 length = rows * columns;
302
303 if (newData != null && newData.length != newRows * newColumns) {
304 throw new IllegalArgumentException(
305 "Passed data must match matrix dimensions.");
306 }
307
308 data = newData;
309 //System.err.printf("%d * %d matrix created\n", rows, columns);
310 }
311
312 /**
313 * Creates a new <i>n</i> times <i>m</i> <tt>FloatMatrix</tt>.
314 * @param newRows the number of rows (<i>n</i>) of the new matrix.
315 * @param newColumns the number of columns (<i>m</i>) of the new matrix.
316 */
317 public FloatMatrix(int newRows, int newColumns) {
318 this(newRows, newColumns, new float[newRows * newColumns]);
319 }
320
321 /**
322 * Creates a new <tt>FloatMatrix</tt> of size 0 times 0.
323 */
324 public FloatMatrix() {
325 this(0, 0, (float[]) null);
326 }
327
328 /**
329 * Create a Matrix of length <tt>len</tt>. By default, this creates a row vector.
330 * @param len
331 */
332 public FloatMatrix(int len) {
333 this(len, 1, new float[len]);
334 }
335
336 public FloatMatrix(float[] newData) {
337 this(newData.length);
338 data = newData;
339 }
340
341 /**
342 * Creates a new matrix by reading it from a file.
343 * @param filename the path and name of the file to read the matrix from
344 * @throws IOException
345 */
346 public FloatMatrix(String filename) throws IOException {
347 load(filename);
348 }
349
350 /**
351 * Creates a new <i>n</i> times <i>m</i> <tt>FloatMatrix</tt> from
352 * the given <i>n</i> times <i>m</i> 2D data array. The first dimension of the array makes the
353 * rows (<i>n</i>) and the second dimension the columns (<i>m</i>). For example, the
354 * given code <br/><br/>
355 * <code>new FloatMatrix(new float[][]{{1d, 2d, 3d}, {4d, 5d, 6d}, {7d, 8d, 9d}}).print();</code><br/><br/>
356 * will constructs the following matrix:
357 * <pre>
358 * 1.0f 2.0f 3.0f
359 * 4.0f 5.0f 6.0f
360 * 7.0f 8.0f 9.0f
361 * </pre>.
362 * @param data <i>n</i> times <i>m</i> data array
363 */
364 public FloatMatrix(float[][] data) {
365 this(data.length, data[0].length);
366
367 for (int r = 0; r < rows; r++) {
368 assert (data[r].length == columns);
369 }
370
371 for (int r = 0; r < rows; r++) {
372 for (int c = 0; c < columns; c++) {
373 put(r, c, data[r][c]);
374 }
375 }
376 }
377
378 /** Create matrix with random values uniformly in 0..1. */
379 public static FloatMatrix rand(int rows, int columns) {
380 FloatMatrix m = new FloatMatrix(rows, columns);
381
382 java.util.Random r = new java.util.Random();
383 for (int i = 0; i < rows * columns; i++) {
384 m.data[i] = r.nextFloat();
385 }
386
387 return m;
388 }
389
390 /** Creates a row vector with random values uniformly in 0..1. */
391 public static FloatMatrix rand(int len) {
392 return rand(len, 1);
393 }
394
395 /** Create matrix with normally distributed random values. */
396 public static FloatMatrix randn(int rows, int columns) {
397 FloatMatrix m = new FloatMatrix(rows, columns);
398
399 java.util.Random r = new java.util.Random();
400 for (int i = 0; i < rows * columns; i++) {
401 m.data[i] = (float) r.nextGaussian();
402 }
403
404 return m;
405 }
406
407 /** Create row vector with normally distributed random values. */
408 public static FloatMatrix randn(int len) {
409 return randn(len, 1);
410 }
411
412 /** Creates a new matrix in which all values are equal 0. */
413 public static FloatMatrix zeros(int rows, int columns) {
414 return new FloatMatrix(rows, columns);
415 }
416
417 /** Creates a row vector of given length. */
418 public static FloatMatrix zeros(int length) {
419 return zeros(length, 1);
420 }
421
422 /** Creates a new matrix in which all values are equal 1. */
423 public static FloatMatrix ones(int rows, int columns) {
424 FloatMatrix m = new FloatMatrix(rows, columns);
425
426 for (int i = 0; i < rows * columns; i++) {
427 m.put(i, 1.0f);
428 }
429
430 return m;
431 }
432
433 /** Creates a row vector with all elements equal to 1. */
434 public static FloatMatrix ones(int length) {
435 return ones(length, 1);
436 }
437
438 /** Construct a new n-by-n identity matrix. */
439 public static FloatMatrix eye(int n) {
440 FloatMatrix m = new FloatMatrix(n, n);
441
442 for (int i = 0; i < n; i++) {
443 m.put(i, i, 1.0f);
444 }
445
446 return m;
447 }
448
449 /**
450 * Creates a new matrix where the values of the given vector are the diagonal values of
451 * the matrix.
452 */
453 public static FloatMatrix diag(FloatMatrix x) {
454 FloatMatrix m = new FloatMatrix(x.length, x.length);
455
456 for (int i = 0; i < x.length; i++) {
457 m.put(i, i, x.get(i));
458 }
459
460 return m;
461 }
462
463 /**
464 * Create a 1-by-1 matrix. For many operations, this matrix functions like a
465 * normal float.
466 */
467 public static FloatMatrix scalar(float s) {
468 FloatMatrix m = new FloatMatrix(1, 1);
469 m.put(0, 0, s);
470 return m;
471 }
472
473 /** Test whether a matrix is scalar. */
474 public boolean isScalar() {
475 return length == 1;
476 }
477
478 /** Return the first element of the matrix. */
479 public float scalar() {
480 return get(0);
481 }
482
483 /**
484 * Concatenates two matrices horizontally. Matrices must have identical
485 * numbers of rows.
486 */
487 public static FloatMatrix concatHorizontally(FloatMatrix A, FloatMatrix B) {
488 if (A.rows != B.rows) {
489 throw new SizeException("Matrices don't have same number of rows.");
490 }
491
492 FloatMatrix result = new FloatMatrix(A.rows, A.columns + B.columns);
493 SimpleBlas.copy(A, result);
494 JavaBlas.rcopy(B.length, B.data, 0, 1, result.data, A.length, 1);
495 return result;
496 }
497
498 /**
499 * Concatenates two matrices vertically. Matrices must have identical
500 * numbers of columns.
501 */
502 public static FloatMatrix concatVertically(FloatMatrix A, FloatMatrix B) {
503 if (A.columns != B.columns) {
504 throw new SizeException("Matrices don't have same number of columns (" + A.columns + " != " + B.columns + ".");
505 }
506
507 FloatMatrix result = new FloatMatrix(A.rows + B.rows, A.columns);
508
509 for (int i = 0; i < A.columns; i++) {
510 JavaBlas.rcopy(A.rows, A.data, A.index(0, i), 1, result.data, result.index(0, i), 1);
511 JavaBlas.rcopy(B.rows, B.data, B.index(0, i), 1, result.data, result.index(A.rows, i), 1);
512 }
513
514 return result;
515 }
516
517 /**************************************************************************
518 * Working with slices (Man! 30+ methods just to make this a bit flexible...)
519 */
520 /** Get all elements specified by the linear indices. */
521 public FloatMatrix get(int[] indices) {
522 FloatMatrix result = new FloatMatrix(indices.length);
523
524 for (int i = 0; i < indices.length; i++) {
525 result.put(i, get(indices[i]));
526 }
527
528 return result;
529 }
530
531 /** Get all elements for a given row and the specified columns. */
532 public FloatMatrix get(int r, int[] indices) {
533 FloatMatrix result = new FloatMatrix(1, indices.length);
534
535 for (int i = 0; i < indices.length; i++) {
536 result.put(i, get(r, indices[i]));
537 }
538
539 return result;
540 }
541
542 /** Get all elements for a given column and the specified rows. */
543 public FloatMatrix get(int[] indices, int c) {
544 FloatMatrix result = new FloatMatrix(indices.length, c);
545
546 for (int i = 0; i < indices.length; i++) {
547 result.put(i, get(indices[i], c));
548 }
549
550 return result;
551 }
552
553 /** Get all elements from the specified rows and columns. */
554 public FloatMatrix get(int[] rindices, int[] cindices) {
555 FloatMatrix result = new FloatMatrix(rindices.length, cindices.length);
556
557 for (int i = 0; i < rindices.length; i++) {
558 for (int j = 0; j < cindices.length; j++) {
559 result.put(i, j, get(rindices[i], cindices[j]));
560 }
561 }
562
563 return result;
564 }
565
566 /** Get elements from specified rows and columns. */
567 public FloatMatrix get(Range rs, Range cs) {
568 rs.init(0, rows - 1);
569 cs.init(0, columns - 1);
570 FloatMatrix result = new FloatMatrix(rs.length(), cs.length());
571
572 for (; !rs.hasMore(); rs.next()) {
573 for (; !cs.hasMore(); cs.next()) {
574 result.put(rs.index(), cs.index(), get(rs.value(), cs.value()));
575 }
576 }
577
578 return result;
579 }
580
581 /** Get elements specified by the non-zero entries of the passed matrix. */
582 public FloatMatrix get(FloatMatrix indices) {
583 return get(indices.findIndices());
584 }
585
586 /**
587 * Get elements from a row and columns as specified by the non-zero entries of
588 * a matrix.
589 */
590 public FloatMatrix get(int r, FloatMatrix indices) {
591 return get(r, indices.findIndices());
592 }
593
594 /**
595 * Get elements from a column and rows as specified by the non-zero entries of
596 * a matrix.
597 */
598 public FloatMatrix get(FloatMatrix indices, int c) {
599 return get(indices.findIndices(), c);
600 }
601
602 /**
603 * Get elements from columns and rows as specified by the non-zero entries of
604 * the passed matrices.
605 */
606 public FloatMatrix get(FloatMatrix rindices, FloatMatrix cindices) {
607 return get(rindices.findIndices(), cindices.findIndices());
608 }
609
610 /** Return all elements with linear index a, a + 1, ..., b - 1.*/
611 public FloatMatrix getRange(int a, int b) {
612 FloatMatrix result = new FloatMatrix(b - a);
613
614 for (int k = 0; k < b - a; k++) {
615 result.put(k, get(a + k));
616 }
617
618 return result;
619 }
620
621 /** Get elements from a row and columns <tt>a</tt> to <tt>b</tt>. */
622 public FloatMatrix getColumnRange(int r, int a, int b) {
623 FloatMatrix result = new FloatMatrix(1, b - a);
624
625 for (int k = 0; k < b - a; k++) {
626 result.put(k, get(r, a + k));
627 }
628
629 return result;
630 }
631
632 /** Get elements from a column and rows <tt>a/tt> to <tt>b</tt>. */
633 public FloatMatrix getRowRange(int a, int b, int c) {
634 FloatMatrix result = new FloatMatrix(b - a);
635
636 for (int k = 0; k < b - a; k++) {
637 result.put(k, get(a + k, c));
638 }
639
640 return result;
641 }
642
643 /**
644 * Get elements from rows <tt>ra</tt> to <tt>rb</tt> and
645 * columns <tt>ca</tt> to <tt>cb</tt>.
646 */
647 public FloatMatrix getRange(int ra, int rb, int ca, int cb) {
648 FloatMatrix result = new FloatMatrix(rb - ra, cb - ca);
649
650 for (int i = 0; i < rb - ra; i++) {
651 for (int j = 0; j < cb - ca; j++) {
652 result.put(i, j, get(ra + i, ca + j));
653 }
654 }
655
656 return result;
657 }
658
659 /** Get whole rows from the passed indices. */
660 public FloatMatrix getRows(int[] rindices) {
661 FloatMatrix result = new FloatMatrix(rindices.length, columns);
662 for (int i = 0; i < rindices.length; i++) {
663 JavaBlas.rcopy(columns, data, index(rindices[i], 0), rows, result.data, result.index(i, 0), result.rows);
664 }
665 return result;
666 }
667
668 /** Get whole rows as specified by the non-zero entries of a matrix. */
669 public FloatMatrix getRows(FloatMatrix rindices) {
670 return getRows(rindices.findIndices());
671 }
672
673 /** Get whole columns from the passed indices. */
674 public FloatMatrix getColumns(int[] cindices) {
675 FloatMatrix result = new FloatMatrix(rows, cindices.length);
676 for (int i = 0; i < cindices.length; i++) {
677 JavaBlas.rcopy(rows, data, index(0, cindices[i]), 1, result.data, result.index(0, i), 1);
678 }
679 return result;
680 }
681
682 /** Get whole columns as specified by the non-zero entries of a matrix. */
683 public FloatMatrix getColumns(FloatMatrix cindices) {
684 return getColumns(cindices.findIndices());
685 }
686
687 /**
688 * Assert that the matrix has a certain length.
689 * @throws SizeException
690 */
691 public void checkLength(int l) {
692 if (length != l) {
693 throw new SizeException("Matrix does not have the necessary length (" + length + " != " + l + ").");
694 }
695 }
696
697 /**
698 * Asserts that the matrix has a certain number of rows.
699 * @throws SizeException
700 */
701 public void checkRows(int r) {
702 if (rows != r) {
703 throw new SizeException("Matrix does not have the necessary number of rows (" + rows + " != " + r + ").");
704 }
705 }
706
707 /**
708 * Asserts that the amtrix has a certain number of columns.
709 * @throws SizeException
710 */
711 public void checkColumns(int c) {
712 if (columns != c) {
713 throw new SizeException("Matrix does not have the necessary number of columns (" + columns + " != " + c + ").");
714 }
715 }
716
717 /** Set elements in linear ordering in the specified indices. */
718 public FloatMatrix put(int[] indices, FloatMatrix x) {
719 if (x.isScalar()) {
720 return put(indices, x.scalar());
721 }
722 x.checkLength(indices.length);
723
724 for (int i = 0; i < indices.length; i++) {
725 put(indices[i], x.get(i));
726 }
727
728 return this;
729 }
730
731 /** Set multiple elements in a row. */
732 public FloatMatrix put(int r, int[] indices, FloatMatrix x) {
733 if (x.isScalar()) {
734 return put(r, indices, x.scalar());
735 }
736 x.checkColumns(indices.length);
737
738 for (int i = 0; i < indices.length; i++) {
739 put(r, indices[i], x.get(i));
740 }
741
742 return this;
743 }
744
745 /** Set multiple elements in a row. */
746 public FloatMatrix put(int[] indices, int c, FloatMatrix x) {
747 if (x.isScalar()) {
748 return put(indices, c, x.scalar());
749 }
750 x.checkRows(indices.length);
751
752 for (int i = 0; i < indices.length; i++) {
753 put(indices[i], c, x.get(i));
754 }
755
756 return this;
757 }
758
759 /** Put a sub-matrix as specified by the indices. */
760 public FloatMatrix put(int[] rindices, int[] cindices, FloatMatrix x) {
761 if (x.isScalar()) {
762 return put(rindices, cindices, x.scalar());
763 }
764 x.checkRows(rindices.length);
765 x.checkColumns(cindices.length);
766
767 for (int i = 0; i < rindices.length; i++) {
768 for (int j = 0; j < cindices.length; j++) {
769 put(rindices[i], cindices[j], x.get(i, j));
770 }
771 }
772
773 return this;
774 }
775
776 /** Put a matrix into specified indices. */
777 public FloatMatrix put(Range rs, Range cs, FloatMatrix x) {
778 rs.init(0, rows - 1);
779 cs.init(0, columns - 1);
780
781 x.checkRows(rs.length());
782 x.checkColumns(cs.length());
783
784 for (; rs.hasMore(); rs.next()) {
785 for (; cs.hasMore(); cs.next()) {
786 put(rs.value(), cs.value(), x.get(rs.index(), cs.index()));
787 }
788 }
789
790 return this;
791 }
792
793 /** Put a single value into the specified indices (linear adressing). */
794 public FloatMatrix put(int[] indices, float v) {
795 for (int i = 0; i < indices.length; i++) {
796 put(indices[i], v);
797 }
798
799 return this;
800 }
801
802 /** Put a single value into a row and the specified columns. */
803 public FloatMatrix put(int r, int[] indices, float v) {
804 for (int i = 0; i < indices.length; i++) {
805 put(r, indices[i], v);
806 }
807
808 return this;
809 }
810
811 /** Put a single value into the specified rows of a column. */
812 public FloatMatrix put(int[] indices, int c, float v) {
813 for (int i = 0; i < indices.length; i++) {
814 put(indices[i], c, v);
815 }
816
817 return this;
818 }
819
820 /** Put a single value into the specified rows and columns. */
821 public FloatMatrix put(int[] rindices, int[] cindices, float v) {
822 for (int i = 0; i < rindices.length; i++) {
823 for (int j = 0; j < cindices.length; j++) {
824 put(rindices[i], cindices[j], v);
825 }
826 }
827
828 return this;
829 }
830
831 /**
832 * Put a sub-matrix into the indices specified by the non-zero entries
833 * of <tt>indices</tt> (linear adressing).
834 */
835 public FloatMatrix put(FloatMatrix indices, FloatMatrix v) {
836 return put(indices.findIndices(), v);
837 }
838
839 /** Put a sub-vector into the specified columns (non-zero entries of <tt>indices</tt>) of a row. */
840 public FloatMatrix put(int r, FloatMatrix indices, FloatMatrix v) {
841 return put(r, indices.findIndices(), v);
842 }
843
844 /** Put a sub-vector into the specified rows (non-zero entries of <tt>indices</tt>) of a column. */
845 public FloatMatrix put(FloatMatrix indices, int c, FloatMatrix v) {
846 return put(indices.findIndices(), c, v);
847 }
848
849 /**
850 * Put a sub-matrix into the specified rows and columns (non-zero entries of
851 * <tt>rindices</tt> and <tt>cindices</tt>.
852 */
853 public FloatMatrix put(FloatMatrix rindices, FloatMatrix cindices, FloatMatrix v) {
854 return put(rindices.findIndices(), cindices.findIndices(), v);
855 }
856
857 /**
858 * Put a single value into the elements specified by the non-zero
859 * entries of <tt>indices</tt> (linear adressing).
860 */
861 public FloatMatrix put(FloatMatrix indices, float v) {
862 return put(indices.findIndices(), v);
863 }
864
865 /**
866 * Put a single value into the specified columns (non-zero entries of
867 * <tt>indices</tt>) of a row.
868 */
869 public FloatMatrix put(int r, FloatMatrix indices, float v) {
870 return put(r, indices.findIndices(), v);
871 }
872
873 /**
874 * Put a single value into the specified rows (non-zero entries of
875 * <tt>indices</tt>) of a column.
876 */
877 public FloatMatrix put(FloatMatrix indices, int c, float v) {
878 return put(indices.findIndices(), c, v);
879 }
880
881 /**
882 * Put a single value in the specified rows and columns (non-zero entries
883 * of <tt>rindices</tt> and <tt>cindices</tt>.
884 */
885 public FloatMatrix put(FloatMatrix rindices, FloatMatrix cindices, float v) {
886 return put(rindices.findIndices(), cindices.findIndices(), v);
887 }
888
889 /** Find the linear indices of all non-zero elements. */
890 public int[] findIndices() {
891 int len = 0;
892 for (int i = 0; i < length; i++) {
893 if (get(i) != 0.0f) {
894 len++;
895 }
896 }
897
898 int[] indices = new int[len];
899 int c = 0;
900
901 for (int i = 0; i < length; i++) {
902 if (get(i) != 0.0f) {
903 indices[c++] = i;
904 }
905 }
906
907 return indices;
908 }
909
910 /**************************************************************************
911 * Basic operations (copying, resizing, element access)
912 */
913 /** Return transposed copy of this matrix. */
914 public FloatMatrix transpose() {
915 FloatMatrix result = new FloatMatrix(columns, rows);
916
917 for (int i = 0; i < rows; i++) {
918 for (int j = 0; j < columns; j++) {
919 result.put(j, i, get(i, j));
920 }
921 }
922
923 return result;
924 }
925
926 /**
927 * Compare two matrices. Returns true if and only if other is also a
928 * FloatMatrix which has the same size and the maximal absolute
929 * difference in matrix elements is smaller thatn 1e-6.
930 */
931 public boolean equals(Object o) {
932 if (!(o instanceof FloatMatrix)) {
933 return false;
934 }
935
936 FloatMatrix other = (FloatMatrix) o;
937
938 if (!sameSize(other)) {
939 return false;
940 }
941
942 FloatMatrix diff = MatrixFunctions.absi(sub(other));
943
944 return diff.max() / (rows * columns) < 1e-6;
945 }
946
947 /** Resize the matrix. All elements will be set to zero. */
948 public void resize(int newRows, int newColumns) {
949 rows = newRows;
950 columns = newColumns;
951 length = newRows * newColumns;
952 data = new float[rows * columns];
953 }
954
955 /** Reshape the matrix. Number of elements must not change. */
956 public FloatMatrix reshape(int newRows, int newColumns) {
957 if (length != newRows * newColumns) {
958 throw new IllegalArgumentException(
959 "Number of elements must not change.");
960 }
961
962 rows = newRows;
963 columns = newColumns;
964
965 return this;
966 }
967
968 /** Generate a new matrix which has the given number of replications of this. */
969 public FloatMatrix repmat(int rowMult, int columnMult) {
970 FloatMatrix result = new FloatMatrix(rows * rowMult, columns * columnMult);
971
972 for (int c = 0; c < columnMult; c++)
973 for (int r = 0; r < rowMult; r++)
974 for (int i = 0; i < rows; i++)
975 for (int j = 0; j < columns; j++)
976 result.put(r * rows + i, c * columns + j, get(i, j));
977 return result;
978 }
979
980 /** Checks whether two matrices have the same size. */
981 public boolean sameSize(FloatMatrix a) {
982 return rows == a.rows && columns == a.columns;
983 }
984
985 /** Throws SizeException unless two matrices have the same size. */
986 public void assertSameSize(FloatMatrix a) {
987 if (!sameSize(a)) {
988 throw new SizeException("Matrices must have the same size.");
989 }
990 }
991
992 /** Checks whether two matrices can be multiplied (that is, number of columns of
993 * this must equal number of rows of a. */
994 public boolean multipliesWith(FloatMatrix a) {
995 return columns == a.rows;
996 }
997
998 /** Throws SizeException unless matrices can be multiplied with one another. */
999 public void assertMultipliesWith(FloatMatrix a) {
1000 if (!multipliesWith(a)) {
1001 throw new SizeException("Number of columns of left matrix must be equal to number of rows of right matrix.");
1002 }
1003 }
1004
1005 /** Checks whether two matrices have the same length. */
1006 public boolean sameLength(FloatMatrix a) {
1007 return length == a.length;
1008 }
1009
1010 /** Throws SizeException unless matrices have the same length. */
1011 public void assertSameLength(FloatMatrix a) {
1012 if (!sameLength(a)) {
1013 throw new SizeException("Matrices must have same length (is: " + length + " and " + a.length + ")");
1014 }
1015 }
1016
1017 /** Copy FloatMatrix a to this. this a is resized if necessary. */
1018 public FloatMatrix copy(FloatMatrix a) {
1019 if (!sameSize(a)) {
1020 resize(a.rows, a.columns);
1021 }
1022
1023 System.arraycopy(a.data, 0, data, 0, length);
1024 return a;
1025 }
1026
1027 /**
1028 * Returns a duplicate of this matrix. Geometry is the same (including offsets, transpose, etc.),
1029 * but the buffer is not shared.
1030 */
1031 public FloatMatrix dup() {
1032 FloatMatrix out = new FloatMatrix(rows, columns);
1033
1034 JavaBlas.rcopy(length, data, 0, 1, out.data, 0, 1);
1035
1036 return out;
1037 }
1038
1039 /** Swap two columns of a matrix. */
1040 public FloatMatrix swapColumns(int i, int j) {
1041 NativeBlas.sswap(rows, data, index(0, i), 1, data, index(0, j), 1);
1042 return this;
1043 }
1044
1045 /** Swap two rows of a matrix. */
1046 public FloatMatrix swapRows(int i, int j) {
1047 NativeBlas.sswap(columns, data, index(i, 0), rows, data, index(j, 0), rows);
1048 return this;
1049 }
1050
1051 /** Set matrix element */
1052 public FloatMatrix put(int rowIndex, int columnIndex, float value) {
1053 data[index(rowIndex, columnIndex)] = value;
1054 return this;
1055 }
1056
1057 /** Retrieve matrix element */
1058 public float get(int rowIndex, int columnIndex) {
1059 return data[index(rowIndex, columnIndex)];
1060 }
1061
1062 /** Get index of an element */
1063 public int index(int rowIndex, int columnIndex) {
1064 return rowIndex + rows * columnIndex;
1065 }
1066
1067 /** Compute the row index of a linear index. */
1068 public int indexRows(int i) {
1069 return i / rows;
1070 }
1071
1072 /** Compute the column index of a linear index. */
1073 public int indexColumns(int i) {
1074 return i - indexRows(i) * rows;
1075 }
1076
1077 /** Get a matrix element (linear indexing). */
1078 public float get(int i) {
1079 return data[i];
1080 }
1081
1082 /** Set a matrix element (linear indexing). */
1083 public FloatMatrix put(int i, float v) {
1084 data[i] = v;
1085 return this;
1086 }
1087
1088 /** Set all elements to a value. */
1089 public FloatMatrix fill(float value) {
1090 for (int i = 0; i < length; i++) {
1091 put(i, value);
1092 }
1093 return this;
1094 }
1095
1096 /** Get number of rows. */
1097 public int getRows() {
1098 return rows;
1099 }
1100
1101 /** Get number of columns. */
1102 public int getColumns() {
1103 return columns;
1104 }
1105
1106 /** Get total number of elements. */
1107 public int getLength() {
1108 return length;
1109 }
1110
1111 /** Checks whether the matrix is empty. */
1112 public boolean isEmpty() {
1113 return columns == 0 || rows == 0;
1114 }
1115
1116 /** Checks whether the matrix is square. */
1117 public boolean isSquare() {
1118 return columns == rows;
1119 }
1120
1121 /** Throw SizeException unless matrix is square. */
1122 public void assertSquare() {
1123 if (!isSquare()) {
1124 throw new SizeException("Matrix must be square!");
1125 }
1126 }
1127
1128 /** Checks whether the matrix is a vector. */
1129 public boolean isVector() {
1130 return columns == 1 || rows == 1;
1131 }
1132
1133 /** Checks whether the matrix is a row vector. */
1134 public boolean isRowVector() {
1135 return rows == 1;
1136 }
1137
1138 /** Checks whether the matrix is a column vector. */
1139 public boolean isColumnVector() {
1140 return columns == 1;
1141 }
1142
1143 /** Returns the diagonal of the matrix. */
1144 public FloatMatrix diag() {
1145 assertSquare();
1146 FloatMatrix d = new FloatMatrix(rows);
1147 JavaBlas.rcopy(rows, data, 0, rows + 1, d.data, 0, 1);
1148 return d;
1149 }
1150
1151 /** Pretty-print this matrix to <tt>System.out</tt>. */
1152 public void print() {
1153 System.out.println(toString());
1154 }
1155
1156 /** Generate string representation of the matrix. */
1157 @Override
1158 public String toString() {
1159 StringBuilder s = new StringBuilder();
1160
1161 s.append("[");
1162
1163 for (int i = 0; i < rows; i++) {
1164 for (int j = 0; j < columns; j++) {
1165 s.append(get(i, j));
1166 if (j < columns - 1) {
1167 s.append(", ");
1168 }
1169 }
1170 if (i < rows - 1) {
1171 s.append("; ");
1172 }
1173 }
1174
1175 s.append("]");
1176
1177 return s.toString();
1178 }
1179
1180 /**
1181 * Generate string representation of the matrix, with specified
1182 * format for the entries. For example, <code>x.toString("%.1f")</code>
1183 * generates a string representations having only one position after the
1184 * decimal point.
1185 */
1186 public String toString(String fmt) {
1187 StringWriter s = new StringWriter();
1188 PrintWriter p = new PrintWriter(s);
1189
1190 p.print("[");
1191
1192 for (int r = 0; r < rows; r++) {
1193 for (int c = 0; c < columns; c++) {
1194 p.printf(fmt, get(r, c));
1195 if (c < columns - 1) {
1196 p.print(", ");
1197 }
1198 }
1199 if (r < rows - 1) {
1200 p.print("; ");
1201 }
1202 }
1203
1204 p.print("]");
1205
1206 return s.toString();
1207 }
1208
1209 /** Converts the matrix to a one-dimensional array of floats. */
1210 public float[] toArray() {
1211 float[] array = new float[length];
1212
1213 System.arraycopy(data, 0, array, 0, length);
1214
1215 return array;
1216 }
1217
1218 /** Converts the matrix to a two-dimensional array of floats. */
1219 public float[][] toArray2() {
1220 float[][] array = new float[rows][columns];
1221
1222 for (int r = 0; r < rows; r++) {
1223 for (int c = 0; c < columns; c++) {
1224 array[r][c] = get(r, c);
1225 }
1226 }
1227
1228 return array;
1229 }
1230
1231 /** Converts the matrix to a one-dimensional array of integers. */
1232 public int[] toIntArray() {
1233 int[] array = new int[length];
1234
1235 for (int i = 0; i < length; i++) {
1236 array[i] = (int) Math.rint(get(i));
1237 }
1238
1239 return array;
1240 }
1241
1242 /** Convert the matrix to a two-dimensional array of integers. */
1243 public int[][] toIntArray2() {
1244 int[][] array = new int[rows][columns];
1245
1246 for (int r = 0; r < rows; r++) {
1247 for (int c = 0; c < columns; c++) {
1248 array[r][c] = (int) Math.rint(get(r, c));
1249 }
1250 }
1251
1252 return array;
1253 }
1254
1255 /** Convert the matrix to a one-dimensional array of boolean values. */
1256 public boolean[] toBooleanArray() {
1257 boolean[] array = new boolean[length];
1258
1259 for (int i = 0; i < length; i++) {
1260 array[i] = get(i) != 0.0f ? true : false;
1261 }
1262
1263 return array;
1264 }
1265
1266 /** Convert the matrix to a two-dimensional array of boolean values. */
1267 public boolean[][] toBooleanArray2() {
1268 boolean[][] array = new boolean[rows][columns];
1269
1270 for (int r = 0; r < rows; r++) {
1271 for (int c = 0; c < columns; c++) {
1272 array[r][c] = get(r, c) != 0.0f ? true : false;
1273 }
1274 }
1275
1276 return array;
1277 }
1278
1279 /** Convert matrix to FloatMatrix. */
1280 public FloatMatrix toFloatMatrix() {
1281 FloatMatrix result = new FloatMatrix(rows, columns);
1282
1283 for (int c = 0; c < columns; c++) {
1284 for (int r = 0; r < rows; r++) {
1285 result.put(r, c, (float) get(r, c));
1286 }
1287 }
1288
1289 return result;
1290 }
1291
1292 /**
1293 * A wrapper which allows to view a matrix as a List of Doubles (read-only!).
1294 * Also implements the {@link ConvertsToFloatMatrix} interface.
1295 */
1296 public class ElementsAsListView extends AbstractList<Float> implements ConvertsToFloatMatrix {
1297 private FloatMatrix me;
1298
1299 public ElementsAsListView(FloatMatrix me) {
1300 this.me = me;
1301 }
1302
1303 @Override
1304 public Float get(int index) {
1305 return me.get(index);
1306 }
1307
1308 @Override
1309 public int size() {
1310 return me.length;
1311 }
1312
1313 public FloatMatrix convertToFloatMatrix() {
1314 return me;
1315 }
1316 }
1317
1318 public class RowsAsListView extends AbstractList<FloatMatrix> implements ConvertsToFloatMatrix {
1319 private FloatMatrix me;
1320
1321 public RowsAsListView(FloatMatrix me) {
1322 this.me = me;
1323 }
1324
1325 @Override
1326 public FloatMatrix get(int index) {
1327 return getRow(index);
1328 }
1329
1330 @Override
1331 public int size() {
1332 return rows;
1333 }
1334
1335 public FloatMatrix convertToFloatMatrix() {
1336 return me;
1337 }
1338 }
1339
1340 public class ColumnsAsListView extends AbstractList<FloatMatrix> implements ConvertsToFloatMatrix {
1341 private FloatMatrix me;
1342
1343 public ColumnsAsListView(FloatMatrix me) {
1344 this.me = me;
1345 }
1346
1347 @Override
1348 public FloatMatrix get(int index) {
1349 return getColumn(index);
1350 }
1351
1352 @Override
1353 public int size() {
1354 return columns;
1355 }
1356
1357 public FloatMatrix convertToFloatMatrix() {
1358 return me;
1359 }
1360 }
1361
1362 public List<Float> elementsAsList() {
1363 return new ElementsAsListView(this);
1364 }
1365
1366 public List<FloatMatrix> rowsAsList() {
1367 return new RowsAsListView(this);
1368 }
1369
1370 public List<FloatMatrix> columnsAsList() {
1371 return new ColumnsAsListView(this);
1372 }
1373
1374 /**************************************************************************
1375 * Arithmetic Operations
1376 */
1377 /**
1378 * Ensures that the result vector has the same length as this. If not,
1379 * resizing result is tried, which fails if result == this or result == other.
1380 */
1381 private void ensureResultLength(FloatMatrix other, FloatMatrix result) {
1382 if (!sameLength(result)) {
1383 if (result == this || result == other) {
1384 throw new SizeException("Cannot resize result matrix because it is used in-place.");
1385 }
1386 result.resize(rows, columns);
1387 }
1388 }
1389
1390 /** Add two matrices (in-place). */
1391 public FloatMatrix addi(FloatMatrix other, FloatMatrix result) {
1392 if (other.isScalar()) {
1393 return addi(other.scalar(), result);
1394 }
1395 if (isScalar()) {
1396 return other.addi(scalar(), result);
1397 }
1398
1399 assertSameLength(other);
1400 ensureResultLength(other, result);
1401
1402 if (result == this) {
1403 SimpleBlas.axpy(1.0f, other, result);
1404 } else if (result == other) {
1405 SimpleBlas.axpy(1.0f, this, result);
1406 } else {
1407 /*SimpleBlas.copy(this, result);
1408 SimpleBlas.axpy(1.0f, other, result);*/
1409 JavaBlas.rzgxpy(length, result.data, data, other.data);
1410 }
1411
1412 return result;
1413 }
1414
1415 /** Add a scalar to a matrix (in-place). */
1416 public FloatMatrix addi(float v, FloatMatrix result) {
1417 ensureResultLength(null, result);
1418
1419 for (int i = 0; i < length; i++) {
1420 result.put(i, get(i) + v);
1421 }
1422 return result;
1423 }
1424
1425 /** Subtract two matrices (in-place). */
1426 public FloatMatrix subi(FloatMatrix other, FloatMatrix result) {
1427 if (other.isScalar()) {
1428 return subi(other.scalar(), result);
1429 }
1430 if (isScalar()) {
1431 return other.rsubi(scalar(), result);
1432 }
1433
1434 assertSameLength(other);
1435 ensureResultLength(other, result);
1436
1437 if (result == this) {
1438 SimpleBlas.axpy(-1.0f, other, result);
1439 } else if (result == other) {
1440 SimpleBlas.scal(-1.0f, result);
1441 SimpleBlas.axpy(1.0f, this, result);
1442 } else {
1443 SimpleBlas.copy(this, result);
1444 SimpleBlas.axpy(-1.0f, other, result);
1445 }
1446 return result;
1447 }
1448
1449 /** Subtract a scalar from a matrix (in-place). */
1450 public FloatMatrix subi(float v, FloatMatrix result) {
1451 ensureResultLength(null, result);
1452
1453 for (int i = 0; i < length; i++) {
1454 result.put(i, get(i) - v);
1455 }
1456 return result;
1457 }
1458
1459 /**
1460 * Subtract two matrices, but subtract first from second matrix, that is,
1461 * compute <em>result = other - this</em> (in-place).
1462 * */
1463 public FloatMatrix rsubi(FloatMatrix other, FloatMatrix result) {
1464 return other.subi(this, result);
1465 }
1466
1467 /** Subtract a matrix from a scalar (in-place). */
1468 public FloatMatrix rsubi(float a, FloatMatrix result) {
1469 ensureResultLength(null, result);
1470
1471 for (int i = 0; i < length; i++) {
1472 result.put(i, a - get(i));
1473 }
1474 return result;
1475 }
1476
1477 /** Elementwise multiplication (in-place). */
1478 public FloatMatrix muli(FloatMatrix other, FloatMatrix result) {
1479 if (other.isScalar()) {
1480 return muli(other.scalar(), result);
1481 }
1482 if (isScalar()) {
1483 return other.muli(scalar(), result);
1484 }
1485
1486 assertSameLength(other);
1487 ensureResultLength(other, result);
1488
1489 for (int i = 0; i < length; i++) {
1490 result.put(i, get(i) * other.get(i));
1491 }
1492 return result;
1493 }
1494
1495 /** Elementwise multiplication with a scalar (in-place). */
1496 public FloatMatrix muli(float v, FloatMatrix result) {
1497 ensureResultLength(null, result);
1498
1499 for (int i = 0; i < length; i++) {
1500 result.put(i, get(i) * v);
1501 }
1502 return result;
1503 }
1504
1505 /** Matrix-matrix multiplication (in-place). */
1506 public FloatMatrix mmuli(FloatMatrix other, FloatMatrix result) {
1507 if (other.isScalar()) {
1508 return muli(other.scalar(), result);
1509 }
1510 if (isScalar()) {
1511 return other.muli(scalar(), result);
1512 }
1513
1514 /* check sizes and resize if necessary */
1515 assertMultipliesWith(other);
1516 if (result.rows != rows || result.columns != other.columns) {
1517 if (result != this && result != other) {
1518 result.resize(rows, other.columns);
1519 } else {
1520 throw new SizeException("Cannot resize result matrix because it is used in-place.");
1521 }
1522 }
1523
1524 if (result == this || result == other) {
1525 /* actually, blas cannot do multiplications in-place. Therefore, we will fake by
1526 * allocating a temporary object on the side and copy the result later.
1527 */
1528 FloatMatrix temp = new FloatMatrix(result.rows, result.columns);
1529 if (other.columns == 1) {
1530 SimpleBlas.gemv(1.0f, this, other, 0.0f, temp);
1531 } else {
1532 SimpleBlas.gemm(1.0f, this, other, 0.0f, temp);
1533 }
1534 SimpleBlas.copy(temp, result);
1535 } else {
1536 if (other.columns == 1) {
1537 SimpleBlas.gemv(1.0f, this, other, 0.0f, result);
1538 } else {
1539 SimpleBlas.gemm(1.0f, this, other, 0.0f, result);
1540 }
1541 }
1542 return result;
1543 }
1544
1545 /** Matrix-matrix multiplication with a scalar (for symmetry, does the
1546 * same as <code>muli(scalar)</code> (in-place).
1547 */
1548 public FloatMatrix mmuli(float v, FloatMatrix result) {
1549 return muli(v, result);
1550 }
1551
1552 /** Elementwise division (in-place). */
1553 public FloatMatrix divi(FloatMatrix other, FloatMatrix result) {
1554 if (other.isScalar()) {
1555 return divi(other.scalar(), result);
1556 }
1557 if (isScalar()) {
1558 return other.rdivi(scalar(), result);
1559 }
1560
1561 assertSameLength(other);
1562 ensureResultLength(other, result);
1563
1564 for (int i = 0; i < length; i++) {
1565 result.put(i, get(i) / other.get(i));
1566 }
1567 return result;
1568 }
1569
1570 /** Elementwise division with a scalar (in-place). */
1571 public FloatMatrix divi(float a, FloatMatrix result) {
1572 ensureResultLength(null, result);
1573
1574 for (int i = 0; i < length; i++) {
1575 result.put(i, get(i) / a);
1576 }
1577 return result;
1578 }
1579
1580 /**
1581 * Elementwise division, with operands switched. Computes
1582 * <code>result = other / this</code> (in-place). */
1583 public FloatMatrix rdivi(FloatMatrix other, FloatMatrix result) {
1584 return other.divi(this, result);
1585 }
1586
1587 /** (Elementwise) division with a scalar, with operands switched. Computes
1588 * <code>result = a / this</code> (in-place). */
1589 public FloatMatrix rdivi(float a, FloatMatrix result) {
1590 ensureResultLength(null, result);
1591
1592 for (int i = 0; i < length; i++) {
1593 result.put(i, a / get(i));
1594 }
1595 return result;
1596 }
1597
1598 /** Negate each element (in-place). */
1599 public FloatMatrix negi() {
1600 for (int i = 0; i < length; i++) {
1601 put(i, -get(i));
1602 }
1603 return this;
1604 }
1605
1606 /** Negate each element. */
1607 public FloatMatrix neg() {
1608 return dup().negi();
1609 }
1610
1611 /** Maps zero to 1.0f and all non-zero values to 0.0f (in-place). */
1612 public FloatMatrix noti() {
1613 for (int i = 0; i < length; i++) {
1614 put(i, get(i) == 0.0f ? 1.0f : 0.0f);
1615 }
1616 return this;
1617 }
1618
1619 /** Maps zero to 1.0f and all non-zero values to 0.0f. */
1620 public FloatMatrix not() {
1621 return dup().noti();
1622 }
1623
1624 /** Maps zero to 0.0f and all non-zero values to 1.0f (in-place). */
1625 public FloatMatrix truthi() {
1626 for (int i = 0; i < length; i++) {
1627 put(i, get(i) == 0.0f ? 0.0f : 1.0f);
1628 }
1629 return this;
1630 }
1631
1632 /** Maps zero to 0.0f and all non-zero values to 1.0f. */
1633 public FloatMatrix truth() {
1634 return dup().truthi();
1635 }
1636
1637 /**
1638 * Calculate matrix exponential of a square matrix.
1639 *
1640 * A scaled Pade approximation algorithm is used.
1641 * The algorithm has been directly translated from Golub & Van Loan "Matrix Computations",
1642 * algorithm 11.3f.1. Special Horner techniques from 11.2f are also used to minimize the number
1643 * of matrix multiplications.
1644 *
1645 * @param A square matrix
1646 * @return matrix exponential of A
1647 */
1648 public static FloatMatrix expm(FloatMatrix A)
1649 {
1650 // constants for pade approximation
1651 final float c0 = 1.0f;
1652 final float c1 = 0.5f;
1653 final float c2 = 0.12f;
1654 final float c3 = 0.01833333333333333f;
1655 final float c4 = 0.0019927536231884053f;
1656 final float c5 = 1.630434782608695E-4f;
1657 final float c6 = 1.0351966873706E-5f;
1658 final float c7 = 5.175983436853E-7f;
1659 final float c8 = 2.0431513566525E-8f;
1660 final float c9 = 6.306022705717593E-10f;
1661 final float c10 = 1.4837700484041396E-11f;
1662 final float c11 = 2.5291534915979653E-13f;
1663 final float c12 = 2.8101705462199615E-15f;
1664 final float c13 = 1.5440497506703084E-17f;
1665
1666 int j = Math.max(0, 1 + (int)Math.floor(Math.log(A.normmax())/Math.log(2)));
1667 FloatMatrix As = A.div((float)Math.pow(2, j)); // scaled version of A
1668 int n = A.getRows();
1669
1670 // calculate D and N using special Horner techniques
1671 FloatMatrix As_2 = As.mmul(As);
1672 FloatMatrix As_4 = As_2.mmul(As_2);
1673 FloatMatrix As_6 = As_4.mmul(As_2);
1674 // U = c0*I + c2*A^2 + c4*A^4 + (c6*I + c8*A^2 + c10*A^4 + c12*A^6)*A^6
1675 FloatMatrix U = FloatMatrix.eye(n).muli(c0).addi(As_2.mul(c2)).addi(As_4.mul(c4)).addi(
1676 FloatMatrix.eye(n).muli(c6).addi(As_2.mul(c8)).addi(As_4.mul(c10)).addi(As_6.mul(c12)).mmuli(As_6));
1677 // V = c1*I + c3*A^2 + c5*A^4 + (c7*I + c9*A^2 + c11*A^4 + c13*A^6)*A^6
1678 FloatMatrix V = FloatMatrix.eye(n).muli(c1).addi(As_2.mul(c3)).addi(As_4.mul(c5)).addi(
1679 FloatMatrix.eye(n).muli(c7).addi(As_2.mul(c9)).addi(As_4.mul(c11)).addi(As_6.mul(c13)).mmuli(As_6));
1680
1681 FloatMatrix AV = As.mmuli(V);
1682 FloatMatrix N = U.add(AV);
1683 FloatMatrix D = U.subi(AV);
1684
1685 // solve DF = N for F
1686 FloatMatrix F = Solve.solve(D, N);
1687
1688 // now square j times
1689 for(int k = 0; k < j; k++)
1690 {
1691 F.mmuli(F);
1692 }
1693
1694 return F;
1695 }
1696
1697 /****************************************************************
1698 * Rank one-updates
1699 */
1700 /** Computes a rank-1-update A = A + alpha * x * y'. */
1701 public FloatMatrix rankOneUpdate(float alpha, FloatMatrix x, FloatMatrix y) {
1702 if (rows != x.length) {
1703 throw new SizeException("Vector x has wrong length (" + x.length + " != " + rows + ").");
1704 }
1705 if (columns != y.length) {
1706 throw new SizeException("Vector y has wrong length (" + x.length + " != " + columns + ").");
1707 }
1708
1709 SimpleBlas.ger(alpha, x, y, this);
1710 return this;
1711 }
1712
1713 /** Computes a rank-1-update A = A + alpha * x * x'. */
1714 public FloatMatrix rankOneUpdate(float alpha, FloatMatrix x) {
1715 return rankOneUpdate(alpha, x, x);
1716 }
1717
1718 /** Computes a rank-1-update A = A + x * x'. */
1719 public FloatMatrix rankOneUpdate(FloatMatrix x) {
1720 return rankOneUpdate(1.0f, x, x);
1721 }
1722
1723 /** Computes a rank-1-update A = A + x * y'. */
1724 public FloatMatrix rankOneUpdate(FloatMatrix x, FloatMatrix y) {
1725 return rankOneUpdate(1.0f, x, y);
1726 }
1727
1728 /****************************************************************
1729 * Logical operations
1730 */
1731 /** Returns the minimal element of the matrix. */
1732 public float min() {
1733 if (isEmpty()) {
1734 return Float.POSITIVE_INFINITY;
1735 }
1736 float v = Float.POSITIVE_INFINITY;
1737 for (int i = 0; i < length; i++) {
1738 if (!Float.isNaN(get(i)) && get(i) < v) {
1739 v = get(i);
1740 }
1741 }
1742
1743 return v;
1744 }
1745
1746 /**
1747 * Returns the linear index of the minimal element. If there are
1748 * more than one elements with this value, the first one is returned.
1749 */
1750 public int argmin() {
1751 if (isEmpty()) {
1752 return -1;
1753 }
1754 float v = Float.POSITIVE_INFINITY;
1755 int a = -1;
1756 for (int i = 0; i < length; i++) {
1757 if (!Float.isNaN(get(i)) && get(i) < v) {
1758 v = get(i);
1759 a = i;
1760 }
1761 }
1762
1763 return a;
1764 }
1765
1766 /**
1767 * Computes the minimum between two matrices. Returns the smaller of the
1768 * corresponding elements in the matrix (in-place).
1769 */
1770 public FloatMatrix mini(FloatMatrix other, FloatMatrix result) {
1771 if (result == this) {
1772 for (int i = 0; i < length; i++) {
1773 if (get(i) > other.get(i)) {
1774 put(i, other.get(i));
1775 }
1776 }
1777 } else {
1778 for (int i = 0; i < length; i++) {
1779 if (get(i) > other.get(i)) {
1780 result.put(i, other.get(i));
1781 } else {
1782 result.put(i, get(i));
1783 }
1784 }
1785 }
1786 return this;
1787 }
1788
1789 /**
1790 * Computes the minimum between two matrices. Returns the smaller of the
1791 * corresponding elements in the matrix (in-place on this).
1792 */
1793 public FloatMatrix mini(FloatMatrix other) {
1794 return mini(other, this);
1795 }
1796
1797 /**
1798 * Computes the minimum between two matrices. Returns the smaller of the
1799 * corresponding elements in the matrix (in-place on this).
1800 */
1801 public FloatMatrix min(FloatMatrix other) {
1802 return mini(other, new FloatMatrix(rows, columns));
1803 }
1804
1805 public FloatMatrix mini(float v, FloatMatrix result) {
1806 if (result == this) {
1807 for (int i = 0; i < length; i++) {
1808 if (get(i) > v) {
1809 result.put(i, v);
1810 }
1811 }
1812 } else {
1813 for (int i = 0; i < length; i++) {
1814 if (get(i) > v) {
1815 result.put(i, v);
1816 } else {
1817 result.put(i, get(i));
1818 }
1819 }
1820
1821 }
1822 return this;
1823 }
1824
1825 public FloatMatrix mini(float v) {
1826 return mini(v, this);
1827 }
1828
1829 public FloatMatrix min(float v) {
1830 return mini(v, new FloatMatrix(rows, columns));
1831 }
1832
1833 /** Returns the maximal element of the matrix. */
1834 public float max() {
1835 if (isEmpty()) {
1836 return Float.NEGATIVE_INFINITY;
1837 }
1838 float v = Float.NEGATIVE_INFINITY;
1839 for (int i = 0; i < length; i++) {
1840 if (!Float.isNaN(get(i)) && get(i) > v) {
1841 v = get(i);
1842 }
1843 }
1844 return v;
1845 }
1846
1847 /**
1848 * Returns the linear index of the maximal element of the matrix. If
1849 * there are more than one elements with this value, the first one
1850 * is returned.
1851 */
1852 public int argmax() {
1853 if (isEmpty()) {
1854 return -1;
1855 }
1856 float v = Float.NEGATIVE_INFINITY;
1857 int a = -1;
1858 for (int i = 0; i < length; i++) {
1859 if (!Float.isNaN(get(i)) && get(i) > v) {
1860 v = get(i);
1861 a = i;
1862 }
1863 }
1864
1865 return a;
1866 }
1867
1868 /**
1869 * Computes the maximum between two matrices. Returns the larger of the
1870 * corresponding elements in the matrix (in-place).
1871 */
1872 public FloatMatrix maxi(FloatMatrix other, FloatMatrix result) {
1873 if (result == this) {
1874 for (int i = 0; i < length; i++) {
1875 if (get(i) < other.get(i)) {
1876 put(i, other.get(i));
1877 }
1878 }
1879 } else {
1880 for (int i = 0; i < length; i++) {
1881 if (get(i) < other.get(i)) {
1882 result.put(i, other.get(i));
1883 } else {
1884 result.put(i, get(i));
1885 }
1886 }
1887 }
1888 return this;
1889 }
1890
1891 /**
1892 * Computes the maximum between two matrices. Returns the smaller of the
1893 * corresponding elements in the matrix (in-place on this).
1894 */
1895 public FloatMatrix maxi(FloatMatrix other) {
1896 return maxi(other, this);
1897 }
1898
1899 /**
1900 * Computes the maximum between two matrices. Returns the larger of the
1901 * corresponding elements in the matrix (in-place on this).
1902 */
1903 public FloatMatrix max(FloatMatrix other) {
1904 return maxi(other, new FloatMatrix(rows, columns));
1905 }
1906
1907 public FloatMatrix maxi(float v, FloatMatrix result) {
1908 if (result == this) {
1909 for (int i = 0; i < length; i++) {
1910 if (get(i) < v) {
1911 result.put(i, v);
1912 }
1913 }
1914 } else {
1915 for (int i = 0; i < length; i++) {
1916 if (get(i) < v) {
1917 result.put(i, v);
1918 } else {
1919 result.put(i, get(i));
1920 }
1921 }
1922
1923 }
1924 return this;
1925 }
1926
1927 public FloatMatrix maxi(float v) {
1928 return maxi(v, this);
1929 }
1930
1931 public FloatMatrix max(float v) {
1932 return maxi(v, new FloatMatrix(rows, columns));
1933 }
1934
1935 /** Computes the sum of all elements of the matrix. */
1936 public float sum() {
1937 float s = 0.0f;
1938 for (int i = 0; i < length; i++) {
1939 s += get(i);
1940 }
1941 return s;
1942 }
1943
1944 /** Computes the product of all elements of the matrix */
1945 public float prod() {
1946 float p = 1.0f;
1947 for (int i = 0; i < length; i++) {
1948 p *= get(i);
1949 }
1950 return p;
1951 }
1952
1953 /**
1954 * Computes the mean value of all elements in the matrix,
1955 * that is, <code>x.sum() / x.length</code>.
1956 */
1957 public float mean() {
1958 return sum() / length;
1959 }
1960
1961 /**
1962 * Computes the cumulative sum, that is, the sum of all elements
1963 * of the matrix up to a given index in linear addressing (in-place).
1964 */
1965 public FloatMatrix cumulativeSumi() {
1966 float s = 0.0f;
1967 for (int i = 0; i < length; i++) {
1968 s += get(i);
1969 put(i, s);
1970 }
1971 return this;
1972 }
1973
1974 /**
1975 * Computes the cumulative sum, that is, the sum of all elements
1976 * of the matrix up to a given index in linear addressing.
1977 */
1978 public FloatMatrix cumulativeSum() {
1979 return dup().cumulativeSumi();
1980 }
1981
1982 /** The scalar product of this with other. */
1983 public float dot(FloatMatrix other) {
1984 return SimpleBlas.dot(this, other);
1985 }
1986
1987 /**
1988 * Computes the projection coefficient of other on this.
1989 *
1990 * The returned scalar times <tt>this</tt> is the orthogonal projection
1991 * of <tt>other</tt> on <tt>this</tt>.
1992 */
1993 public float project(FloatMatrix other) {
1994 other.checkLength(length);
1995 float norm = 0, dot = 0;
1996 for (int i = 0; i < this.length; i++) {
1997 float x = get(i);
1998 norm += x*x;
1999 dot += x*other.get(i);
2000 }
2001 return dot/norm;
2002 }
2003
2004 /**
2005 * The Euclidean norm of the matrix as vector, also the Frobenius
2006 * norm of the matrix.
2007 */
2008 public float norm2() {
2009 float norm = 0.0f;
2010 for (int i = 0; i < length; i++) {
2011 norm += get(i) * get(i);
2012 }
2013 return (float)Math.sqrt(norm);
2014 }
2015
2016 /**
2017 * The maximum norm of the matrix (maximal absolute value of the elements).
2018 */
2019 public float normmax() {
2020 float max = 0.0f;
2021 for (int i = 0; i < length; i++) {
2022 float a = Math.abs(get(i));
2023 if (a > max)
2024 max = a;
2025 }
2026 return max;
2027 }
2028
2029 /**
2030 * The 1-norm of the matrix as vector (sum of absolute values of elements).
2031 */
2032 public float norm1() {
2033 float norm = 0.0f;
2034 for (int i = 0; i < length; i++) {
2035 norm += Math.abs(get(i));
2036 }
2037 return norm;
2038 }
2039
2040 /**
2041 * Return a new matrix with all elements sorted.
2042 */
2043 public FloatMatrix sort() {
2044 float array[] = toArray();
2045 java.util.Arrays.sort(array);
2046 return new FloatMatrix(rows, columns, array);
2047 }
2048
2049 /**
2050 * Sort elements in-place.
2051 */
2052 public FloatMatrix sorti() {
2053 Arrays.sort(data);
2054 return this;
2055 }
2056
2057 /**
2058 * Get the sorting permutation.
2059 *
2060 * @return an int[] array such that which indexes the elements in sorted
2061 * order.
2062 */
2063 public int[] sortingPermutation() {
2064 Integer[] indices = new Integer[length];
2065
2066 for (int i = 0; i < length; i++) {
2067 indices[i] = i;
2068 }
2069
2070 final float[] array = data;
2071
2072 Arrays.sort(indices, new Comparator() {
2073
2074 public int compare(Object o1, Object o2) {
2075 int i = (Integer) o1;
2076 int j = (Integer) o2;
2077 if (array[i] < array[j]) {
2078 return -1;
2079 } else if (array[i] == array[j]) {
2080 return 0;
2081 } else {
2082 return 1;
2083 }
2084 }
2085 });
2086
2087 int[] result = new int[length];
2088
2089 for (int i = 0; i < length; i++) {
2090 result[i] = indices[i];
2091 }
2092
2093 return result;
2094 }
2095
2096 /**
2097 * Sort columns (in-place).
2098 */
2099 public FloatMatrix sortColumnsi() {
2100 for (int i = 0; i < length; i += rows) {
2101 Arrays.sort(data, i, i + rows);
2102 }
2103 return this;
2104 }
2105
2106 /** Sort columns. */
2107 public FloatMatrix sortColumns() {
2108 return dup().sortColumnsi();
2109 }
2110
2111 /** Return matrix of indices which sort all columns. */
2112 public int[][] columnSortingPermutations() {
2113 int[][] result = new int[columns][];
2114
2115 FloatMatrix temp = new FloatMatrix(rows);
2116 for (int c = 0; c < columns; c++) {
2117 result[c] = getColumn(c, temp).sortingPermutation();
2118 }
2119
2120 return result;
2121 }
2122
2123 /** Sort rows (in-place). */
2124 public FloatMatrix sortRowsi() {
2125 // actually, this is much harder because the data is not consecutive
2126 // in memory...
2127 FloatMatrix temp = new FloatMatrix(columns);
2128 for (int r = 0; r < rows; r++) {
2129 putRow(r, getRow(r, temp).sorti());
2130 }
2131 return this;
2132 }
2133
2134 /** Sort rows. */
2135 public FloatMatrix sortRows() {
2136 return dup().sortRowsi();
2137 }
2138
2139 /** Return matrix of indices which sort all columns. */
2140 public int[][] rowSortingPermutations() {
2141 int[][] result = new int[rows][];
2142
2143 FloatMatrix temp = new FloatMatrix(columns);
2144 for (int r = 0; r < rows; r++) {
2145 result[r] = getRow(r, temp).sortingPermutation();
2146 }
2147
2148 return result;
2149 }
2150
2151 /** Return a vector containing the sums of the columns (having number of columns many entries) */
2152 public FloatMatrix columnSums() {
2153 if (rows == 1) {
2154 return dup();
2155 } else {
2156 FloatMatrix v = new FloatMatrix(1, columns);
2157
2158 for (int c = 0; c < columns; c++) {
2159 for (int r = 0; r < rows; r++) {
2160 v.put(c, v.get(c) + get(r, c));
2161 }
2162 }
2163
2164 return v;
2165 }
2166 }
2167
2168 /** Return a vector containing the means of all columns. */
2169 public FloatMatrix columnMeans() {
2170 return columnSums().divi(rows);
2171 }
2172
2173 /** Return a vector containing the sum of the rows. */
2174 public FloatMatrix rowSums() {
2175 if (columns == 1) {
2176 return dup();
2177 } else {
2178 FloatMatrix v = new FloatMatrix(rows);
2179
2180 for (int c = 0; c < columns; c++) {
2181 for (int r = 0; r < rows; r++) {
2182 v.put(r, v.get(r) + get(r, c));
2183 }
2184 }
2185
2186 return v;
2187 }
2188 }
2189
2190 /** Return a vector containing the means of the rows. */
2191 public FloatMatrix rowMeans() {
2192 return rowSums().divi(columns);
2193 }
2194
2195 /** Get a copy of a column. */
2196 public FloatMatrix getColumn(int c) {
2197 return getColumn(c, new FloatMatrix(rows, 1));
2198 }
2199
2200 /** Copy a column to the given vector. */
2201 public FloatMatrix getColumn(int c, FloatMatrix result) {
2202 result.checkLength(rows);
2203 JavaBlas.rcopy(rows, data, index(0, c), 1, result.data, 0, 1);
2204 return result;
2205 }
2206
2207 /** Copy a column back into the matrix. */
2208 public void putColumn(int c, FloatMatrix v) {
2209 JavaBlas.rcopy(rows, v.data, 0, 1, data, index(0, c), 1);
2210 }
2211
2212 /** Get a copy of a row. */
2213 public FloatMatrix getRow(int r) {
2214 return getRow(r, new FloatMatrix(1, columns));
2215 }
2216
2217 /** Copy a row to a given vector. */
2218 public FloatMatrix getRow(int r, FloatMatrix result) {
2219 result.checkLength(columns);
2220 JavaBlas.rcopy(columns, data, index(r, 0), rows, result.data, 0, 1);
2221 return result;
2222 }
2223
2224 /** Copy a row back into the matrix. */
2225 public void putRow(int r, FloatMatrix v) {
2226 JavaBlas.rcopy(columns, v.data, 0, 1, data, index(r, 0), rows);
2227 }
2228
2229 /** Return column-wise minimums. */
2230 public FloatMatrix columnMins() {
2231 FloatMatrix mins = new FloatMatrix(1, columns);
2232 for (int c = 0; c < columns; c++) {
2233 mins.put(c, getColumn(c).min());
2234 }
2235 return mins;
2236 }
2237
2238 /** Return index of minimal element per column. */
2239 public int[] columnArgmins() {
2240 int[] argmins = new int[columns];
2241 for (int c = 0; c < columns; c++) {
2242 argmins[c] = getColumn(c).argmin();
2243 }
2244 return argmins;
2245 }
2246
2247 /** Return column-wise maximums. */
2248 public FloatMatrix columnMaxs() {
2249 FloatMatrix maxs = new FloatMatrix(1, columns);
2250 for (int c = 0; c < columns; c++) {
2251 maxs.put(c, getColumn(c).max());
2252 }
2253 return maxs;
2254 }
2255
2256 /** Return index of minimal element per column. */
2257 public int[] columnArgmaxs() {
2258 int[] argmaxs = new int[columns];
2259 for (int c = 0; c < columns; c++) {
2260 argmaxs[c] = getColumn(c).argmax();
2261 }
2262 return argmaxs;
2263 }
2264
2265 /** Return row-wise minimums. */
2266 public FloatMatrix rowMins() {
2267 FloatMatrix mins = new FloatMatrix(rows);
2268 for (int c = 0; c < rows; c++) {
2269 mins.put(c, getRow(c).min());
2270 }
2271 return mins;
2272 }
2273
2274 /** Return index of minimal element per row. */
2275 public int[] rowArgmins() {
2276 int[] argmins = new int[rows];
2277 for (int c = 0; c < rows; c++) {
2278 argmins[c] = getRow(c).argmin();
2279 }
2280 return argmins;
2281 }
2282
2283 /** Return row-wise maximums. */
2284 public FloatMatrix rowMaxs() {
2285 FloatMatrix maxs = new FloatMatrix(rows);
2286 for (int c = 0; c < rows; c++) {
2287 maxs.put(c, getRow(c).max());
2288 }
2289 return maxs;
2290 }
2291
2292 /** Return index of minimal element per row. */
2293 public int[] rowArgmaxs() {
2294 int[] argmaxs = new int[rows];
2295 for (int c = 0; c < rows; c++) {
2296 argmaxs[c] = getRow(c).argmax();
2297 }
2298 return argmaxs;
2299 }
2300
2301 /**************************************************************************
2302 * Elementwise Functions
2303 */
2304 /** Add a row vector to all rows of the matrix (in place). */
2305 public FloatMatrix addiRowVector(FloatMatrix x) {
2306 x.checkLength(columns);
2307 for (int c = 0; c < columns; c++)
2308 for (int r = 0; r < rows; r++)
2309 put(r, c, get(r, c) + x.get(c));
2310 return this;
2311 }
2312
2313 /** Add a row to all rows of the matrix. */
2314 public FloatMatrix addRowVector(FloatMatrix x) {
2315 return dup().addiRowVector(x);
2316 }
2317
2318 /** Add a vector to all columns of the matrix (in-place). */
2319 public FloatMatrix addiColumnVector(FloatMatrix x) {
2320 x.checkLength(rows);
2321 for (int c = 0; c < columns; c++)
2322 for (int r = 0; r < rows; r++)
2323 put(r, c, get(r, c) + x.get(r));
2324 return this;
2325 }
2326
2327 /** Add a vector to all columns of the matrix. */
2328 public FloatMatrix addColumnVector(FloatMatrix x) {
2329 return dup().addiColumnVector(x);
2330 }
2331
2332 /** Subtract a row vector from all rows of the matrix (in-place). */
2333 public FloatMatrix subiRowVector(FloatMatrix x) {
2334 // This is a bit crazy, but a row vector must have as length as the columns of the matrix.
2335 x.checkLength(columns);
2336 for (int c = 0; c < columns; c++)
2337 for (int r = 0; r < rows; r++)
2338 put(r, c, get(r, c) - x.get(c));
2339 return this;
2340 }
2341
2342 /** Subtract a row vector from all rows of the matrix. */
2343 public FloatMatrix subRowVector(FloatMatrix x) {
2344 return dup().subiRowVector(x);
2345 }
2346
2347 /** Subtract a column vector from all columns of the matrix (in-place). */
2348 public FloatMatrix subiColumnVector(FloatMatrix x) {
2349 x.checkLength(rows);
2350 for (int c = 0; c < columns; c++)
2351 for (int r = 0; r < rows; r++)
2352 put(r, c, get(r, c) - x.get(r));
2353 return this;
2354 }
2355
2356 /** Subtract a vector from all columns of the matrix. */
2357 public FloatMatrix subColumnVector(FloatMatrix x) {
2358 return dup().subiColumnVector(x);
2359 }
2360
2361 /** Multiply a row by a scalar. */
2362 public FloatMatrix mulRow(int r, float scale) {
2363 NativeBlas.sscal(columns, scale, data, index(r, 0), rows);
2364 return this;
2365 }
2366
2367 /** Multiply a column by a scalar. */
2368 public FloatMatrix mulColumn(int c, float scale) {
2369 NativeBlas.sscal(rows, scale, data, index(0, c), 1);
2370 return this;
2371 }
2372
2373 /** Multiply all columns with a column vector (in-place). */
2374 public FloatMatrix muliColumnVector(FloatMatrix x) {
2375 x.checkLength(rows);
2376 for (int c = 0; c < columns; c++) {
2377 for (int r = 0; r < rows; r++)
2378 put(r, c, get(r, c) * x.get(r));
2379 }
2380 return this;
2381 }
2382
2383 /** Multiply all columns with a column vector. */
2384 public FloatMatrix mulColumnVector(FloatMatrix x) {
2385 return dup().muliColumnVector(x);
2386 }
2387
2388 /** Multiply all rows with a row vector (in-place). */
2389 public FloatMatrix muliRowVector(FloatMatrix x) {
2390 x.checkLength(columns);
2391 for (int c = 0; c < columns; c++)
2392 for (int r = 0; r < rows; r++)
2393 put(r, c, get(r, c) * x.get(c));
2394 return this;
2395 }
2396
2397 /** Multiply all rows with a row vector. */
2398 public FloatMatrix mulRowVector(FloatMatrix x) {
2399 return dup().muliRowVector(x);
2400 }
2401
2402 public FloatMatrix diviRowVector(FloatMatrix x) {
2403 x.checkLength(columns);
2404 for (int c = 0; c < columns; c++)
2405 for (int r = 0; r < rows; r++)
2406 put(r, c, get(r, c) / x.get(c));
2407 return this;
2408 }
2409
2410 public FloatMatrix divRowVector(FloatMatrix x) {
2411 return dup().diviRowVector(x);
2412 }
2413
2414 public FloatMatrix diviColumnVector(FloatMatrix x) {
2415 x.checkLength(rows);
2416 for (int c = 0; c < columns; c++)
2417 for (int r = 0; r < rows; r++)
2418 put(r, c, get(r, c) / x.get(r));
2419 return this;
2420 }
2421
2422 public FloatMatrix divColumnVector(FloatMatrix x) {
2423 return dup().diviColumnVector(x);
2424 }
2425
2426 /**
2427 * Writes out this matrix to the given data stream.
2428 * @param dos the data output stream to write to.
2429 * @throws IOException
2430 */
2431 public void out(DataOutputStream dos) throws IOException {
2432 dos.writeUTF("float");
2433 dos.writeInt(columns);
2434 dos.writeInt(rows);
2435
2436 dos.writeInt(data.length);
2437 for (int i = 0; i < data.length; i++) {
2438 dos.writeDouble(data[i]);
2439 }
2440 }
2441
2442 /**
2443 * Reads in a matrix from the given data stream. Note
2444 * that the old data of this matrix will be discarded.
2445 * @param dis the data input stream to read from.
2446 * @throws IOException
2447 */
2448 public void in(DataInputStream dis) throws IOException {
2449 if (!dis.readUTF().equals("float")) {
2450 throw new IllegalStateException("The matrix in the specified file is not of the correct type!");
2451 }
2452
2453 this.columns = dis.readInt();
2454 this.rows = dis.readInt();
2455
2456 final int MAX = dis.readInt();
2457 data = new float[MAX];
2458 for (int i = 0; i < MAX; i++) {
2459 data[i] = dis.readFloat();
2460 }
2461 }
2462
2463 /**
2464 * Saves this matrix to the specified file.
2465 * @param filename the file to write the matrix in.
2466 * @throws IOException thrown on errors while writing the matrix to the file
2467 */
2468 public void save(String filename) throws IOException {
2469 DataOutputStream dos = new DataOutputStream(new FileOutputStream(filename, false));
2470 this.out(dos);
2471 }
2472
2473 /**
2474 * Loads a matrix from a file into this matrix. Note that the old data
2475 * of this matrix will be discarded.
2476 * @param filename the file to read the matrix from
2477 * @throws IOException thrown on errors while reading the matrix
2478 */
2479 public void load(String filename) throws IOException {
2480 DataInputStream dis = new DataInputStream(new FileInputStream(filename));
2481 this.in(dis);
2482 }
2483
2484 public static FloatMatrix loadAsciiFile(String filename) throws IOException {
2485 BufferedReader is = new BufferedReader(new InputStreamReader(new FileInputStream(filename)));
2486
2487 // Go through file and count columns and rows. What makes this endeavour a bit difficult is
2488 // that files can have leading or trailing spaces leading to spurious fields
2489 // after String.split().
2490 String line;
2491 int rows = 0;
2492 int columns = -1;
2493 while ((line = is.readLine()) != null) {
2494 String[] elements = line.split("\\s+");
2495 int numElements = elements.length;
2496 if (elements[0].length() == 0) {
2497 numElements--;
2498 }
2499 if (elements[elements.length - 1].length() == 0) {
2500 numElements--;
2501 }
2502
2503 if (columns == -1) {
2504 columns = numElements;
2505 } else {
2506 if (columns != numElements) {
2507 throw new IOException("Number of elements changes in line " + line + ".");
2508 }
2509 }
2510
2511 rows++;
2512 }
2513 is.close();
2514
2515 // Go through file a second time process the actual data.
2516 is = new BufferedReader(new InputStreamReader(new FileInputStream(filename)));
2517 FloatMatrix result = new FloatMatrix(rows, columns);
2518 int r = 0;
2519 while ((line = is.readLine()) != null) {
2520 String[] elements = line.split("\\s+");
2521 int firstElement = (elements[0].length() == 0) ? 1 : 0;
2522 for (int c = 0, cc = firstElement; c < columns; c++, cc++) {
2523 result.put(r, c, Float.valueOf(elements[cc]));
2524 }
2525 r++;
2526 }
2527 return result;
2528 }
2529
2530 public static FloatMatrix loadCSVFile(String filename) throws IOException {
2531 BufferedReader is = new BufferedReader(new InputStreamReader(new FileInputStream(filename)));
2532
2533 List<FloatMatrix> rows = new LinkedList<FloatMatrix>();
2534 String line;
2535 int columns = -1;
2536 while ((line = is.readLine()) != null) {
2537 String[] elements = line.split(",");
2538 int numElements = elements.length;
2539 if (elements[0].length() == 0) {
2540 numElements--;
2541 }
2542 if (elements[elements.length - 1].length() == 0) {
2543 numElements--;
2544 }
2545
2546 if (columns == -1) {
2547 columns = numElements;
2548 } else {
2549 if (columns != numElements) {
2550 throw new IOException("Number of elements changes in line " + line + ".");
2551 }
2552 }
2553
2554 FloatMatrix row = new FloatMatrix(columns);
2555 for (int c = 0; c < columns; c++)
2556 row.put(c, Float.valueOf(elements[c]));
2557 rows.add(row);
2558 }
2559 is.close();
2560
2561 System.out.println("Done reading file");
2562
2563 FloatMatrix result = new FloatMatrix(rows.size(), columns);
2564 int r = 0;
2565 Iterator<FloatMatrix> ri = rows.iterator();
2566 while (ri.hasNext()) {
2567 result.putRow(r, ri.next());
2568 r++;
2569 }
2570 return result;
2571 }
2572
2573 /****************************************************************
2574 * Autogenerated code
2575 */
2576 /***** Code for operators ***************************************/
2577
2578 /* Overloads for the usual arithmetic operations */
2579 /*#
2580 def gen_overloads(base, result_rows, result_cols, verb=''); <<-EOS
2581 #{doc verb.capitalize + " a matrix (in place)."}
2582 public FloatMatrix #{base}i(FloatMatrix other) {
2583 return #{base}i(other, this);
2584 }
2585
2586 #{doc verb.capitalize + " a matrix (in place)."}
2587 public FloatMatrix #{base}(FloatMatrix other) {
2588 return #{base}i(other, new FloatMatrix(#{result_rows}, #{result_cols}));
2589 }
2590
2591 #{doc verb.capitalize + " a scalar (in place)."}
2592 public FloatMatrix #{base}i(float v) {
2593 return #{base}i(v, this);
2594 }
2595
2596 #{doc verb.capitalize + " a scalar."}
2597 public FloatMatrix #{base}(float v) {
2598 return #{base}i(v, new FloatMatrix(rows, columns));
2599 }
2600 EOS
2601 end
2602 #*/
2603
2604 /* Generating code for logical operators. This not only generates the stubs
2605 * but really all of the code.
2606 */
2607 /*#
2608 def gen_compare(name, op, cmp); <<-EOS
2609 #{doc 'Test for ' + cmp + ' (in-place).'}
2610 public FloatMatrix #{name}i(FloatMatrix other, FloatMatrix result) {
2611 if (other.isScalar())
2612 return #{name}i(other.scalar(), result);
2613
2614 assertSameLength(other);
2615 ensureResultLength(other, result);
2616
2617 for (int i = 0; i < length; i++)
2618 result.put(i, get(i) #{op} other.get(i) ? 1.0f : 0.0f);
2619 return result;
2620 }
2621
2622 #{doc 'Test for ' + cmp + ' (in-place).'}
2623 public FloatMatrix #{name}i(FloatMatrix other) {
2624 return #{name}i(other, this);
2625 }
2626
2627 #{doc 'Test for ' + cmp + '.'}
2628 public FloatMatrix #{name}(FloatMatrix other) {
2629 return #{name}i(other, new FloatMatrix(rows, columns));
2630 }
2631
2632 #{doc 'Test for ' + cmp + ' against a scalar (in-place).'}
2633 public FloatMatrix #{name}i(float value, FloatMatrix result) {
2634 ensureResultLength(null, result);
2635 for (int i = 0; i < length; i++)
2636 result.put(i, get(i) #{op} value ? 1.0f : 0.0f);
2637 return result;
2638 }
2639
2640 #{doc 'Test for ' + cmp + ' against a scalar (in-place).'}
2641 public FloatMatrix #{name}i(float value) {
2642 return #{name}i(value, this);
2643 }
2644
2645 #{doc 'test for ' + cmp + ' against a scalar.'}
2646 public FloatMatrix #{name}(float value) {
2647 return #{name}i(value, new FloatMatrix(rows, columns));
2648 }
2649 EOS
2650 end
2651 #*/
2652 /*#
2653 def gen_logical(name, op, cmp); <<-EOS
2654 #{doc 'Compute elementwise ' + cmp + ' (in-place).'}
2655 public FloatMatrix #{name}i(FloatMatrix other, FloatMatrix result) {
2656 assertSameLength(other);
2657 ensureResultLength(other, result);
2658
2659 for (int i = 0; i < length; i++)
2660 result.put(i, (get(i) != 0.0f) #{op} (other.get(i) != 0.0f) ? 1.0f : 0.0f);
2661 return result;
2662 }
2663
2664 #{doc 'Compute elementwise ' + cmp + ' (in-place).'}
2665 public FloatMatrix #{name}i(FloatMatrix other) {
2666 return #{name}i(other, this);
2667 }
2668
2669 #{doc 'Compute elementwise ' + cmp + '.'}
2670 public FloatMatrix #{name}(FloatMatrix other) {
2671 return #{name}i(other, new FloatMatrix(rows, columns));
2672 }
2673
2674 #{doc 'Compute elementwise ' + cmp + ' against a scalar (in-place).'}
2675 public FloatMatrix #{name}i(float value, FloatMatrix result) {
2676 ensureResultLength(null, result);
2677 boolean val = (value != 0.0f);
2678 for (int i = 0; i < length; i++)
2679 result.put(i, (get(i) != 0.0f) #{op} val ? 1.0f : 0.0f);
2680 return result;
2681 }
2682
2683 #{doc 'Compute elementwise ' + cmp + ' against a scalar (in-place).'}
2684 public FloatMatrix #{name}i(float value) {
2685 return #{name}i(value, this);
2686 }
2687
2688 #{doc 'Compute elementwise ' + cmp + ' against a scalar.'}
2689 public FloatMatrix #{name}(float value) {
2690 return #{name}i(value, new FloatMatrix(rows, columns));
2691 }
2692 EOS
2693 end
2694 #*/
2695
2696 /*# collect(gen_overloads('add', 'rows', 'columns', 'add'),
2697 gen_overloads('sub', 'rows', 'columns', 'subtract'),
2698 gen_overloads('rsub', 'rows', 'columns', '(right-)subtract'),
2699 gen_overloads('div', 'rows', 'columns', 'elementwise divide by'),
2700 gen_overloads('rdiv', 'rows', 'columns', '(right-)elementwise divide by'),
2701 gen_overloads('mul', 'rows', 'columns', 'elementwise multiply by'),
2702 gen_overloads('mmul', 'rows', 'other.columns', 'matrix-multiply by'),
2703 gen_compare('lt', '<', '"less than"'),
2704 gen_compare('gt', '>', '"greater than"'),
2705 gen_compare('le', '<=', '"less than or equal"'),
2706 gen_compare('ge', '>=', '"greater than or equal"'),
2707 gen_compare('eq', '==', 'equality'),
2708 gen_compare('ne', '!=', 'inequality'),
2709 gen_logical('and', '&', 'logical and'),
2710 gen_logical('or', '|', 'logical or'),
2711 gen_logical('xor', '^', 'logical xor'))
2712 #*/
2713 //RJPP-BEGIN------------------------------------------------------------
2714 /** Add a matrix (in place). */
2715 public FloatMatrix addi(FloatMatrix other) {
2716 return addi(other, this);
2717 }
2718
2719 /** Add a matrix (in place). */
2720 public FloatMatrix add(FloatMatrix other) {
2721 return addi(other, new FloatMatrix(rows, columns));
2722 }
2723
2724 /** Add a scalar (in place). */
2725 public FloatMatrix addi(float v) {
2726 return addi(v, this);
2727 }
2728
2729 /** Add a scalar. */
2730 public FloatMatrix add(float v) {
2731 return addi(v, new FloatMatrix(rows, columns));
2732 }
2733
2734 /** Subtract a matrix (in place). */
2735 public FloatMatrix subi(FloatMatrix other) {
2736 return subi(other, this);
2737 }
2738
2739 /** Subtract a matrix (in place). */
2740 public FloatMatrix sub(FloatMatrix other) {
2741 return subi(other, new FloatMatrix(rows, columns));
2742 }
2743
2744 /** Subtract a scalar (in place). */
2745 public FloatMatrix subi(float v) {
2746 return subi(v, this);
2747 }
2748
2749 /** Subtract a scalar. */
2750 public FloatMatrix sub(float v) {
2751 return subi(v, new FloatMatrix(rows, columns));
2752 }
2753
2754 /** (right-)subtract a matrix (in place). */
2755 public FloatMatrix rsubi(FloatMatrix other) {
2756 return rsubi(other, this);
2757 }
2758
2759 /** (right-)subtract a matrix (in place). */
2760 public FloatMatrix rsub(FloatMatrix other) {
2761 return rsubi(other, new FloatMatrix(rows, columns));
2762 }
2763
2764 /** (right-)subtract a scalar (in place). */
2765 public FloatMatrix rsubi(float v) {
2766 return rsubi(v, this);
2767 }
2768
2769 /** (right-)subtract a scalar. */
2770 public FloatMatrix rsub(float v) {
2771 return rsubi(v, new FloatMatrix(rows, columns));
2772 }
2773
2774 /** Elementwise divide by a matrix (in place). */
2775 public FloatMatrix divi(FloatMatrix other) {
2776 return divi(other, this);
2777 }
2778
2779 /** Elementwise divide by a matrix (in place). */
2780 public FloatMatrix div(FloatMatrix other) {
2781 return divi(other, new FloatMatrix(rows, columns));
2782 }
2783
2784 /** Elementwise divide by a scalar (in place). */
2785 public FloatMatrix divi(float v) {
2786 return divi(v, this);
2787 }
2788
2789 /** Elementwise divide by a scalar. */
2790 public FloatMatrix div(float v) {
2791 return divi(v, new FloatMatrix(rows, columns));
2792 }
2793
2794 /** (right-)elementwise divide by a matrix (in place). */
2795 public FloatMatrix rdivi(FloatMatrix other) {
2796 return rdivi(other, this);
2797 }
2798
2799 /** (right-)elementwise divide by a matrix (in place). */
2800 public FloatMatrix rdiv(FloatMatrix other) {
2801 return rdivi(other, new FloatMatrix(rows, columns));
2802 }
2803
2804 /** (right-)elementwise divide by a scalar (in place). */
2805 public FloatMatrix rdivi(float v) {
2806 return rdivi(v, this);
2807 }
2808
2809 /** (right-)elementwise divide by a scalar. */
2810 public FloatMatrix rdiv(float v) {
2811 return rdivi(v, new FloatMatrix(rows, columns));
2812 }
2813
2814 /** Elementwise multiply by a matrix (in place). */
2815 public FloatMatrix muli(FloatMatrix other) {
2816 return muli(other, this);
2817 }
2818
2819 /** Elementwise multiply by a matrix (in place). */
2820 public FloatMatrix mul(FloatMatrix other) {
2821 return muli(other, new FloatMatrix(rows, columns));
2822 }
2823
2824 /** Elementwise multiply by a scalar (in place). */
2825 public FloatMatrix muli(float v) {
2826 return muli(v, this);
2827 }
2828
2829 /** Elementwise multiply by a scalar. */
2830 public FloatMatrix mul(float v) {
2831 return muli(v, new FloatMatrix(rows, columns));
2832 }
2833
2834 /** Matrix-multiply by a matrix (in place). */
2835 public FloatMatrix mmuli(FloatMatrix other) {
2836 return mmuli(other, this);
2837 }
2838
2839 /** Matrix-multiply by a matrix (in place). */
2840 public FloatMatrix mmul(FloatMatrix other) {
2841 return mmuli(other, new FloatMatrix(rows, other.columns));
2842 }
2843
2844 /** Matrix-multiply by a scalar (in place). */
2845 public FloatMatrix mmuli(float v) {
2846 return mmuli(v, this);
2847 }
2848
2849 /** Matrix-multiply by a scalar. */
2850 public FloatMatrix mmul(float v) {
2851 return mmuli(v, new FloatMatrix(rows, columns));
2852 }
2853
2854 /** Test for "less than" (in-place). */
2855 public FloatMatrix lti(FloatMatrix other, FloatMatrix result) {
2856 if (other.isScalar())
2857 return lti(other.scalar(), result);
2858
2859 assertSameLength(other);
2860 ensureResultLength(other, result);
2861
2862 for (int i = 0; i < length; i++)
2863 result.put(i, get(i) < other.get(i) ? 1.0f : 0.0f);
2864 return result;
2865 }
2866
2867 /** Test for "less than" (in-place). */
2868 public FloatMatrix lti(FloatMatrix other) {
2869 return lti(other, this);
2870 }
2871
2872 /** Test for "less than". */
2873 public FloatMatrix lt(FloatMatrix other) {
2874 return lti(other, new FloatMatrix(rows, columns));
2875 }
2876
2877 /** Test for "less than" against a scalar (in-place). */
2878 public FloatMatrix lti(float value, FloatMatrix result) {
2879 ensureResultLength(null, result);
2880 for (int i = 0; i < length; i++)
2881 result.put(i, get(i) < value ? 1.0f : 0.0f);
2882 return result;
2883 }
2884
2885 /** Test for "less than" against a scalar (in-place). */
2886 public FloatMatrix lti(float value) {
2887 return lti(value, this);
2888 }
2889
2890 /** test for "less than" against a scalar. */
2891 public FloatMatrix lt(float value) {
2892 return lti(value, new FloatMatrix(rows, columns));
2893 }
2894
2895 /** Test for "greater than" (in-place). */
2896 public FloatMatrix gti(FloatMatrix other, FloatMatrix result) {
2897 if (other.isScalar())
2898 return gti(other.scalar(), result);
2899
2900 assertSameLength(other);
2901 ensureResultLength(other, result);
2902
2903 for (int i = 0; i < length; i++)
2904 result.put(i, get(i) > other.get(i) ? 1.0f : 0.0f);
2905 return result;
2906 }
2907
2908 /** Test for "greater than" (in-place). */
2909 public FloatMatrix gti(FloatMatrix other) {
2910 return gti(other, this);
2911 }
2912
2913 /** Test for "greater than". */
2914 public FloatMatrix gt(FloatMatrix other) {
2915 return gti(other, new FloatMatrix(rows, columns));
2916 }
2917
2918 /** Test for "greater than" against a scalar (in-place). */
2919 public FloatMatrix gti(float value, FloatMatrix result) {
2920 ensureResultLength(null, result);
2921 for (int i = 0; i < length; i++)
2922 result.put(i, get(i) > value ? 1.0f : 0.0f);
2923 return result;
2924 }
2925
2926 /** Test for "greater than" against a scalar (in-place). */
2927 public FloatMatrix gti(float value) {
2928 return gti(value, this);
2929 }
2930
2931 /** test for "greater than" against a scalar. */
2932 public FloatMatrix gt(float value) {
2933 return gti(value, new FloatMatrix(rows, columns));
2934 }
2935
2936 /** Test for "less than or equal" (in-place). */
2937 public FloatMatrix lei(FloatMatrix other, FloatMatrix result) {
2938 if (other.isScalar())
2939 return lei(other.scalar(), result);
2940
2941 assertSameLength(other);
2942 ensureResultLength(other, result);
2943
2944 for (int i = 0; i < length; i++)
2945 result.put(i, get(i) <= other.get(i) ? 1.0f : 0.0f);
2946 return result;
2947 }
2948
2949 /** Test for "less than or equal" (in-place). */
2950 public FloatMatrix lei(FloatMatrix other) {
2951 return lei(other, this);
2952 }
2953
2954 /** Test for "less than or equal". */
2955 public FloatMatrix le(FloatMatrix other) {
2956 return lei(other, new FloatMatrix(rows, columns));
2957 }
2958
2959 /** Test for "less than or equal" against a scalar (in-place). */
2960 public FloatMatrix lei(float value, FloatMatrix result) {
2961 ensureResultLength(null, result);
2962 for (int i = 0; i < length; i++)
2963 result.put(i, get(i) <= value ? 1.0f : 0.0f);
2964 return result;
2965 }
2966
2967 /** Test for "less than or equal" against a scalar (in-place). */
2968 public FloatMatrix lei(float value) {
2969 return lei(value, this);
2970 }
2971
2972 /** test for "less than or equal" against a scalar. */
2973 public FloatMatrix le(float value) {
2974 return lei(value, new FloatMatrix(rows, columns));
2975 }
2976
2977 /** Test for "greater than or equal" (in-place). */
2978 public FloatMatrix gei(FloatMatrix other, FloatMatrix result) {
2979 if (other.isScalar())
2980 return gei(other.scalar(), result);
2981
2982 assertSameLength(other);
2983 ensureResultLength(other, result);
2984
2985 for (int i = 0; i < length; i++)
2986 result.put(i, get(i) >= other.get(i) ? 1.0f : 0.0f);
2987 return result;
2988 }
2989
2990 /** Test for "greater than or equal" (in-place). */
2991 public FloatMatrix gei(FloatMatrix other) {
2992 return gei(other, this);
2993 }
2994
2995 /** Test for "greater than or equal". */
2996 public FloatMatrix ge(FloatMatrix other) {
2997 return gei(other, new FloatMatrix(rows, columns));
2998 }
2999
3000 /** Test for "greater than or equal" against a scalar (in-place). */
3001 public FloatMatrix gei(float value, FloatMatrix result) {
3002 ensureResultLength(null, result);
3003 for (int i = 0; i < length; i++)
3004 result.put(i, get(i) >= value ? 1.0f : 0.0f);
3005 return result;
3006 }
3007
3008 /** Test for "greater than or equal" against a scalar (in-place). */
3009 public FloatMatrix gei(float value) {
3010 return gei(value, this);
3011 }
3012
3013 /** test for "greater than or equal" against a scalar. */
3014 public FloatMatrix ge(float value) {
3015 return gei(value, new FloatMatrix(rows, columns));
3016 }
3017
3018 /** Test for equality (in-place). */
3019 public FloatMatrix eqi(FloatMatrix other, FloatMatrix result) {
3020 if (other.isScalar())
3021 return eqi(other.scalar(), result);
3022
3023 assertSameLength(other);
3024 ensureResultLength(other, result);
3025
3026 for (int i = 0; i < length; i++)
3027 result.put(i, get(i) == other.get(i) ? 1.0f : 0.0f);
3028 return result;
3029 }
3030
3031 /** Test for equality (in-place). */
3032 public FloatMatrix eqi(FloatMatrix other) {
3033 return eqi(other, this);
3034 }
3035
3036 /** Test for equality. */
3037 public FloatMatrix eq(FloatMatrix other) {
3038 return eqi(other, new FloatMatrix(rows, columns));
3039 }
3040
3041 /** Test for equality against a scalar (in-place). */
3042 public FloatMatrix eqi(float value, FloatMatrix result) {
3043 ensureResultLength(null, result);
3044 for (int i = 0; i < length; i++)
3045 result.put(i, get(i) == value ? 1.0f : 0.0f);
3046 return result;
3047 }
3048
3049 /** Test for equality against a scalar (in-place). */
3050 public FloatMatrix eqi(float value) {
3051 return eqi(value, this);
3052 }
3053
3054 /** test for equality against a scalar. */
3055 public FloatMatrix eq(float value) {
3056 return eqi(value, new FloatMatrix(rows, columns));
3057 }
3058
3059 /** Test for inequality (in-place). */
3060 public FloatMatrix nei(FloatMatrix other, FloatMatrix result) {
3061 if (other.isScalar())
3062 return nei(other.scalar(), result);
3063
3064 assertSameLength(other);
3065 ensureResultLength(other, result);
3066
3067 for (int i = 0; i < length; i++)
3068 result.put(i, get(i) != other.get(i) ? 1.0f : 0.0f);
3069 return result;
3070 }
3071
3072 /** Test for inequality (in-place). */
3073 public FloatMatrix nei(FloatMatrix other) {
3074 return nei(other, this);
3075 }
3076
3077 /** Test for inequality. */
3078 public FloatMatrix ne(FloatMatrix other) {
3079 return nei(other, new FloatMatrix(rows, columns));
3080 }
3081
3082 /** Test for inequality against a scalar (in-place). */
3083 public FloatMatrix nei(float value, FloatMatrix result) {
3084 ensureResultLength(null, result);
3085 for (int i = 0; i < length; i++)
3086 result.put(i, get(i) != value ? 1.0f : 0.0f);
3087 return result;
3088 }
3089
3090 /** Test for inequality against a scalar (in-place). */
3091 public FloatMatrix nei(float value) {
3092 return nei(value, this);
3093 }
3094
3095 /** test for inequality against a scalar. */
3096 public FloatMatrix ne(float value) {
3097 return nei(value, new FloatMatrix(rows, columns));
3098 }
3099
3100 /** Compute elementwise logical and (in-place). */
3101 public FloatMatrix andi(FloatMatrix other, FloatMatrix result) {
3102 assertSameLength(other);
3103 ensureResultLength(other, result);
3104
3105 for (int i = 0; i < length; i++)
3106 result.put(i, (get(i) != 0.0f) & (other.get(i) != 0.0f) ? 1.0f : 0.0f);
3107 return result;
3108 }
3109
3110 /** Compute elementwise logical and (in-place). */
3111 public FloatMatrix andi(FloatMatrix other) {
3112 return andi(other, this);
3113 }
3114
3115 /** Compute elementwise logical and. */
3116 public FloatMatrix and(FloatMatrix other) {
3117 return andi(other, new FloatMatrix(rows, columns));
3118 }
3119
3120 /** Compute elementwise logical and against a scalar (in-place). */
3121 public FloatMatrix andi(float value, FloatMatrix result) {
3122 ensureResultLength(null, result);
3123 boolean val = (value != 0.0f);
3124 for (int i = 0; i < length; i++)
3125 result.put(i, (get(i) != 0.0f) & val ? 1.0f : 0.0f);
3126 return result;
3127 }
3128
3129 /** Compute elementwise logical and against a scalar (in-place). */
3130 public FloatMatrix andi(float value) {
3131 return andi(value, this);
3132 }
3133
3134 /** Compute elementwise logical and against a scalar. */
3135 public FloatMatrix and(float value) {
3136 return andi(value, new FloatMatrix(rows, columns));
3137 }
3138
3139 /** Compute elementwise logical or (in-place). */
3140 public FloatMatrix ori(FloatMatrix other, FloatMatrix result) {
3141 assertSameLength(other);
3142 ensureResultLength(other, result);
3143
3144 for (int i = 0; i < length; i++)
3145 result.put(i, (get(i) != 0.0f) | (other.get(i) != 0.0f) ? 1.0f : 0.0f);
3146 return result;
3147 }
3148
3149 /** Compute elementwise logical or (in-place). */
3150 public FloatMatrix ori(FloatMatrix other) {
3151 return ori(other, this);
3152 }
3153
3154 /** Compute elementwise logical or. */
3155 public FloatMatrix or(FloatMatrix other) {
3156 return ori(other, new FloatMatrix(rows, columns));
3157 }
3158
3159 /** Compute elementwise logical or against a scalar (in-place). */
3160 public FloatMatrix ori(float value, FloatMatrix result) {
3161 ensureResultLength(null, result);
3162 boolean val = (value != 0.0f);
3163 for (int i = 0; i < length; i++)
3164 result.put(i, (get(i) != 0.0f) | val ? 1.0f : 0.0f);
3165 return result;
3166 }
3167
3168 /** Compute elementwise logical or against a scalar (in-place). */
3169 public FloatMatrix ori(float value) {
3170 return ori(value, this);
3171 }
3172
3173 /** Compute elementwise logical or against a scalar. */
3174 public FloatMatrix or(float value) {
3175 return ori(value, new FloatMatrix(rows, columns));
3176 }
3177
3178 /** Compute elementwise logical xor (in-place). */
3179 public FloatMatrix xori(FloatMatrix other, FloatMatrix result) {
3180 assertSameLength(other);
3181 ensureResultLength(other, result);
3182
3183 for (int i = 0; i < length; i++)
3184 result.put(i, (get(i) != 0.0f) ^ (other.get(i) != 0.0f) ? 1.0f : 0.0f);
3185 return result;
3186 }
3187
3188 /** Compute elementwise logical xor (in-place). */
3189 public FloatMatrix xori(FloatMatrix other) {
3190 return xori(other, this);
3191 }
3192
3193 /** Compute elementwise logical xor. */
3194 public FloatMatrix xor(FloatMatrix other) {
3195 return xori(other, new FloatMatrix(rows, columns));
3196 }
3197
3198 /** Compute elementwise logical xor against a scalar (in-place). */
3199 public FloatMatrix xori(float value, FloatMatrix result) {
3200 ensureResultLength(null, result);
3201 boolean val = (value != 0.0f);
3202 for (int i = 0; i < length; i++)
3203 result.put(i, (get(i) != 0.0f) ^ val ? 1.0f : 0.0f);
3204 return result;
3205 }
3206
3207 /** Compute elementwise logical xor against a scalar (in-place). */
3208 public FloatMatrix xori(float value) {
3209 return xori(value, this);
3210 }
3211
3212 /** Compute elementwise logical xor against a scalar. */
3213 public FloatMatrix xor(float value) {
3214 return xori(value, new FloatMatrix(rows, columns));
3215 }
3216 //RJPP-END--------------------------------------------------------------
3217 }