Actual source code: bjkokkos.kokkos.cxx

  1: #include <petscvec_kokkos.hpp>
  2: #include <petsc/private/pcimpl.h>
  3: #include <petsc/private/kspimpl.h>
  4: #include <petscksp.h>
  5: #include "petscsection.h"
  6: #include <petscdmcomposite.h>
  7: #include <Kokkos_Core.hpp>

  9: typedef Kokkos::TeamPolicy<>::member_type team_member;

 11: #include <../src/mat/impls/aij/seq/aij.h>
 12: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>

 14: #define PCBJKOKKOS_SHARED_LEVEL 1
 15: #define PCBJKOKKOS_VEC_SIZE 16
 16: #define PCBJKOKKOS_TEAM_SIZE 16
 17: #define PCBJKOKKOS_VERBOSE_LEVEL 0

 19: typedef enum {BATCH_KSP_BICG_IDX,BATCH_KSP_TFQMR_IDX,BATCH_KSP_GMRES_IDX,NUM_BATCH_TYPES} KSPIndex;
 20: typedef struct {
 21:   Vec                                              vec_diag;
 22:   PetscInt                                         nBlocks; /* total number of blocks */
 23:   PetscInt                                         n; // cache host version of d_bid_eqOffset_k[nBlocks]
 24:   KSP                                              ksp; // Used just for options. Should have one for each block
 25:   Kokkos::View<PetscInt*, Kokkos::LayoutRight>     *d_bid_eqOffset_k;
 26:   Kokkos::View<PetscScalar*, Kokkos::LayoutRight>  *d_idiag_k;
 27:   Kokkos::View<PetscInt*>                          *d_isrow_k;
 28:   Kokkos::View<PetscInt*>                          *d_isicol_k;
 29:   KSPIndex                                         ksp_type_idx;
 30:   PetscInt                                         nwork;
 31:   PetscInt                                         const_block_size; // used to decide to use shared memory for work vectors
 32:   PetscInt                                         *dm_Nf;  // Number of fields in each DM
 33:   PetscInt                                         num_dms;
 34:   // diagnostics
 35:   PetscBool                                        reason;
 36:   PetscBool                                        monitor;
 37:   PetscInt                                         batch_target;
 38: } PC_PCBJKOKKOS;

 40: static PetscErrorCode  PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
 41: {
 42:   const char    *prefix;
 43:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
 44:   DM             dm;

 46:   KSPCreate(PetscObjectComm((PetscObject)pc),&jac->ksp);
 47:   KSPSetErrorIfNotConverged(jac->ksp,pc->erroriffailure);
 48:   PetscObjectIncrementTabLevel((PetscObject)jac->ksp,(PetscObject)pc,1);
 49:   PCGetOptionsPrefix(pc,&prefix);
 50:   KSPSetOptionsPrefix(jac->ksp,prefix);
 51:   KSPAppendOptionsPrefix(jac->ksp,"pc_bjkokkos_");
 52:   PCGetDM(pc,&dm);
 53:   if (dm) {
 54:     KSPSetDM(jac->ksp, dm);
 55:     KSPSetDMActive(jac->ksp, PETSC_FALSE);
 56:   }
 57:   jac->reason       = PETSC_FALSE;
 58:   jac->monitor      = PETSC_FALSE;
 59:   jac->batch_target = 0;
 60:   return 0;
 61: }

 63: // y <-- Ax
 64: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team,  const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
 65: {
 66:   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
 67:       int rowa = ic[rowb];
 68:       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
 69:       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
 70:       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
 71:       PetscScalar sum;
 72:       Kokkos::parallel_reduce(Kokkos::ThreadVectorRange (team, n), [=] (const int i, PetscScalar& lsum) {
 73:           lsum += aa[i] * x_loc[r[aj[i]]-start];
 74:         }, sum);
 75:       Kokkos::single(Kokkos::PerThread (team),[=]() {y_loc[rowb-start] = sum;});
 76:     });
 77:   team.team_barrier();
 78:   return 0;
 79: }

 81: // temp buffer per thread with reduction at end?
 82: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
 83: {
 84:   Kokkos::parallel_for(Kokkos::TeamVectorRange(team,end-start), [=] (int i) { y_loc[i] = 0;});
 85:   team.team_barrier();
 86:   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
 87:       int rowa = ic[rowb];
 88:       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
 89:       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
 90:       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
 91:       const PetscScalar xx = x_loc[rowb-start]; // rowb = ic[rowa] = ic[r[rowb]]
 92:       Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,n), [=] (const int &i) {
 93:           PetscScalar val = aa[i] * xx;
 94:           Kokkos::atomic_fetch_add(&y_loc[r[aj[i]]-start], val);
 95:         });
 96:     });
 97:   team.team_barrier();
 98:   return 0;
 99: }

