"""
Adaptive Poisson solver using a residual-based energy-norm error
estimator

  eta_h**2 = sum_T eta_T**2

with

  eta_T**2 = h_T**2 ||R_T||_T**2 + c h_T ||R_dT||_dT**2

where

  R_T =  f + div grad uh
  R_dT = 2 avg(grad uh * n)  (2*avg is jump, since n switches sign across edges)

and a Dorfler marking strategy

By Richard Falk making minor modifications (changing the domain) in a code
of Douglas Arnold building on code of Marie Rognes.  Last revision 
February 23, 2017. 
"""

from dolfin import *
from mshr import *
from numpy import zeros

# Stop when sum of eta_T**2 < tolerance**2 or max_iterations is reached
tolerance = 0.35
max_iterations = 20

# Create initial mesh
# Create wedge circle geometry
fraction = 0.5
x_0 = 0.0
x_1 = 0.0
r = 1.0
theta = fraction*pi/2.
a = Point(x_0, x_1)
circle = Circle(a, r)
r = 2*r
b = Point(x_0, x_1 - r)
c = Point(x_0 + r*sin(theta), x_1 - r*cos(theta))
polygon = Polygon([a, b, c])
wedge = circle - polygon

# Generate mesh from geometry
mesh = generate_mesh(wedge, 10)
viz_imesh = plot(mesh, title="initial mesh")
# viz_imesh.write_pdf('initialmesh')

# Dirichlet boundary condition and right hand side
g = Constant(0.)
# Right hand side = 1
f =  Expression("1.0", degree =2)

# SOLVE - ESTIMATE - MARK - REFINE loop
for i in range(max_iterations):

    # *** SOLVE step
    # Define variational problem and boundary condition
    # Solve variational problem on current mesh
    V = FunctionSpace(mesh, "CG", 1)
    u = TrialFunction(V)
    v = TestFunction(V)
    a = inner(grad(u), grad(v))*dx
    L = f*v*dx
    uh = Function(V)
    solve(a==L, uh, DirichletBC(V, g, DomainBoundary()))

    # *** ESTIMATE step
    # Define cell and edge residuals
    R_T = f + div(grad(uh))
    # get the normal to the cells
    n = FacetNormal(mesh)
    R_dT = 2*avg(dot(grad(uh), n))
    # Will use space of constants to localize indicator form
    Constants = FunctionSpace(mesh, "DG", 0)
    w = TestFunction(Constants)
    h = CellSize(mesh)
    # Assemble squared error indicators, eta_T^2, and store into a numpy array
    eta2 = assemble(h**2 * R_T**2 * w * dx + 4.*avg(h) * R_dT**2 * avg(w) * dS) # dS is integral over interior edges only
    eta2 = eta2.array()
    # compute maximum and sum (which is the estimate for squared H1 norm of error)
    eta2_max = max(eta2)
    sum_eta2 = sum(eta2)
    # stop error estimate is less than tolerance
    if sum_eta2 < tolerance**2:
        print "Final mesh %g: %d triangles, %d vertices" % (i+1, mesh.num_cells(), mesh.num_vertices())
        print "\nTolerance achieved.  Exiting."
        break

    # *** MARK step
    # Mark cells for refinement for which eta > frac eta_max for frac = .95, .90, ...;
    # choose frac so that marked elements account for a given part of total error
    frac = .95
    delfrac = .05
    part = .25
    marked = zeros(eta2.size, dtype='bool') # marked starts as False for all elements
    sum_marked_eta2 = 0. # sum over marked elements of squared error indicators
    while sum_marked_eta2 < part*sum_eta2:
      new_marked = (~marked) & (eta2 > frac*eta2_max)
      sum_marked_eta2 += sum(eta2[new_marked])
      marked += new_marked
      frac -= delfrac
    # attach Boolean array marked to a cell function
    cells_marked = CellFunction("bool", mesh)
    cells_marked.array()[:] = marked    

    # *** REFINE step
    mesh = refine(mesh, cells_marked)
    #plot(mesh, title="Mesh q" + str(i))

viz_fmesh = plot(mesh, title="final mesh")
# viz_fmesh.write_pdf('finalmesh')



