DOLFIN
DOLFIN C++ interface
SUNDIALSNVector.h
1 // Copyright (C) 2017 Chris Hadjigeorgiou and Chris Richardson
2 //
3 // This file is part of DOLFIN.
4 //
5 // DOLFIN is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU Lesser General Public License as published by
7 // the Free Software Foundation, either version 3 of the License, or
8 // (at your option) any later version.
9 //
10 // DOLFIN is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 // GNU Lesser General Public License for more details.
14 //
15 // You should have received a copy of the GNU Lesser General Public License
16 // along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
17 
18 
19 #ifndef __DOLFIN_N_VECTOR_H
20 #define __DOLFIN_N_VECTOR_H
21 
22 #ifdef HAS_SUNDIALS
23 
24 #include <string>
25 #include <utility>
26 #include <memory>
27 #include <dolfin/common/types.h>
28 #include <sundials/sundials_nvector.h>
29 #include "DefaultFactory.h"
30 #include "GenericVector.h"
31 #include "Vector.h"
32 
33 namespace dolfin
34 {
39  {
40  public:
41 
45  SUNDIALSNVector(MPI_Comm comm=MPI_COMM_WORLD)
46  {
47  DefaultFactory factory;
48  vector = factory.create_vector(comm);
49  }
50 
56  SUNDIALSNVector(MPI_Comm comm, std::size_t N)
57  {
58  DefaultFactory factory;
59  vector = factory.create_vector(comm);
60  vector->init(N);
61  N_V = std::unique_ptr<_generic_N_Vector>(new _generic_N_Vector);
62  N_V->ops = &ops;
63  N_V->content = (void *)(this);
64  }
65 
69  SUNDIALSNVector(const SUNDIALSNVector& x) : vector(x.vec()->copy()) {}
70 
74  SUNDIALSNVector(const GenericVector& x) : vector(x.copy())
75  {
76  N_V = std::unique_ptr<_generic_N_Vector>(new _generic_N_Vector);
77  N_V->ops = &ops;
78  N_V->content = (void *)(this);
79  }
80 
84  SUNDIALSNVector(std::shared_ptr<GenericVector> x) : vector(x)
85  {
86  N_V = std::unique_ptr<_generic_N_Vector>(new _generic_N_Vector);
87  N_V->ops = &ops;
88  N_V->content = (void *)(this);
89  }
90  //-----------------------------------------------------------------------------
91 
95  N_Vector nvector() const
96  {
97  N_V->content = (void *)(this);
98  return N_V.get();
99  }
100 
104  std::shared_ptr<GenericVector> vec() const
105  {
106  return vector;
107  }
108 
111  { *vector = *x.vector; return *this; }
112 
113  private:
114 
115  //--- Implementation of N_Vector ops
116 
117  // Get ID for custom SUNDIALSNVector implementation
118  static N_Vector_ID N_VGetVectorID(N_Vector nv)
119  {
120  dolfin_debug("N_VGetVectorID");
121  return SUNDIALS_NVEC_CUSTOM;
122  }
123 
124  // Sets the components of the N_Vector z to be the absolute values of the
125  // components of the N_Vector x
126  static void N_VAbs(N_Vector x, N_Vector z)
127  {
128  dolfin_debug("N_VAbs");
129  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
130  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
131 
132  *vz = *vx;
133  vz->abs();
134  }
135 
137  static void N_VConst(double c, N_Vector z)
138  {
139  dolfin_debug("N_VConst");
140  auto v = static_cast<SUNDIALSNVector *>(z->content)->vec();
141  *v = c;
142  }
143 
147  static N_Vector N_VClone(N_Vector z)
148  {
149  dolfin_debug("N_VClone");
150  auto vz = static_cast<const SUNDIALSNVector *>(z->content);
151 
152  SUNDIALSNVector *new_vector = new SUNDIALSNVector(*vz);
153 
154  _generic_N_Vector *V = new _generic_N_Vector;
155  V->ops = z->ops;
156  V->content = (void *)(new_vector);
157 
158  return V;
159  }
160 
163  static N_Vector N_VCloneEmpty(N_Vector x)
164  {
165  dolfin_debug("N_VCloneEmpty");
166  dolfin_not_implemented();
167  return NULL;
168  }
169 
172  static void N_VDestroy(N_Vector z)
173  {
174  dolfin_debug("N_VDestroy");
175  delete (SUNDIALSNVector*)(z->content);
176  delete z;
177  }
178 
181  static void N_VSpace(N_Vector x, long int *y, long int *z)
182  {
183  dolfin_debug("N_VSpace");
184  dolfin_not_implemented();
185  }
186 
188  static double* N_VGetArrayPointer(N_Vector x)
189  {
190  dolfin_debug("N_VGetArrayPointer");
191  dolfin_not_implemented();
192  return NULL;
193  }
194 
196  static void N_VSetArrayPointer(double* c,N_Vector x)
197  {
198  dolfin_debug("N_VSetArrayPointer");
199  dolfin_not_implemented();
200  }
201 
204  static void N_VProd(N_Vector x, N_Vector y, N_Vector z)
205  {
206  dolfin_debug("N_VProd");
207  auto vx = static_cast<const SUNDIALSNVector*>(x->content)->vec();
208  auto vy = static_cast<const SUNDIALSNVector*>(y->content)->vec();
209  auto vz = static_cast<SUNDIALSNVector*>(z->content)->vec();
210 
211  // Copy x to z
212  *vz = *vx;
213  // Multiply by y
214  *vz *= *vy;
215  }
216 
219  static void N_VDiv(N_Vector x, N_Vector y, N_Vector z)
220  {
221  dolfin_debug("N_VDiv");
222 
223  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
224  auto vy = static_cast<const SUNDIALSNVector *>(y->content)->vec();
225  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
226 
227  std::vector<double> xdata;
228  vx->get_local(xdata);
229  std::vector<double> ydata;
230  vy->get_local(ydata);
231  for (unsigned int i = 0; i != xdata.size(); ++i)
232  xdata[i] /= ydata[i];
233 
234  vz->set_local(xdata);
235  vz->apply("insert");
236 
237  }
238 
240  static void N_VScale(double c, N_Vector x, N_Vector z)
241  {
242  dolfin_debug("N_VScale");
243  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
244  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
245 
246  // z = c*x
247  *vz = *vx;
248  *vz *= c;
249  }
250 
253  static void N_VInv(N_Vector x, N_Vector z)
254  {
255  dolfin_debug("N_VInv");
256  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
257  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
258 
259  // z = 1/x
260  std::vector<double> xvals;
261  vx->get_local(xvals);
262  for (auto &val : xvals)
263  val = 1.0/val;
264  vz->set_local(xvals);
265  vz->apply("insert");
266  }
267 
270  static void N_VAddConst(N_Vector x, double c, N_Vector z)
271  {
272  dolfin_debug("N_VAddConst");
273  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
274  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
275 
276  *vz = *vx;
277  *vz += c;
278  }
279 
281  static double N_VDotProd(N_Vector x, N_Vector z)
282  {
283  dolfin_debug("N_VDotProd");
284  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
285  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
286 
287  return vx->inner(*vz);
288  }
289 
291  static double N_VMaxNorm(N_Vector x)
292  {
293  dolfin_debug("N_VMaxNorm");
294  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
295  auto vy = vx->copy();
296  vy->abs();
297  return vy->max();
298  }
299 
301  static double N_VMin(N_Vector x)
302  {
303  dolfin_debug("N_VMin");
304  return (static_cast<const SUNDIALSNVector *>(x->content)->vec())->min();
305  }
306 
309  static void N_VLinearSum(double a, N_Vector x, double b, N_Vector y, N_Vector z)
310  {
311  dolfin_debug("N_VLinearSum");
312  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
313  auto vy = static_cast<const SUNDIALSNVector *>(y->content)->vec();
314  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
315 
316  std::vector<double> xdata;
317  vx->get_local(xdata);
318  std::vector<double> ydata;
319  vy->get_local(ydata);
320 
321  for (unsigned int i = 0; i != xdata.size(); ++i)
322  xdata[i] = a*xdata[i] + b*ydata[i];
323 
324  vz->set_local(xdata);
325  vz->apply("insert");
326  }
327 
330  static double N_VWrmsNorm(N_Vector x, N_Vector z)
331  {
332  dolfin_debug("N_VWrmsNorm");
333  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
334  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
335 
336  auto y = vx->copy();
337  *y *= *vz;
338  return y->norm("l2")/std::sqrt(y->size());
339  }
340 
344  static double N_VWrmsNormMask(N_Vector x, N_Vector y, N_Vector z)
345  {
346  dolfin_debug("N_VWrmsNormMask");
347  dolfin_not_implemented();
348  return 0.0;
349  }
350 
353  static double N_VWl2Norm(N_Vector x, N_Vector z )
354  {
355  dolfin_debug("N_VWl2Norm");
356  dolfin_not_implemented();
357  return 0.0;
358  }
359 
361  static double N_VL1Norm(N_Vector x )
362  {
363  dolfin_debug("N_VL1Norm");
364  dolfin_not_implemented();
365  return 0.0;
366  }
367 
370  static void N_VCompare(double c, N_Vector x, N_Vector z)
371  {
372  dolfin_debug("N_VCompare");
373  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
374  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
375  std::vector<double> xvals;
376  vx->get_local(xvals);
377  for (auto &val : xvals)
378  val = (std::abs(val) >= c) ? 1.0 : 0.0;
379  vz->set_local(xvals);
380  vz->apply("insert");
381  }
382 
385  static int N_VInvTest(N_Vector x, N_Vector z)
386  {
387  dolfin_debug("N_VInvTest");
388  int no_zero_found = true;
389  auto vx = static_cast<const SUNDIALSNVector *>(x->content)->vec();
390  auto vz = static_cast<SUNDIALSNVector *>(z->content)->vec();
391 
392  std::vector<double> xvals;
393  vx->get_local(xvals);
394  for (auto &val : xvals)
395  if(val != 0)
396  val = 1.0/val;
397  else
398  no_zero_found = false;
399  vz->set_local(xvals);
400 
401  vz->apply("insert");
402 
403  return no_zero_found;
404  }
405 
406 
409  static double N_VMinQuotient(N_Vector x, N_Vector z )
410  {
411  dolfin_debug("N_VConstrMask");
412  dolfin_not_implemented();
413  return 0.0;
414  }
415 
417  static int N_VConstrMask(N_Vector x, N_Vector y, N_Vector z )
418  {
419  dolfin_debug("N_VConstrMask");
420  dolfin_not_implemented();
421  return 0;
422  }
423 
424  // Pointer to concrete implementation
425  std::shared_ptr<GenericVector> vector;
426 
427  // Pointer to SUNDIALS struct
428  std::unique_ptr<_generic_N_Vector> N_V;
429 
430  // Structure containing function pointers to vector operations
431  struct _generic_N_Vector_Ops ops = {N_VGetVectorID, // N_Vector_ID (*N_VGetVectorID)(SUNDIALSNVector);
432  N_VClone, // NVector (*N_VClone)(NVector);
433  N_VCloneEmpty, // NVector (*N_VCloneEmpty)(NVector);
434  N_VDestroy, // void (*N_VDestroy)(NVector);
435  NULL, //N_VSpace, // void (*N_VSpace)(NVector, long int *, long int *);
436  N_VGetArrayPointer, // realtype* (*N_VGetArrayPointer)(NVector);
437  N_VSetArrayPointer, // void (*N_VSetArrayPointer)(realtype *, NVector);
438  N_VLinearSum, // void (*N_VLinearSum)(realtype, NVector, realtype, NVector, NVector);
439  N_VConst, // void (*N_VConst)(realtype, NVector);
440  N_VProd, // void (*N_VProd)(NVector, NVector, NVector);
441  N_VDiv, // void (*N_VDiv)(NVector, NVector, NVector);
442  N_VScale, // void (*N_VScale)(realtype, NVector, NVector);
443  N_VAbs, // void (*N_VAbs)(NVector, NVector);
444  N_VInv, // void (*N_VInv)(NVector, NVector);
445  N_VAddConst, // void (*N_VAddConst)(NVector, realtype, NVector);
446  N_VDotProd, // realtype (*N_VDotProd)(NVector, NVector);
447  N_VMaxNorm, // realtype (*N_VMaxNorm)(NVector);
448  N_VWrmsNorm, // realtype (*N_VWrmsNorm)(NVector, NVector);
449  N_VWrmsNormMask, // realtype (*N_VWrmsNormMask)(NVector, NVector, NVector);
450  N_VMin, // realtype (*N_VMin)(NVector);
451  N_VWl2Norm, // realtype (*N_VWl2Norm)(NVector, NVector);
452  N_VL1Norm, // realtype (*N_VL1Norm)(NVector);
453  N_VCompare, // void (*N_VCompare)(realtype, NVector, NVector);
454  N_VInvTest, // booleantype (*N_VInvtest)(NVector, NVector);
455  N_VConstrMask, // booleantype (*N_VConstrMask)(NVector, NVector, NVector);
456  N_VMinQuotient}; // realtype (*N_VMinQuotient)(NVector, NVector);
457  };
458 
459 
460 }
461 
462 #endif
463 
464 #endif
const SUNDIALSNVector & operator=(const SUNDIALSNVector &x)
Assignment operator.
Definition: SUNDIALSNVector.h:110
SUNDIALSNVector(const SUNDIALSNVector &x)
Definition: SUNDIALSNVector.h:69
SUNDIALSNVector(MPI_Comm comm=MPI_COMM_WORLD)
Definition: SUNDIALSNVector.h:45
std::shared_ptr< GenericVector > vec() const
Definition: SUNDIALSNVector.h:104
Definition: adapt.h:29
virtual std::shared_ptr< GenericVector > create_vector(MPI_Comm comm) const
Create empty vector.
Definition: DefaultFactory.cpp:37
Definition: SUNDIALSNVector.h:38
SUNDIALSNVector(std::shared_ptr< GenericVector > x)
Definition: SUNDIALSNVector.h:84
SUNDIALSNVector(MPI_Comm comm, std::size_t N)
Definition: SUNDIALSNVector.h:56
N_Vector nvector() const
Definition: SUNDIALSNVector.h:95
This class defines a common interface for vectors.
Definition: GenericVector.h:47
Default linear algebra factory based on global parameter "linear_algebra_backend".
Definition: DefaultFactory.h:35
SUNDIALSNVector(const GenericVector &x)
Definition: SUNDIALSNVector.h:74