101: typedef struct Batch_MetaData_TAG
102: {
103:   PetscInt           flops;
104:   PetscInt           its;
105:   KSPConvergedReason reason;
106: }Batch_MetaData;

108: // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
109: KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
110: {
111:   using Kokkos::parallel_reduce;
112:   using Kokkos::parallel_for;
113:   int               Nblk = end-start, i,m;
114:   PetscReal         dp,dpold,w,dpest,tau,psi,cm,r0;
115:   PetscScalar       *ptr = work_space, rho,rhoold,a,s,b,eta,etaold,psiold,cf,dpi;
116:   const PetscScalar *Diag = &glb_idiag[start];
117:   PetscScalar       *XX = ptr; ptr += stride;
118:   PetscScalar       *R = ptr; ptr += stride;
119:   PetscScalar       *RP = ptr; ptr += stride;
120:   PetscScalar       *V = ptr; ptr += stride;
121:   PetscScalar       *T = ptr; ptr += stride;
122:   PetscScalar       *Q = ptr; ptr += stride;
123:   PetscScalar       *P = ptr; ptr += stride;
124:   PetscScalar       *U = ptr; ptr += stride;
125:   PetscScalar       *D = ptr; ptr += stride;
126:   PetscScalar       *AUQ = V;

128:   // init: get b, zero x
129:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
130:       int rowa = ic[rowb];
131:       R[rowb-start] = glb_b[rowa];
132:       XX[rowb-start] = 0;
133:     });
134:   team.team_barrier();
135:   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
136:   team.team_barrier();
137:   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
138:   // diagnostics
139: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
140:   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
141: #endif
142:   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
143:   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}

145:   /* Make the initial Rp = R */
146:   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {RP[idx] = R[idx];});
147:   team.team_barrier();
148:   /* Set the initial conditions */
149:   etaold = 0.0;
150:   psiold = 0.0;
151:   tau    = dp;
152:   dpold  = dp;

154:   /* rhoold = (r,rp)     */
155:   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rhoold);
156:   team.team_barrier();
157:   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx]; P[idx] = R[idx]; T[idx] = Diag[idx]*P[idx]; D[idx] = 0;});
158:   team.team_barrier();
159:   MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);

161:   i=0;
162:   do {
163:     /* s <- (v,rp)          */
164:     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += V[idx]*PetscConj(RP[idx]);}, s);
165:     team.team_barrier();
166:     a    = rhoold / s;                              /* a <- rho / s         */
167:     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
168:     /* t <- u + q           */
169:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Q[idx] = U[idx] - a*V[idx]; T[idx] = U[idx] + Q[idx];});
170:     team.team_barrier();
171:     // KSP_PCApplyBAorAB
172:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*T[idx]; });
173:     team.team_barrier();
174:     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,AUQ);
175:     /* r <- r - a K (u + q) */
176:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {R[idx] = R[idx] - a*AUQ[idx]; });
177:     team.team_barrier();
178:     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
179:     team.team_barrier();
180:     dp = PetscSqrtReal(PetscRealPart(dpi));
181:     for (m=0; m<2; m++) {
182:       if (!m) w = PetscSqrtReal(dp*dpold);
183:       else w = dp;
184:       psi = w / tau;
185:       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
186:       tau = tau * psi * cm;
187:       eta = cm * cm * a;
188:       cf  = psiold * psiold * etaold / a;
189:       if (!m) {
190:         /* D = U + cf D */
191:         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = U[idx] + cf*D[idx]; });
192:       } else {
193:         /* D = Q + cf D */
194:         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = Q[idx] + cf*D[idx]; });
195:       }
196:       team.team_barrier();
197:       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + eta*D[idx]; });
198:       team.team_barrier();
199:       dpest = PetscSqrtReal(2*i + m + 2.0) * tau;
200: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
201:       if (monitor && m==1) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dpest);});
202: #endif
203:       if (dpest < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
204:       if (dpest/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
205: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
206:       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dpest,r0);}); goto done;}
207: #else
208:       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
209: #endif
210:       if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}

212:       etaold = eta;
213:       psiold = psi;
214:     }

