I think I came across a method for doing this based on this post:
https://fenicsproject.org/qa/12282/change-values-function-mesh-function-defined-boundarymesh
There's probably a better way to do this and I'm still trying to check it for validity, but maybe this helps someone. Any comments to make this code better are appreciated!
from fenics import *
import time
import numpy as np
parameters['allow_extrapolation'] = True
ofilename = 'dg_jump_mesh.xml'
mesh = Mesh(ofilename)
subdomains0 = MeshFunction("size_t", mesh, "%s_physical_region.xml"%ofilename.split('.')[0])
boundaries0 = MeshFunction("size_t", mesh, "%s_facet_region.xml"%ofilename.split('.')[0])
# Rename subdomain numbers from GMSH
subdomains = CellFunction("size_t", mesh, 0)
subdomains.set_all(0)
subdomains.array()[subdomains0.array()==4] = 0
subdomains.array()[subdomains0.array()==5] = 1
# Rename boundary numbers from GMSH
boundaries = FacetFunction("size_t", mesh, 0)
boundaries.set_all(0)
boundaries.array()[boundaries0.array()==1] = 1
boundaries.array()[boundaries0.array()==2] = 2
boundaries.array()[boundaries0.array()==3] = 3
# Define function space and basis functions
V = FunctionSpace(mesh, 'DG', 2)
Vcg = FunctionSpace(mesh, 'CG', 1)
u = TrialFunction(V)
v = TestFunction(V)
# Define new measures associated with the interior domains and
ds = Measure('ds', domain=mesh, subdomain_data=boundaries)
dS = Measure('dS', domain=mesh, subdomain_data=boundaries)
dx = Measure('dx', domain=mesh, subdomain_data=subdomains)
dom1 = SubMesh(mesh, subdomains, 1)
Vc = FunctionSpace(dom1, 'CG', 1)
dom2 = SubMesh(mesh, subdomains, 0)
Vm = FunctionSpace(dom2, 'CG', 1)
bmesh = BoundaryMesh(R, 'exterior')
Vb = FunctionSpace(bmesh, 'CG', 1)
interface = SubMesh(bmesh, subdomains, 0)
Vi = FunctionSpace(interface, 'CG', 1)
ui = TrialFunction(Vi)
vi = TestFunction(Vi)
dxi = Measure('dx', domain=interface)
alpha = 10.0
gamma = alpha
n = FacetNormal(mesh)
h = CellSize(mesh)
h_avg = (h('+') + h('-'))/2
D1 = 1.0
D2 = 0.0
f = Constant(0.0)
u_0 = interpolate(Constant(0.0), V)
ui_0 = interpolate(Constant(0), Vi)
J = Function(V)
Jb = Function(Vb)
J = Function(Vcg)
dofb_vb = np.array(dof_to_vertex_map(Vb), dtype=int)
vb_v = np.array(bmesh.entity_map(0), dtype=int)
v_dof = np.array(vertex_to_dof_map(Vcg), dtype=int)
ofile = 'Simulation Results/TEST_%s.pvd'%ofilename.split('.')[0]
print(' [+] Output to %s'%(ofile))
vtkfile = File(ofile)
t_end = 2e-6
dt = 0.1e-6
t = 0.
while(t<t_end):
## interpolate u_0 onto submesh for each side
u_i = interpolate(u_0, Vc)
u_o = interpolate(u_0, Vm)
## interpolate each submesh to boundary mesh
ub_i = interpolate(u_i, Vb)
ub_o = interpolate(u_o, Vb)
## subtract interpolated values to obtain jump(u_0) on boundary mesh
tmpb = ub_i - ub_o
tmpi = project(tmpb, Vi)
## perform calculation
ab = inner(grad(ui), grad(vi))*dxi
Lb = dt*inner(grad(tmpi), grad(vi))*dxi + inner(grad(tmpi), grad(vi))*dxi
Lb = f*vi*dxi
solve(ab == Lb, ui_0)
## Interpolate solution back to boundary mesh
Jb.interpolate(ui_0)
## Interpolate boundary mesh back to submesh
array = J.vector().array()
array[v_dof[vb_v[dofb_vb]]] = Jb.vector().array()
J.vector()[:] = array
## Deal with volume integrals (stiffness matrix)
a = inner(grad(v), grad(u))*dx
## Deal with internal element interface conditions on non-boundary elements
a += - inner(jump(u, n), avg(grad(v)))*dS(0) \
- inner(jump(v, n), avg(grad(u)))*dS(0) \
+ (alpha/h_avg)*inner(jump(v, n), jump(u, n))*dS(0)
## Deal with boundary conditions
a += - inner(grad(v), u*n)*ds(1) \
+ (gamma/h)*v*u*ds(1) \
- inner(grad(v), u*n)*ds(2) \
+ (gamma/h)*v*u*ds(2) \
## Deal with imposed jump condition explicitly
a += - inner(avg(grad(v)), jump(u, n))*dS(3) \
- inner(jump(v, n), avg(grad(u)))*dS(3) \
+ (gamma/h_avg)*inner(jump(u, n), jump(v, n))*dS(3)
## Deal with source term
L = v*f*dx
## Deal with Dirichlet boundary conditions explicitly
L += - D1*inner(grad(v), n)*ds(1) \
+ (gamma/h)*D1*v*ds(1) \
- D2*inner(grad(v), n)*ds(2) \
+ (gamma/h)*D2*v*ds(2) \
L += - inner(J*n('-'), avg(grad(v)))*dS(3) \
+ (gamma/h_avg)*inner(J, jump(v))*dS(3)
#
start = time.time()
solve(a == L, u_0)
vtkfile << u_0
# vtkfile << J
# vtkfile << tmpi
print('Calculation Time = %s'%(time.time()-start))
t += dt