aboutsummaryrefslogtreecommitdiffstats
path: root/packages/dev-utils/src/blockchain_lifecycle.ts
blob: 4bb13609710bab04d80d81dd4853317b388be5de (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import { uniqueVersionIds, Web3Wrapper } from '@0xproject/web3-wrapper';
import { includes } from 'lodash';

enum NodeType {
    Geth = 'GETH',
    Ganache = 'GANACHE',
}

export class BlockchainLifecycle {
    private _web3Wrapper: Web3Wrapper;
    private _snapshotIdsStack: number[];
    constructor(web3Wrapper: Web3Wrapper) {
        this._web3Wrapper = web3Wrapper;
        this._snapshotIdsStack = [];
    }
    public async startAsync(): Promise<void> {
        const nodeType = await this._getNodeTypeAsync();
        switch (nodeType) {
            case NodeType.Ganache:
                const snapshotId = await this._web3Wrapper.takeSnapshotAsync();
                this._snapshotIdsStack.push(snapshotId);
                break;
            case NodeType.Geth:
                const blockNumber = await this._web3Wrapper.getBlockNumberAsync();
                this._snapshotIdsStack.push(blockNumber);
                break;
            default:
                throw new Error(`Unknown node type: ${nodeType}`);
        }
    }
    public async revertAsync(): Promise<void> {
        const nodeType = await this._getNodeTypeAsync();
        switch (nodeType) {
            case NodeType.Ganache:
                const snapshotId = this._snapshotIdsStack.pop() as number;
                const didRevert = await this._web3Wrapper.revertSnapshotAsync(snapshotId);
                if (!didRevert) {
                    throw new Error(`Snapshot with id #${snapshotId} failed to revert`);
                }
                break;
            case NodeType.Geth:
                const blockNumber = this._snapshotIdsStack.pop() as number;
                await this._web3Wrapper.setHeadAsync(blockNumber);
                break;
            default:
                throw new Error(`Unknown node type: ${nodeType}`);
        }
    }
    private async _getNodeTypeAsync(): Promise<NodeType> {
        const version = await this._web3Wrapper.getNodeVersionAsync();
        if (includes(version, uniqueVersionIds.geth)) {
            return NodeType.Geth;
        } else if (includes(version, uniqueVersionIds.ganache)) {
            return NodeType.Ganache;
        } else {
            throw new Error(`Unknown client version: ${version}`);
        }
    }
}