216:     /* rho <- (r,rp)       */
217:     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rho);
218:     team.team_barrier();
219:     b    = rho / rhoold;                            /* b <- rho / rhoold   */
220:     /* u <- r + b q        */
221:     /* p <- u + b(q + b p) */
222:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx] + b*Q[idx]; Q[idx] = Q[idx] + b*P[idx]; P[idx] = U[idx] + b*Q[idx];});
223:     /* v <- K p  */
224:     team.team_barrier();
225:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*P[idx]; });
226:     team.team_barrier();
227:     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);

229:     rhoold = rho;
230:     dpold  = dp;

232:     i++;
233:   } while (i<maxit);
234:   done:
235:   // KSPUnwindPreconditioner
236:   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = Diag[idx]*XX[idx]; });
237:   team.team_barrier();
238:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
239:       int rowa = ic[rowb];
240:       glb_x[rowa] = XX[rowb-start];
241:     });
242:   metad->its = i+1;
243:   if (1) {
244:     int nnz;
245:     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
246:     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
247:   } else {
248:     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
249:   }
250:   return 0;
251: }

253: // Solve Ax = y with biCG
254: KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
255: {
256:   using Kokkos::parallel_reduce;
257:   using Kokkos::parallel_for;
258:   int               Nblk = end-start, i;
259:   PetscReal         dp, r0;
260:   PetscScalar       *ptr = work_space, dpi, a=1.0, beta, betaold=1.0, b, b2, ma, mac;
261:   const PetscScalar *Di = &glb_idiag[start];
262:   PetscScalar       *XX = ptr; ptr += stride;
263:   PetscScalar       *Rl = ptr; ptr += stride;
264:   PetscScalar       *Zl = ptr; ptr += stride;
265:   PetscScalar       *Pl = ptr; ptr += stride;
266:   PetscScalar       *Rr = ptr; ptr += stride;
267:   PetscScalar       *Zr = ptr; ptr += stride;
268:   PetscScalar       *Pr = ptr; ptr += stride;

270:   /*     r <- b (x is 0) */
271:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
272:       int rowa = ic[rowb];
273:       //VecCopy(Rr,Rl);
274:       Rl[rowb-start] = Rr[rowb-start] = glb_b[rowa];
275:       XX[rowb-start] = 0;
276:     });
277:   team.team_barrier();
278:   /*     z <- Br         */
279:   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx]; });
280:   team.team_barrier();
281:   /*    dp <- r'*r       */
282:   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Rr[idx]*PetscConj(Rr[idx]);}, dpi);
283:   team.team_barrier();
284:   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
285: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
286:   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
287: #endif
288:   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
289:   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
290:   i = 0;
291:   do {
292:     /*     beta <- r'z     */
293:     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += Zr[idx]*PetscConj(Rl[idx]);}, beta);
294:     team.team_barrier();
295: #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
296: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
297:     Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("%7d beta = Z.R = %22.14e \n",i,(double)beta);});
298: #endif
299: #endif
300:     if (!i) {
301:       if (beta == 0.0) {
302:         metad->reason = KSP_DIVERGED_BREAKDOWN_BICG;
303:         goto done;
304:       }
305:       /*     p <- z          */
306:       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = Zr[idx]; Pl[idx] = Zl[idx];});
307:     } else {
308:       b    = beta/betaold;
309:       /*     p <- z + b* p   */
310:       b2    = PetscConj(b);
311:       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = b*Pr[idx] + Zr[idx]; Pl[idx] = b2*Pl[idx] + Zl[idx];});
312:     }
313:     team.team_barrier();
314:     betaold = beta;
315:     /*     z <- Kp         */
316:     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pr,Zr);
317:     MatMultTranspose(team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pl,Zl);
318:     /*     dpi <- z'p      */
319:     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Zr[idx]*PetscConj(Pl[idx]);}, dpi);
320:     team.team_barrier();
321:     //
322:     a       = beta/dpi;                           /*     a = beta/p'z    */
323:     ma      = -a;
324:     mac      = PetscConj(ma);
325:     /*     x <- x + ap     */
326:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + a*Pr[idx]; Rr[idx] = Rr[idx] + ma*Zr[idx]; Rl[idx] = Rl[idx] + mac*Zl[idx];});team.team_barrier();
327:     team.team_barrier();
328:     /*    dp <- r'*r       */
329:     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum +=  Rr[idx]*PetscConj(Rr[idx]);}, dpi);
330:     team.team_barrier();
331:     dp = PetscSqrtReal(PetscRealPart(dpi));
332: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
333:     if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dp);});
334: #endif
335:     if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
336:     if (dp/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
337: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
338:     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dp,r0);}); goto done;}
339: #else
340:     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
341: #endif
342:     if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
343:     /* z <- Br  */
344:     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx];});
345:     i++;
346:     team.team_barrier();
347:   } while (i<maxit);
348:  done:
349:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
350:       int rowa = ic[rowb];
351:       glb_x[rowa] = XX[rowb-start];
352:     });
353:   metad->its = i+1;
354:   if (1) {
355:     int nnz;
356:     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
357:     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
358:   } else {
359:     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
360:   }
361:   return 0;
362: }

