Skip to content

Possible updates to PyTorch 1.10 #9

@jhualberta

Description

@jhualberta

Thank you for your lecture, it is very useful.
I know it has already been 5 years since these note originally released. Some functions in PyTorch has been deprecated.
The followings are possible updates to PyTorch 1.10 and Python3.11
schrodinger.py

schrodinger.py

Replaced torch.symeig(H, eigenvectors=True) → ✅ torch.linalg.eigh(H)
#########################################################

import numpy as np 
import torch
torch.set_default_dtype(torch.float64)
import torch.nn as nn
import matplotlib.pyplot as plt

class Schrodinger1D(nn.Module):
    def __init__(self, xmesh):
        super(Schrodinger1D, self).__init__()
        
        self.xmesh = xmesh
        self.potential = nn.Parameter(xmesh**2)

        nmesh = xmesh.shape[0]
        h2 = (xmesh[1] - xmesh[0]) ** 2
        self.K =   torch.diag(1/h2 * torch.ones(nmesh, dtype=xmesh.dtype), diagonal=0) \
                 - torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=1) \
                 - torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=-1)

    def _solve(self):
        H = torch.diag(self.potential) + self.K
        eigvals, eigvecs = torch.linalg.eigh(H)  # Replaced deprecated symeig
        return eigvecs[:, 0]  # Ground state (corresponding to smallest eigenvalue)

    def forward(self, target):
        psi = self._solve()
        return (psi**2 - target).abs().sum()

    def plot(self, target):
        psi = self._solve().detach()

        plt.cla()
        plt.plot(self.xmesh.numpy(), target.numpy(), label='Target Density')
        plt.plot(self.xmesh.numpy(), psi.square().numpy(), label='Current Density')
        plt.plot(self.xmesh.numpy(), self.potential.detach().numpy()/10000, label='Potential (V/10000)')
        plt.legend()
        plt.draw()

if __name__ == '__main__':
    # Prepare mesh and target density
    xmin, xmax, Nmesh = -1, 1, 500
    xmesh = torch.linspace(xmin, xmax, Nmesh)
    
    target = torch.zeros(Nmesh)
    idx = torch.where(torch.abs(xmesh) < 0.5)
    target[idx] = 1. - torch.abs(xmesh[idx])
    target = (target / torch.norm(target))**2
    
    model = Schrodinger1D(xmesh)
    optimizer = torch.optim.LBFGS(
        model.parameters(), 
        max_iter=10, 
        tolerance_change=1E-7, 
        tolerance_grad=1E-7, 
        line_search_fn='strong_wolfe'
    )

    def closure():
        optimizer.zero_grad()
        loss = model(target)  # Density difference 
        loss.backward()
        return loss 

    plt.ion()
    for epoch in range(50):
        loss = optimizer.step(closure)
        print(epoch, loss.item())
        model.plot(target)
        plt.pause(0.01)

    plt.ioff()
    model.plot(target)
    plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions