"""
Douglas N. Arnold, 2011-11-17
Revised by Richard Falk February 23, 2017

Poisson equation with Dirichlet conditions, study of convergence in
L2 and H1 norms.

-Laplace(u) = f on the unit square.
u = u0 on the boundary.

We take f with a 1/r^1.8 singularity at (1/3, 1/3), and u0 = 0, so
the solution is singular.

We take a solution on a fine mesh and with high degree elements as
the exact solution.
"""

from dolfin import *
from numpy import log

# Define source term
f = Expression("pow(max( (x[0]-1./3.)*(x[0]-1./3.)+(x[1]-1./3.)*(x[1]-1./3.), DOLFIN_EPS),-.1)", degree = 2)

nmeshes = 5  # number of meshes to compute with
deg = 1      # degree of the elements to compute with

# compute "exact solution" on a very fine mesh with higher degree elements
n = 2**(nmeshes+3)
mesh = UnitSquareMesh(n, n)
Vex = FunctionSpace(mesh, 'Lagrange', deg+1)
bc = DirichletBC(Vex, Constant(0.), DomainBoundary())
v = TestFunction(Vex)
u = TrialFunction(Vex)
b = inner(grad(u), grad(v))*dx
F = f*v*dx
uex = Function(Vex)
solve(b == F, uex, bc)
print "Computed comparison solution using {} triangles and degree {}".format(mesh.num_cells(), deg+1)

plot(uex, interactive=True)

errors = [] # list into which to store the errors
for i in range(nmeshes):

    # Create mesh and define function space
    n = 2**(i+3)
    mesh = UnitSquareMesh(n, n)
    Vh = FunctionSpace(mesh, 'Lagrange', deg)
    bc = DirichletBC(Vh, Constant(0.), DomainBoundary())

    # Define variational problem
    v = TestFunction(Vh)
    u = TrialFunction(Vh)
    b = inner(grad(u), grad(v))*dx
    F = f*v*dx

    # Compute solution
    uh = Function(Vh)
    solve(b == F, uh, bc)

    # compare with exact solution
    err = uex - interpolate(uh, Vex)

    L2norm = sqrt( assemble( uex*uex*dx ) )
    H1seminorm = sqrt( assemble( inner(grad(uex),grad(uex))*dx ) )
    L2normerr = sqrt( assemble( err*err*dx))
    H1seminormerr = sqrt( assemble( inner(grad(err),grad(err))*dx ) )
    errors.append([1.0/n, L2normerr, H1seminormerr])
    
print "\n     h      L2 error           H1 error            L2 rate H1 rate\n"
print "  {:7.5f}   {:4.2e} ({:5.2f}%)  {:4.2e} ({:5.2f}%)".format( \
    errors[0][0], errors[0][1], 100*errors[0][1]/L2norm, errors[0][2], 100*errors[0][2]/H1seminorm)
for i in range(1,nmeshes):
    print "  {:7.5f}   {:4.2e} ({:5.2f}%)  {:4.2e} ({:5.2f}%)  {:5.2f}   {:5.2f}".format( \
    errors[i][0], errors[i][1], 100*errors[i][1]/L2norm, errors[i][2], 100*errors[i][2]/H1seminorm, \
    log(errors[i-1][1]/errors[i][1])/log(2), log(errors[i-1][2]/errors[i][2])/log(2))