364: // KSP solver solve Ax = b; x is output, bin is input
365: static PetscErrorCode PCApply_BJKOKKOS(PC pc,Vec bin,Vec xout)
366: {
367:   PC_PCBJKOKKOS    *jac = (PC_PCBJKOKKOS*)pc->data;
368:   Mat               A   = pc->pmat;
369:   Mat_SeqAIJKokkos *aijkok;

372:   aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
373:   if (!aijkok) {
374:     SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"No aijkok");
375:   } else {
376:     using scr_mem_t  = Kokkos::DefaultExecutionSpace::scratch_memory_space;
377:     using vect2D_scr_t = Kokkos::View<PetscScalar**, Kokkos::LayoutLeft, scr_mem_t>;
378:     PetscInt          *d_bid_eqOffset, maxit = jac->ksp->max_it, scr_bytes_team, stride, global_buff_size;
379:     const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
380:     const PetscInt    nwork = jac->nwork, nBlk = jac->nBlocks;
381:     PetscScalar       *glb_xdata=NULL;
382:     PetscReal         rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
383:     const PetscScalar *glb_idiag =jac->d_idiag_k->data(), *glb_bdata=NULL;
384:     const PetscInt    *glb_Aai = aijkok->i_device_data(), *glb_Aaj = aijkok->j_device_data();
385:     const PetscScalar *glb_Aaa = aijkok->a_device_data();
386:     Kokkos::View<Batch_MetaData*, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
387:     PCFailedReason    pcreason;
388:     KSPIndex          ksp_type_idx = jac->ksp_type_idx;
389:     PetscMemType      mtype;
390:     PetscContainer    container;
391:     PetscInt          batch_sz;
392:     VecScatter        plex_batch=NULL;
393:     Vec               bvec;
394:     PetscBool         monitor = jac->monitor; // captured
395:     PetscInt          view_bid = jac->batch_target;
396:     // get field major is to map plex IO to/from block/field major
397:     PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);
398:     VecDuplicate(bin,&bvec);
399:     if (container) {
400:       PetscContainerGetPointer(container, (void **) &plex_batch);
401:       VecScatterBegin(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);
402:       VecScatterEnd(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);
403:     } else {
404:       VecCopy(bin, bvec);
405:     }
406:     // get x
407:     VecGetArrayAndMemType(xout,&glb_xdata,&mtype);
408: #if defined(PETSC_HAVE_CUDA)
410: #endif
411:     VecGetArrayReadAndMemType(bvec,&glb_bdata,&mtype);
412: #if defined(PETSC_HAVE_CUDA)
414: #endif
415:     // get batch size
416:     PetscObjectQuery((PetscObject) A, "batch size", (PetscObject *) &container);
417:     if (container) {
418:       PetscInt *pNf=NULL;
419:       PetscContainerGetPointer(container, (void **) &pNf);
420:       batch_sz = *pNf;
421:     } else batch_sz = 1;
423:     d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
424:     // solve each block independently
425:     if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
426:       scr_bytes_team = jac->const_block_size*nwork*sizeof(PetscScalar);
427:       stride = jac->const_block_size; // captured
428:       global_buff_size = 0;
429:     } else {
430:       scr_bytes_team = 0;
431:       stride = jac->n; // captured
432:       global_buff_size = jac->n*nwork;
433:     }
434:     Kokkos::View<PetscScalar*, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_size); // global work vectors
435:     PetscInfo(pc,"\tn = %" PetscInt_FMT ". %d shared mem words/team. %" PetscInt_FMT " global mem words, rtol=%e, num blocks %" PetscInt_FMT ", team_size=%" PetscInt_FMT ", %" PetscInt_FMT " vector threads\n",jac->n,scr_bytes_team/sizeof(PetscScalar),global_buff_size,rtol,nBlk,
436:                team_size, PCBJKOKKOS_VEC_SIZE);
437:     PetscScalar  *d_work_vecs = scr_bytes_team ? NULL : d_work_vecs_k.data();
438:     const PetscInt *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
439:     Kokkos::parallel_for("Solve", Kokkos::TeamPolicy<>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team)),
440:         KOKKOS_LAMBDA (const team_member team) {
441:         const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID+1];
442:         vect2D_scr_t work_vecs(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), scr_bytes_team ? (end-start) : 0, nwork);
443:         PetscScalar *work_buff = (scr_bytes_team) ? work_vecs.data() : &d_work_vecs[start];
444:         bool        print = monitor && (blkID==view_bid);
445:         switch (ksp_type_idx) {
446:         case BATCH_KSP_BICG_IDX:
447:           BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
448:           break;
449:         case BATCH_KSP_TFQMR_IDX:
450:           BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
451:           break;
452:         case BATCH_KSP_GMRES_IDX:
453: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
454:           printf("GMRES not implemented %d\n",ksp_type_idx);
455: #else
456:           /* void */
457: #endif
458:           break;
459:         default:
460: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
461:           printf("Unknown KSP type %d\n",ksp_type_idx);
462: #else
463:           /* void */;
464: #endif
465:         }
466:     });
467:     auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
468:     Kokkos::fence();
469:     Kokkos::deep_copy (h_metadata, d_metadata);
470: #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
471: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
472:     PetscPrintf(PETSC_COMM_WORLD,"Iterations\n");
473: #endif
474:     // assume species major
475: #if PCBJKOKKOS_VERBOSE_LEVEL < 4
476:     PetscPrintf(PETSC_COMM_WORLD,"max iterations per species (%s) :",ksp_type_idx==BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr");
477: #endif
478:     for (PetscInt dmIdx=0, s=0, head=0 ; dmIdx < jac->num_dms; dmIdx += batch_sz) {
479:       for (PetscInt f=0, idx=head ; f < jac->dm_Nf[dmIdx] ; f++,s++,idx++) {
480: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
481:         PetscPrintf(PETSC_COMM_WORLD,"%2D:", s);
482:         for (int bid=0 ; bid<batch_sz ; bid++) {
483:          PetscPrintf(PETSC_COMM_WORLD,"%3D ", h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its);
484:         }
485:         PetscPrintf(PETSC_COMM_WORLD,"\n");
486: #else
487:         PetscInt count=0;
488:         for (int bid=0 ; bid<batch_sz ; bid++) {
489:           if (h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its;
490:         }
491:         PetscPrintf(PETSC_COMM_WORLD,"%3D ", count);
492: #endif
493:       }
494:       head += batch_sz*jac->dm_Nf[dmIdx];
495:     }
496: #if PCBJKOKKOS_VERBOSE_LEVEL < 4
497:     PetscPrintf(PETSC_COMM_WORLD,"\n");
498: #endif
499: #endif
500:     PetscInt count=0, mbid=0;
501:     for (int blkID=0;blkID<nBlk;blkID++) {
502:       PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops);
503:       if (jac->reason) {
504:         if (jac->batch_target==blkID) {
505:           PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID%batch_sz, blkID/batch_sz);
506:         } else if (jac->batch_target==-1 && h_metadata[blkID].its > count) {
507:           count = h_metadata[blkID].its;
508:           mbid = blkID;
509:         }
510:         if (h_metadata[blkID].reason < 0) {
511:           PetscCall(PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n",
512:                               KSPConvergedReasons[h_metadata[blkID].reason],h_metadata[blkID].its,blkID/batch_sz,blkID%batch_sz));
513:         }
514:       }
515:     }
516:     if (jac->batch_target==-1 && jac->reason) {
517:       PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[mbid].reason], h_metadata[mbid].its,mbid%batch_sz,mbid/batch_sz);
518:     }
519:     VecRestoreArrayAndMemType(xout,&glb_xdata);
520:     VecRestoreArrayReadAndMemType(bvec,&glb_bdata);
521:     {
522:       int errsum;
523:       Kokkos::parallel_reduce(nBlk, KOKKOS_LAMBDA (const int idx, int& lsum) {
524:           if (d_metadata[idx].reason < 0) ++lsum;
525:         }, errsum);
526:       pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
527:     }
528:     PCSetFailedReason(pc,pcreason);
529:     // map back to Plex space
530:     if (plex_batch) {
531:       VecCopy(xout, bvec);
532:       VecScatterBegin(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);
533:       VecScatterEnd(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);
534:     }
535:     VecDestroy(&bvec);
536:   }
537:   return 0;
538: }

