-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathim_spring_solver.cpp
More file actions
157 lines (128 loc) · 5.05 KB
/
Copy pathim_spring_solver.cpp
File metadata and controls
157 lines (128 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include "spring_solver.h"
#include <iostream>
#include <stdexcept>
#include <cvode/cvode.h> /* main integrator header file */
#include <cvode/cvode_spgmr.h> /* prototypes & constants for CVSPGMR solver */
#include <cvode/cvode_spbcgs.h> /* prototypes & constants for CVSPBCG solver */
#include <cvode/cvode_sptfqmr.h> /* prototypes & constants for CVSPTFQMR solver */
#include <nvector/nvector_serial.h> /* serial N_Vector types, fct. and macros */
#include <sundials/sundials_dense.h> /* use generic DENSE solver in preconditioning */
#include <sundials/sundials_types.h> /* definition of realtype */
#include <sundials/sundials_math.h> /* contains the macros ABS, SUNSQR, and EXP */
static void updateSpringVertex(N_Vector, std::vector<SpringVertex*>&);
static void setInitialCondition(N_Vector, std::vector<SpringVertex*>&);
static int f(realtype, N_Vector, N_Vector, void*);
/*static int Jac(long int N, long int mu, long int ml,
realtype t, N_Vector u, N_Vector fu,
DlsMat J, void *user_data,
N_Vector tmp1, N_Vector tmp2, N_Vector tmp3);*/
static int check_flag(void *flagvalue, const char *funcname, int opt);
void IM_SPRING_SOLVER::doSolve(double t) {
int NEQ = pts.size() * 6; //number of equations
//const int MAX_NB = 7; //number of neighbours
//create vector
N_Vector u = N_VNew_Serial(NEQ);
//set tolerances
realtype reltol = RCONST(0.0);
realtype abstol = RCONST(1.0e-5);
//set initial condition
setInitialCondition(u, pts);
//Create ode solver
void *cvode_mem = CVodeCreate(CV_BDF, CV_NEWTON);
//Initialize
int flag = CVodeInit(cvode_mem, f, 0, u);
//set tolerance
flag = CVodeSStolerances(cvode_mem, reltol, abstol);
//sIMet pointer to user-defined data
flag = CVodeSetUserData(cvode_mem, (void*)this);
//call CVSpgmr to specify the KSP linear solver
flag = CVSpbcg(cvode_mem, PREC_NONE, 0);
//try to add preconditioner if iteration too many times
//flag = CVSpilsSetPreconditioner(cvode_mem, Precond, PSolve);
realtype tret;
long nst;
//realtype umax = N_VMaxNorm(u);
for (size_t i = 0; i < ext_forces.size(); ++i)
ext_forces[i]->computeExternalForce();
//solve ode
flag = CVode(cvode_mem, t, u, &tret, CV_NORMAL);
flag = CVodeGetNumSteps(cvode_mem, &nst);
if (check_flag(&flag, "CVodeGetNumSteps", 1))
throw std::runtime_error("CVodeGetNumSteps");
updateSpringVertex(u, pts);
N_VDestroy_Serial(u);
CVodeFree(&cvode_mem);
}
static void updateSpringVertex(N_Vector u, std::vector<SpringVertex*>& pts) {
realtype *udata = N_VGetArrayPointer_Serial(u);
for (size_t i = 0; i < pts.size(); ++i) {
pts[i]->x[0] = udata[i*6];
pts[i]->x[1] = udata[i*6+1];
pts[i]->x[2] = udata[i*6+2];
pts[i]->v[0] = udata[i*6+3];
pts[i]->v[1] = udata[i*6+4];
pts[i]->v[2] = udata[i*6+5];
}
}
static void setInitialCondition(N_Vector u, std::vector<SpringVertex*>& pts) {
realtype *udata = N_VGetArrayPointer_Serial(u);
for (size_t i = 0; i < pts.size(); ++i) {
udata[i*6] = pts[i]->x[0];
udata[i*6+1] = pts[i]->x[1];
udata[i*6+2] = pts[i]->x[2];
udata[i*6+3] = pts[i]->v[0];
udata[i*6+4] = pts[i]->v[1];
udata[i*6+5] = pts[i]->v[2];
}
}
static int f(realtype t, N_Vector u, N_Vector udot, void* user_data) {
realtype *dudata;
IM_SPRING_SOLVER* sp_solver = (IM_SPRING_SOLVER*)user_data;
std::vector<SpringVertex*>& pts = sp_solver->getSpringMesh();
dudata = N_VGetArrayPointer_Serial(udot);
//update the coords and velocity in spring vertex
//and compute acceleration
updateSpringVertex(u, pts);
for (size_t i = 0; i < pts.size(); ++i)
sp_solver->computeAccel(pts[i]);
for (size_t i = 0; i < pts.size(); ++i) {
//dx = v
dudata[i*6] = pts[i]->v[0];
dudata[i*6+1] = pts[i]->v[1];
dudata[i*6+2] = pts[i]->v[2];
//dv = f
dudata[i*6+3] = pts[i]->accel[0];
dudata[i*6+4] = pts[i]->accel[1];
dudata[i*6+5] = pts[i]->accel[2];
}
return 0;
}
/*static int Jac(long int N, long int mu, long int ml,
realtype t, N_Vector u, N_Vector fu,
DlsMat J, void *user_data,
N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
{
return 0;
}*/
static int check_flag(void *flagvalue, const char *funcname, int opt)
{
int *errflag;
/* Check if SUNDIALS function returned NULL pointer - no memory allocated */
if (opt == 0 && flagvalue == NULL) {
fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n",
funcname);
return(1); }
/* Check if flag < 0 */
else if (opt == 1) {
errflag = (int *) flagvalue;
if (*errflag < 0) {
fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with flag = %d\n\n",
funcname, *errflag);
return(1); }}
/* Check if function returned NULL pointer - no memory allocated */
else if (opt == 2 && flagvalue == NULL) {
fprintf(stderr, "\nMEMORY_ERROR: %s() failed - returned NULL pointer\n\n",
funcname);
return(1); }
return(0);
}