540: static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
541: {
542:   PC_PCBJKOKKOS    *jac = (PC_PCBJKOKKOS*)pc->data;
543:   Mat               A   = pc->pmat;
544:   Mat_SeqAIJKokkos *aijkok;
545:   PetscBool         flg;

549:   PetscObjectTypeCompareAny((PetscObject)A,&flg,MATSEQAIJKOKKOS,MATMPIAIJKOKKOS,MATAIJKOKKOS,"");
551:   if (!(aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr))) {
552:     SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No aijkok");
553:   } else {
554:     if (!jac->vec_diag) {
555:       Vec               *subX;
556:       DM                pack,*subDM;
557:       PetscInt          nDMs, n;
558:       PetscContainer    container;
559:       PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);
560:       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
561:         MatOrderingType   rtype;
562:         IS                isrow,isicol;
563:         const PetscInt    *rowindices,*icolindices;

565:         if (container) rtype = MATORDERINGNATURAL; // if we have a vecscatter then don't reorder here (all the reorder stuff goes away in future)
566:         else rtype = MATORDERINGRCM;
567:         // get permutation. Not what I expect so inverted here
568:         MatGetOrdering(A,rtype,&isrow,&isicol);
569:         ISDestroy(&isrow);
570:         ISInvertPermutation(isicol,PETSC_DECIDE,&isrow);
571:         ISGetIndices(isrow,&rowindices);
572:         ISGetIndices(isicol,&icolindices);
573:         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isrow_k((PetscInt*)rowindices,A->rmap->n);
574:         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isicol_k ((PetscInt*)icolindices,A->rmap->n);
575:         jac->d_isrow_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isrow_k));
576:         jac->d_isicol_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isicol_k));
577:         Kokkos::deep_copy (*jac->d_isrow_k, h_isrow_k);
578:         Kokkos::deep_copy (*jac->d_isicol_k, h_isicol_k);
579:         ISRestoreIndices(isrow,&rowindices);
580:         ISRestoreIndices(isicol,&icolindices);
581:         ISDestroy(&isrow);
582:         ISDestroy(&isicol);
583:       }
584:       // get block sizes
585:       PCGetDM(pc, &pack);
587:       PetscObjectTypeCompare((PetscObject)pack,DMCOMPOSITE,&flg);
589:       DMCompositeGetNumberDM(pack,&nDMs);
590:       jac->num_dms = nDMs;
591:       DMCreateGlobalVector(pack, &jac->vec_diag);
592:       VecGetLocalSize(jac->vec_diag,&n);
593:       jac->n = n;
594:       jac->d_idiag_k = new Kokkos::View<PetscScalar*, Kokkos::LayoutRight>("idiag", n);
595:       // options
596:       PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
597:       KSPSetFromOptions(jac->ksp);
598:       PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPBICG,"");
599:       if (flg) {jac->ksp_type_idx = BATCH_KSP_BICG_IDX; jac->nwork = 7;}
600:       else {
601:         PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPTFQMR,"");
602:         if (flg) {jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX; jac->nwork = 10;}
603:         else {
604:           PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPGMRES,"");
605:           if (flg) {jac->ksp_type_idx = BATCH_KSP_GMRES_IDX; jac->nwork = 0;}
606:           SETERRQ(PetscObjectComm((PetscObject)jac->ksp),PETSC_ERR_ARG_WRONG,"unsupported type %s", ((PetscObject)jac->ksp)->type_name);
607:         }
608:       }
609:       {
610:         PetscViewer       viewer;
611:         PetscBool         flg;
612:         PetscViewerFormat format;
613:         PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_converged_reason",&viewer,&format,&flg);
614:         jac->reason = flg;
615:         PetscViewerDestroy(&viewer);
616:         PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_monitor",&viewer,&format,&flg);
617:         jac->monitor = flg;
618:         PetscViewerDestroy(&viewer);
619:         PetscOptionsGetInt(((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_batch_target",&jac->batch_target,&flg);
621:         if (!jac->monitor && !flg) jac->batch_target = -1; // turn it off
622:       }
623:       // get blocks - jac->d_bid_eqOffset_k
624:       PetscMalloc(sizeof(*subX)*nDMs, &subX);
625:       PetscMalloc(sizeof(*subDM)*nDMs, &subDM);
626:       PetscMalloc(sizeof(*jac->dm_Nf)*nDMs, &jac->dm_Nf);
627:       PetscInfo(pc, "Have %" PetscInt_FMT " DMs, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name);
628:       DMCompositeGetEntriesArray(pack,subDM);
629:       jac->nBlocks = 0;
630:       for (PetscInt ii=0;ii<nDMs;ii++) {
631:         PetscSection section;
632:         PetscInt Nf;
633:         DM dm = subDM[ii];
634:         DMGetLocalSection(dm, &section);
635:         PetscSectionGetNumFields(section, &Nf);
636:         jac->nBlocks += Nf;
637: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
638:         if (ii==0) PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks);
639: #else
640:         PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks);
641: #endif
642:         jac->dm_Nf[ii] = Nf;
643:       }
644:       { // d_bid_eqOffset_k
645:         Kokkos::View<PetscInt*, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks+1);
646:         DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX);
647:         h_block_offsets[0] = 0;
648:         jac->const_block_size = -1;
649:         for (PetscInt ii=0, idx = 0;ii<nDMs;ii++) {
650:           PetscInt nloc,nblk;
651:           VecGetSize(subX[ii],&nloc);
652:           nblk = nloc/jac->dm_Nf[ii];
654:           for (PetscInt jj=0;jj<jac->dm_Nf[ii];jj++, idx++) {
655:             h_block_offsets[idx+1] = h_block_offsets[idx] + nblk;
656: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
657:             if (idx==0) PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks);
658: #else
659:             PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks);
660: #endif
661:             if (jac->const_block_size == -1) jac->const_block_size = nblk;
662:             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
663:           }
664:         }
665:         DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX);
666:         PetscFree(subX);
667:         PetscFree(subDM);
668:         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt*, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(),h_block_offsets));
669:         Kokkos::deep_copy (*jac->d_bid_eqOffset_k, h_block_offsets);
670:       }
671:     }
672:     { // get jac->d_idiag_k (PC setup),
673:       const PetscInt    *d_ai=aijkok->i_device_data(), *d_aj=aijkok->j_device_data();
674:       const PetscScalar *d_aa = aijkok->a_device_data();
675:       const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
676:       PetscInt          *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
677:       PetscScalar       *d_idiag = jac->d_idiag_k->data();
678:       Kokkos::parallel_for("Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA (const team_member team) {
679:           const PetscInt blkID = team.league_rank();
680:           Kokkos::parallel_for
681:             (Kokkos::TeamThreadRange(team,d_bid_eqOffset[blkID],d_bid_eqOffset[blkID+1]),
682:              [=] (const int rowb) {
683:                const PetscInt    rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
684:                const PetscScalar *aa  = d_aa + ai;
685:                const PetscInt    nrow = d_ai[rowa + 1] - ai;
686:                int found;
687:                Kokkos::parallel_reduce
688:                  (Kokkos::ThreadVectorRange (team, nrow),
689:                   [=] (const int& j, int &count) {
690:                     const PetscInt colb = r[aj[j]];
691:                     if (colb==rowb) {
692:                       d_idiag[rowb] = 1./aa[j];
693:                       count++;
694:                     }}, found);
695: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
696:                if (found!=1) Kokkos::single (Kokkos::PerThread (team), [=] () {printf("ERRORrow %d) found = %d\n",rowb,found);});
697: #endif
698:              });
699:         });
700:     }
701:   }
702:   return 0;
703: }

705: /* Default destroy, if it has never been setup */
706: static PetscErrorCode PCReset_BJKOKKOS(PC pc)
707: {
708:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;

710:   KSPDestroy(&jac->ksp);
711:   VecDestroy(&jac->vec_diag);
712:   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
713:   if (jac->d_idiag_k) delete jac->d_idiag_k;
714:   if (jac->d_isrow_k) delete jac->d_isrow_k;
715:   if (jac->d_isicol_k) delete jac->d_isicol_k;
716:   jac->d_bid_eqOffset_k = NULL;
717:   jac->d_idiag_k = NULL;
718:   jac->d_isrow_k = NULL;
719:   jac->d_isicol_k = NULL;
720:   PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",NULL); // not published now (causes configure errors)
721:   PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",NULL);
722:   PetscFree(jac->dm_Nf);
723:   jac->dm_Nf = NULL;
724:   return 0;
725: }

727: static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
728: {
729:   PCReset_BJKOKKOS(pc);
730:   PetscFree(pc->data);
731:   return 0;
732: }

734: static PetscErrorCode PCView_BJKOKKOS(PC pc,PetscViewer viewer)
735: {
736:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
737:   PetscBool      iascii;

739:   if (!jac->ksp) PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
740:   PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii);
741:   if (iascii) {
742:     PetscViewerASCIIPrintf(viewer,"  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n");
743:     PetscCall(PetscViewerASCIIPrintf(viewer,"\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n",jac->nwork,jac->ksp->rtol,
744:                                    jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
745:                                    ((PetscObject)jac->ksp)->type_name));
746:   }
747:   return 0;
748: }

750: static PetscErrorCode PCSetFromOptions_BJKOKKOS(PetscOptionItems *PetscOptionsObject,PC pc)
751: {
752:   PetscOptionsHead(PetscOptionsObject,"PC BJKOKKOS options");
753:   PetscOptionsTail();
754:   return 0;
755: }

757: static PetscErrorCode  PCBJKOKKOSSetKSP_BJKOKKOS(PC pc,KSP ksp)
758: {
759:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;

761:   PetscObjectReference((PetscObject)ksp);
762:   KSPDestroy(&jac->ksp);
763:   jac->ksp = ksp;
764:   return 0;
765: }

767: /*@C
768:    PCBJKOKKOSSetKSP - Sets the KSP context for a KSP PC.

770:    Collective on PC

772:    Input Parameters:
773: +  pc - the preconditioner context
774: -  ksp - the KSP solver

776:    Notes:
777:    The PC and the KSP must have the same communicator

779:    Level: advanced

781: @*/
782: PetscErrorCode  PCBJKOKKOSSetKSP(PC pc,KSP ksp)
783: {
787:   PetscTryMethod(pc,"PCBJKOKKOSSetKSP_C",(PC,KSP),(pc,ksp));
788:   return 0;
789: }

791: static PetscErrorCode  PCBJKOKKOSGetKSP_BJKOKKOS(PC pc,KSP *ksp)
792: {
793:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;

795:   if (!jac->ksp) PCBJKOKKOSCreateKSP_BJKOKKOS(pc);
796:   *ksp = jac->ksp;
797:   return 0;
798: }

800: /*@C
801:    PCBJKOKKOSGetKSP - Gets the KSP context for a KSP PC.

803:    Not Collective but KSP returned is parallel if PC was parallel

805:    Input Parameter:
806: .  pc - the preconditioner context

808:    Output Parameters:
809: .  ksp - the KSP solver

811:    Notes:
812:    You must call KSPSetUp() before calling PCBJKOKKOSGetKSP().

814:    If the PC is not a PCBJKOKKOS object it raises an error

816:    Level: advanced

818: @*/
819: PetscErrorCode  PCBJKOKKOSGetKSP(PC pc,KSP *ksp)
820: {
823:   PetscUseMethod(pc,"PCBJKOKKOSGetKSP_C",(PC,KSP*),(pc,ksp));
824:   return 0;
825: }

827: /* ----------------------------------------------------------------------------------*/

829: /*MC
830:      PCBJKOKKOS -  Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a AIJASeq matrix on the GPU.

832:    Options Database Key:
833: .     -pc_bjkokkos_ - options prefix with ksp options

835:    Level: intermediate

837:    Notes:
838:     For use with -ksp_type preonly to bypass any CPU work

840:    Developer Notes:

842: .seealso:  PCCreate(), PCSetType(), PCType (for list of available types), PC,
843:            PCSHELL, PCCOMPOSITE, PCSetUseAmat(), PCBJKOKKOSGetKSP()

845: M*/

847: PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
848: {
849:   PC_PCBJKOKKOS *jac;

851:   PetscNewLog(pc,&jac);
852:   pc->data = (void*)jac;

854:   jac->ksp              = NULL;
855:   jac->vec_diag         = NULL;
856:   jac->d_bid_eqOffset_k = NULL;
857:   jac->d_idiag_k        = NULL;
858:   jac->d_isrow_k        = NULL;
859:   jac->d_isicol_k       = NULL;
860:   jac->nBlocks          = 1;

862:   PetscMemzero(pc->ops,sizeof(struct _PCOps));
863:   pc->ops->apply           = PCApply_BJKOKKOS;
864:   pc->ops->applytranspose  = NULL;
865:   pc->ops->setup           = PCSetUp_BJKOKKOS;
866:   pc->ops->reset           = PCReset_BJKOKKOS;
867:   pc->ops->destroy         = PCDestroy_BJKOKKOS;
868:   pc->ops->setfromoptions  = PCSetFromOptions_BJKOKKOS;
869:   pc->ops->view            = PCView_BJKOKKOS;

871:   PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",PCBJKOKKOSGetKSP_BJKOKKOS);
872:   PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",PCBJKOKKOSSetKSP_BJKOKKOS);
873:   return 0;
874: }