summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorcathook <b01902109@csie.ntu.edu.tw>2014-11-04 21:27:07 +0800
committercathook <b01902109@csie.ntu.edu.tw>2014-11-04 21:27:07 +0800
commitb03e0b80e59fc649a7d26c880d10b545aeee6024 (patch)
tree141eac55ab2f1297f9061a072f7d342e7ca1b1d7
downloadvim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar
vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.gz
vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.bz2
vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.lz
vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.xz
vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.zst
vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.zip
Init commit, gives a prototype of the shared vim.
-rw-r--r--shared_vim.vim257
-rwxr-xr-xshared_vim_server995
2 files changed, 1252 insertions, 0 deletions
diff --git a/shared_vim.vim b/shared_vim.vim
new file mode 100644
index 0000000..d2b06f5
--- /dev/null
+++ b/shared_vim.vim
@@ -0,0 +1,257 @@
+function! SharedVimConnect(server_name, port, identity)
+ let b:shared_vim_server_name = a:server_name
+ let b:shared_vim_port = a:port
+ let b:shared_vim_identity = a:identity
+ let b:shared_vim_init = 1
+ call SharedVimSync()
+endfunction
+
+
+function! SharedVimDisconnect()
+ unlet! b:shared_vim_server_name
+ unlet! b:shared_vim_port
+ unlet! b:shared_vim_identity
+ unlet! b:shared_vim_init
+endfunction
+
+
+function! SharedVimSync()
+python << EOF
+import json
+import re
+import socket
+import vim
+import zlib
+
+
+class JSON_TOKEN: # pylint:disable=W0232
+ """Enumeration the Ttken strings for json object."""
+ CURSOR = 'cursor' # cursor position
+ CURSORS = 'cursors' # other users' cursor position
+ ERROR = 'error' # error string
+ IDENTITY = 'identity' # identity of myself
+ INIT = 'init' # initialize connect flag
+ TEXT = 'text' # text content in the buffer
+
+def vim_input(prompt='', default_value=''):
+ vim.command('call inputsave()')
+ vim.command("let user_input = input('%s','%s')" % (prompt, default_value))
+ vim.command('call inputrestore()')
+ return vim.eval('user_input')
+
+class StringTCP(object):
+ """Send/receive strings by tcp connection.
+
+ Attributes:
+ _connection: The tcp connection.
+ """
+ ENCODING = 'utf-8'
+ COMPRESS_LEVEL = 2
+ HEADER_LENGTH = 10
+
+ def __init__(self, sock=None, servername=None, port=None):
+ """Constructor.
+
+ Args:
+ sock: The tcp connection. if it is None, the constructor will
+ automatically creates an tcp connection to servername:port.
+ servername: The server name if needs.
+ port: The server port if needs.
+ """
+ if sock is None:
+ self._connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._connection.connect((servername, port))
+ else:
+ self._connection = sock
+
+ def send_string(self, string):
+ """Sends a string to the tcp-connection.
+
+ Args:
+ string: The string to be sent.
+ """
+ body = StringTCP._create_body_from_string(string)
+ header = StringTCP._create_header_from_body(body)
+ self._connection.send(header + body)
+
+ def recv_string(self):
+ """Receives a string from the tcp-connection.
+
+ Returns:
+ The string received.
+ """
+ header = StringTCP._recv_header_string(self._connection)
+ body = StringTCP._recv_body_string(self._connection, header)
+ return body
+
+ def close(self):
+ """Closes the socket."""
+ self._connection.close()
+
+ @staticmethod
+ def _create_body_from_string(string):
+ """Creates package body from data string.
+
+ Args:
+ string: Data string.
+ """
+ byte_string = string.encode(StringTCP.ENCODING)
+ return zlib.compress(byte_string, StringTCP.COMPRESS_LEVEL)
+
+ @staticmethod
+ def _create_header_from_body(body):
+ """Creates package header from package body.
+
+ Args:
+ body: Package body.
+ """
+ header_string = ('%%0%dd' % StringTCP.HEADER_LENGTH) % len(body)
+ return header_string.encode(StringTCP.ENCODING)
+
+ @staticmethod
+ def _recv_header_string(conn):
+ """Receives package header from specified tcp connection.
+
+ Args:
+ conn: The specified tcp connection.
+
+ Returns:
+ Package header.
+ """
+ return conn.recv(StringTCP.HEADER_LENGTH).decode(StringTCP.ENCODING)
+
+ @staticmethod
+ def _recv_body_string(conn, header):
+ """Receives package body from specified tcp connection and header.
+
+ Args:
+ conn: The specified tcp connection.
+ header: The package header.
+
+ Returns:
+ Package body.
+ """
+ body_length = int(header)
+ body = conn.recv(body_length)
+ body_byte = zlib.decompress(body)
+ return body_byte.decode(StringTCP.ENCODING)
+
+
+def rc_to_num(lines, rc):
+ """Transforms cursor position from row-col format to numeric format.
+
+ Args:
+ lines: List of lines of the context.
+ rc: A 2-tuple for row-col position.
+
+ Returns:
+ A number for cursor position.
+ """
+ result = 0
+ for row in range(0, rc[0]):
+ result += len(lines[row]) + 1
+ result += rc[1]
+ return result
+
+def nums_to_rcs(lines, nums):
+ """Transforms cursor positions from numeric format to row-col format.
+
+ Args:
+ lines: List of lines of the context.
+ nums: List of number of each positions to be transformed.
+
+ Returns:
+ A list of 2-tuple row-col format cursor positions.
+ """
+ sorted_index = sorted(range(len(nums)), key=lambda x: nums[x])
+ now_index, max_index = 0, len(sorted_index)
+ num = 0
+ rcs = [None] * len(nums)
+ for row in range(len(lines)):
+ num_max = num + len(lines[row])
+ while now_index < max_index:
+ if nums[sorted_index[now_index]] > num_max:
+ break
+ rcs[sorted_index[now_index]] = \
+ (row, nums[sorted_index[now_index]] - num)
+ now_index += 1
+ else:
+ break
+ num = num_max + 1
+ return rcs
+
+
+def main():
+ """Main process."""
+ try:
+ # Fetches information.
+ server_name = vim.current.buffer.vars['shared_vim_server_name']
+ port = vim.current.buffer.vars['shared_vim_port']
+ identity = vim.current.buffer.vars['shared_vim_identity']
+ cursor_position = rc_to_num(vim.current.buffer[:],
+ (vim.current.window.cursor[0] - 1,
+ vim.current.window.cursor[1]))
+ text = '\n'.join(vim.current.buffer[:])
+ init_flag = vim.current.buffer.vars['shared_vim_init']
+
+ if text and init_flag:
+ result = vim_input('It will clear the buffer, would you want to ' +
+ 'continue? [Y/n] ', 'Y')
+ if result == 'y' or result == 'Y':
+ vim.current.buffer[0 : len(vim.current.buffer)] = ['']
+ text = ''
+ else:
+ return
+ print('')
+
+ # Creates request.
+ request = {
+ JSON_TOKEN.IDENTITY : identity,
+ JSON_TOKEN.TEXT : text,
+ JSON_TOKEN.CURSOR : cursor_position,
+ JSON_TOKEN.INIT : bool(init_flag),
+ }
+
+ # Connects to the server and gets the response.
+ print('Connect to %s:%d' % (server_name, port))
+ conn = StringTCP(servername=server_name, port=port)
+ conn.send_string(json.dumps(request))
+ response = json.loads(conn.recv_string())
+ conn.close()
+
+ # Sync.
+ if JSON_TOKEN.ERROR in response:
+ raise Exception('from server: ' + response[JSON_TOKEN.ERROR])
+ else:
+ lines = re.split(r'\n', response[JSON_TOKEN.TEXT])
+ rcs = nums_to_rcs(lines, response[JSON_TOKEN.CURSORS])
+ my_rc = nums_to_rcs(lines, [response[JSON_TOKEN.CURSOR]])[0]
+
+ other_cursors_ptrn = '/%s/' % ('\\|'.join(
+ ['\\%%%dl\\%%%dc' % (rc[0] + 1, rc[1] + 1) for rc in rcs]))
+
+ vim.command('match SharedVimOthersCursors %s' % other_cursors_ptrn)
+ vim.current.buffer[0 : len(vim.current.buffer)] = lines
+ vim.current.window.cursor = (my_rc[0] + 1, my_rc[1])
+
+ except Exception as e:
+ print(e)
+
+main()
+EOF
+ let b:shared_vim_init = 0
+endfunction
+
+
+function! SharedVimEventsHandler(event_name)
+ if exists('b:shared_vim_server_name')
+ if a:event_name == 'VimCursorMoved'
+ call SharedVimSync()
+ endif
+ endif
+endfunction
+
+
+autocmd CursorMoved * call SharedVimEventsHandler('VimCursorMoved')
+
+highlight SharedVimOthersCursors ctermbg=darkred
diff --git a/shared_vim_server b/shared_vim_server
new file mode 100755
index 0000000..dd88155
--- /dev/null
+++ b/shared_vim_server
@@ -0,0 +1,995 @@
+#! /usr/bin/env python3
+
+"""Shared Vim server."""
+
+import difflib
+import json
+import re
+import socket
+import sys
+import threading
+import zlib
+
+
+BACKLOG = 1024
+PROMPT = '> '
+
+class AUTHORITY: # pylint:disable=W0232
+ """Enumeration the types of authority."""
+ READONLY = 1 # can only read.
+ READWRITE = 2 # can read and write.
+
+class JSON_TOKEN: # pylint:disable=W0232
+ """Enumeration the Ttken strings for json object."""
+ CURSOR = 'cursor' # cursor position
+ CURSORS = 'cursors' # other users' cursor position
+ ERROR = 'error' # error string
+ IDENTITY = 'identity' # identity of myself
+ INIT = 'init' # initialize connect flag
+ TEXT = 'text' # text content in the buffer
+
+
+def normal_print(string=None, prompt=True):
+ """Prints string to stdout.
+
+ It will also prints a prompt string for CUI if needs.
+
+ Args:
+ string: The string to be printed.
+ prompt: True of it needs to print a prompt string.
+ """
+ string = '' if string is None else string
+ if prompt:
+ sys.stdout.write(string + PROMPT)
+ else:
+ sys.stdout.write(string)
+ sys.stdout.flush()
+
+
+class StringTCP(object):
+ """Send/receive strings by tcp connection.
+
+ Attributes:
+ _connection: The tcp connection.
+ """
+ ENCODING = 'utf-8'
+ COMPRESS_LEVEL = 2
+ HEADER_LENGTH = 10
+
+ def __init__(self, sock=None, servername=None, port=None):
+ """Constructor.
+
+ Args:
+ sock: The tcp connection. if it is None, the constructor will
+ automatically creates an tcp connection to servername:port.
+ servername: The server name if needs.
+ port: The server port if needs.
+ """
+ if sock is None:
+ self._connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._connection.connect((servername, port))
+ else:
+ self._connection = sock
+
+ def send_string(self, string):
+ """Sends a string to the tcp-connection.
+
+ Args:
+ string: The string to be sent.
+ """
+ body = StringTCP._create_body_from_string(string)
+ header = StringTCP._create_header_from_body(body)
+ self._connection.send(header + body)
+
+ def recv_string(self):
+ """Receives a string from the tcp-connection.
+
+ Returns:
+ The string received.
+ """
+ header = StringTCP._recv_header_string(self._connection)
+ body = StringTCP._recv_body_string(self._connection, header)
+ return body
+
+ def close(self):
+ """Closes the socket."""
+ self._connection.close()
+
+ @staticmethod
+ def _create_body_from_string(string):
+ """Creates package body from data string.
+
+ Args:
+ string: Data string.
+
+ Returns:
+ Package body.
+ """
+ byte_string = string.encode(StringTCP.ENCODING)
+ return zlib.compress(byte_string, StringTCP.COMPRESS_LEVEL)
+
+ @staticmethod
+ def _create_header_from_body(body):
+ """Creates package header from package body.
+
+ Args:
+ body: Package body.
+
+ Returns:
+ Package header.
+ """
+ header_string = ('%%0%dd' % StringTCP.HEADER_LENGTH) % len(body)
+ return header_string.encode(StringTCP.ENCODING)
+
+ @staticmethod
+ def _recv_header_string(conn):
+ """Receives package header from specified tcp connection.
+
+ Args:
+ conn: The specified tcp connection.
+
+ Returns:
+ Package header.
+ """
+ return conn.recv(StringTCP.HEADER_LENGTH).decode(StringTCP.ENCODING)
+
+ @staticmethod
+ def _recv_body_string(conn, header):
+ """Receives package body from specified tcp connection and header.
+
+ Args:
+ conn: The specified tcp connection.
+ header: The package header.
+
+ Returns:
+ Package body.
+ """
+ body_length = int(header)
+ body = conn.recv(body_length)
+ body_byte = zlib.decompress(body)
+ return body_byte.decode(StringTCP.ENCODING)
+
+
+class ChgStepInfo(object):
+ """Stores one step in modifying one string to anothor string.
+
+ Here a step is to replace a range of substring with another string. So from
+ the original string, we can do lots of step and then get the final result.
+
+ Attributes:
+ _begin: Begin of the replacement in the original string.
+ _end: End of the replacement in the original string.
+ _new_str: The string to replace.
+ _begin2: Where the _new_str at the final string.
+ """
+ def __init__(self, begin, end, begin2, new_str):
+ """Constructor.
+
+ Args:
+ begin: Begin of the replacement in the original string.
+ end: End of the replacement in the original string.
+ begin2: Where the _new_str at the final string.
+ new_str: The string to replace.
+ """
+ self._begin = begin
+ self._end = end
+ self._begin2 = begin2
+ self._new_str = new_str
+
+ def rebase(self, step2):
+ """Inserts a modify event before it.
+
+ It will update the self._begin, self._end to prevent confliction.
+
+ Args:
+ step: The other modify event.
+ """
+ if self._begin < step2.begin:
+ self._end = min([self._end, step2.begin])
+ else:
+ self._begin = max([self._begin, step2.end])
+ self._end = min([self._begin, self._end])
+ self._begin += step2.increased_length
+ self._end += step2.increased_length
+
+ @property
+ def begin(self):
+ """Gets the begin of the replacing range."""
+ return self._begin
+
+ @property
+ def end(self):
+ """Gets the end of the replacing range."""
+ return self._end
+
+ @property
+ def begin2(self):
+ """Gets the begin of the replacing string."""
+ return self._begin2
+
+ @property
+ def new_str(self):
+ """Gets the replacing string."""
+ return self._new_str
+
+ @property
+ def increased_length(self):
+ """Gets the amount of length increased if applying this step."""
+ return len(self.new_str) - (self.end - self.begin)
+
+
+class CursorPosInfo_RelativeToNew(object):
+ """Stores a cursor position.
+
+ Attributes:
+ step_id: The step id in that commit.
+ delta: Difference between the step.begin and the real position.
+ """
+ def __init__(self, step_id, delta):
+ """Constructor.
+
+ Args:
+ step_id: The step id in that commit.
+ delta: Difference between the step.begin and the real position.
+ """
+ self.step_id = step_id
+ self.delta = delta
+
+class CursorPosInfo_RelativeToOld(object):
+ """Stores a cursor position.
+
+ Attributes:
+ pos: The cursor position
+ """
+ def __init__(self, position):
+ """Constructor.
+
+ Args:
+ pos: The cursor position
+ """
+ self.pos = position
+
+
+class TextCommit(object):
+ """Contains a commit information.
+
+ Attributes:
+ _text: Text of this commit.
+ _steps: List of instance of ChgStepInfo between it and the prevoius
+ commit.
+ count: Number of user points to it.
+ """
+ def __init__(self, prev_text, my_text, count=1):
+ """Constructor.
+
+ Args:
+ prev_text: Text in previous commit.
+ my_text: Text in this commit.
+ count: Default count.
+ """
+ self._text = my_text
+ self._steps = []
+ diff = difflib.SequenceMatcher(a=prev_text, b=my_text)
+ for tag, begin, end, begin2, end2 in diff.get_opcodes():
+ if tag in ('replace', 'delete', 'insert'):
+ self._steps += [ChgStepInfo(begin, end,
+ begin2, my_text[begin2 : end2])]
+
+ self.count = count
+
+ def rebase(self, commits):
+ """Rebase to specified commits.
+
+ Like git-rebase.
+
+ Args:
+ commits: List of instance of TextCommit.
+ """
+ for commit in commits:
+ self._rebase_steps(commit._steps)
+ if commits:
+ self._rebase_text(commits[-1].text)
+
+ def get_cursor_pos_info(self, cursor_pos):
+ """Gets the cursor position information.
+
+ Args:
+ cursor_pos: A number for cursor position.
+ """
+ delta = 0
+ for ind in range(len(self._steps)):
+ step = self._steps[ind]
+ if step.begin2 <= cursor_pos < step.begin2 + len(step.new_str):
+ return CursorPosInfo_RelativeToNew(ind,
+ cursor_pos - step.begin2)
+ elif cursor_pos < step.begin2:
+ break
+ delta += self._steps[ind].increased_length
+ return CursorPosInfo_RelativeToOld(cursor_pos - delta)
+
+ def get_cursor_pos(self, cursor_pos_info):
+ """Gets the cursor position.
+
+ Args:
+ cursor_pos_info: A instance of one of the CursorPosInfo_*
+
+ Returns:
+ A number for the cursor position.
+ """
+ if isinstance(cursor_pos_info, CursorPosInfo_RelativeToNew):
+ return (self._steps[cursor_pos_info.step_id].begin +
+ cursor_pos_info.delta)
+ else:
+ return cursor_pos_info.pos
+
+ def rebase_cursor_pos_info(self, cursor_pos_info):
+ """Rebases the cursor position info by applying steps in this commit.
+
+ Args:
+ cursor_pos_info: An instance of one of the CursorPosInfo_*
+
+ Returns:
+ An instance of one of the CursorPosInfo_*
+ """
+ if isinstance(cursor_pos_info, CursorPosInfo_RelativeToOld):
+ cursor_pos_info.pos = self.rebase_cursor_pos(cursor_pos_info.pos)
+
+ def rebase_cursor_pos(self, cursor_pos):
+ """Rebases the cursor position by applying steps in this commit.
+
+ Args:
+ cursor_pos: Cursor position to be rebased.
+
+ Returns:
+ Rebased cursor position.
+ """
+ for step in self._steps:
+ cursor_pos = max([cursor_pos, step.end]) + step.increased_length
+ return cursor_pos
+
+ @property
+ def text(self):
+ """Gets the text of this commit."""
+ return self._text
+
+ def _rebase_steps(self, steps):
+ """Rebases the specified steps only.
+
+ Args:
+ steps: List of instance of ChgStepInfo.
+ """
+ for step in steps:
+ for my_step in self._steps:
+ my_step.rebase(step)
+
+ def _rebase_text(self, past_text):
+ """Rebases the text by the specified text.
+
+ Args:
+ past_text: The specified text.
+ """
+ end_index = 0
+ self._text = ''
+ for step in self._steps:
+ self._text += past_text[end_index : step.begin]
+ self._text += step.new_str
+ end_index = step.end
+ self._text += past_text[end_index : ]
+
+ def __str__(self):
+ s = '\n'.join(['Replace (%d, %d) with %r' % (e.begin, e.end, e.new_str)
+ for e in self._steps] +
+ ['Become %r' % self._text])
+ return s
+
+
+class Snapshot(object):
+ """Stores informations about current state.
+
+ Attrbutes:
+ _text: Text.
+ _cursor: Cursor position.
+ _cursors: Other users' cursor position.
+ """
+ def __init__(self, text, cursor_pos, other_cursors):
+ """Constructor.
+
+ Args:
+ text: Text.
+ cursor_pos: Cursor position.
+ other_cursors: Other users' cursor positions.
+ """
+ self._text = text
+ self._cursor = cursor_pos
+ self._cursors = other_cursors
+
+ @property
+ def text(self):
+ """Gets the text."""
+ return self._text
+
+ @property
+ def cursor(self):
+ """Gets the cursor position."""
+ return self._cursor
+
+ @property
+ def cursors(self):
+ """Gets the other users' cursor positions."""
+ return self._cursors
+
+
+class TextChain(object):
+ """Contains a list of text.
+
+ You can "commit" a new version of text string from a very old text string,
+ because it will automatically merge it.
+
+ Attributes:
+ _commits_id_max: Maximum index of the commits.
+ _commits: Dict with key be the commit and the value be the commit.
+ _saved_filename: Name of the file for initialize and backup.
+ _lock: Lock for preventing multiple threads commit at the same time.
+ """
+ def __init__(self, filename):
+ """Constructor.
+
+ Args:
+ filename: Name of the file for initialize and backup.
+ """
+ self._commits_id_max = 0
+ self._commits = {self._commits_id_max : TextCommit('', '')}
+ self._saved_filename = filename
+ self._lock = threading.RLock()
+ try:
+ with open(filename, 'r') as f:
+ self.commit(self._commits_id_max, Snapshot(f.read(), 0, []))
+ self.del_reference(self._commits_id_max)
+ except IOError:
+ pass
+
+ def commit(self, last_commit_id, snapshot):
+ """Commits a text.
+
+ Args:
+ last_commit_id: The id of the original commit.
+ snapshot: The snapshot of the state, includes:
+ text: The new text.
+ cursor: The user's cursor position.
+ cursors: Other users' cursor positions in the previous commit.
+
+ Returns:
+ A 2-tuple for new_commit_id, new_snapshot
+ """
+ with self._lock:
+ new_commit = TextCommit(self._commits[last_commit_id].text,
+ snapshot.text)
+ cursor_pos_info = new_commit.get_cursor_pos_info(snapshot.cursor)
+ new_commit.rebase(self._get_commits_after(last_commit_id))
+ new_commit_id = self._push_commit(new_commit)
+ new_snapshot = Snapshot(
+ new_commit.text,
+ self._commit_rebase_cursor(last_commit_id, cursor_pos_info),
+ [new_commit.rebase_cursor_pos(x) for x in snapshot.cursors])
+ self.del_reference(last_commit_id)
+ self._save()
+ return (new_commit_id, new_snapshot)
+
+ def text(self, commit_id):
+ """Gets the text of a specified commit id.
+
+ Args:
+ commit_id: The specified commit id.
+
+ Returns:
+ text.
+ """
+ return self._commits[commit_id].text
+
+ def _push_commit(self, new_commit):
+ """Pushes a commit into commit chain.
+
+ Args:
+ new_commit: The commit to be push.
+
+ Returns:
+ The commit of this new commit.
+ """
+ self._commits_id_max += 1
+ self._commits[self._commits_id_max] = new_commit
+ return self._commits_id_max
+
+ def _commit_rebase_cursor(self, last_commit_id, cursor_pos_info):
+ """Updates the cursor by gived the gived cursor position.
+
+ Args:
+ last_commit_it: Last commit id.
+ cursor_info: An instance of CursorPosInfo
+
+ Returns:
+ New cursor position.
+ """
+ commits_for_rebase = self._get_commits_after(last_commit_id)
+ for commit in commits_for_rebase:
+ commit.rebase_cursor_pos_info(cursor_pos_info)
+ return commits_for_rebase[-1].get_cursor_pos(cursor_pos_info)
+
+ def del_reference(self, commit_id):
+ """Removes a reference from a commit id.
+
+ Args:
+ commit_id: Commit id.
+ """
+ with self._lock:
+ self._commits[commit_id].count -= 1
+ self._merge_unnecessary_commits()
+
+ def _get_commits_after(self, threshold):
+ """Returns the commit after a specified id.
+
+ Args:
+ threshold: The specified commit id.
+
+ Returns:
+ A list of commits.
+ """
+ valid_commits = [c for c in self._commits.items() if c[0] > threshold]
+ return [cmt[1] for cmt in sorted(valid_commits, key=lambda x: x[0])]
+
+ def _save(self):
+ """Save the last commit into the file."""
+ try:
+ with open(self._saved_filename, 'w') as f:
+ f.write(self._commits[self._commits_id_max].text)
+ except IOError:
+ pass
+
+ def _merge_unnecessary_commits(self):
+ """Merges the unnecessary commits."""
+ commit_ids = sorted(self._commits.keys())
+ try:
+ for index in range(1, len(commit_ids) - 1):
+ pre_id = commit_ids[index - 1]
+ cur_id = commit_ids[index]
+ nxt_id = commit_ids[index + 1]
+ print('%d---%d(%d)---%d' %
+ (pre_id, cur_id, self._commits[cur_id].count, nxt_id))
+ if self._commits[cur_id].count == 0:
+ del self._commits[cur_id]
+ self._commits[nxt_id] = TextCommit(
+ self._commits[pre_id].text,
+ self._commits[nxt_id].text,
+ self._commits[nxt_id].count)
+ except KeyError as e:
+ print('Shit!! %r' % e)
+ print(sorted(self._commits.keys()))
+
+
+class UserInfo(object):
+ """Informations about a user.
+
+ Attributes:
+ _authority: Authority of this user.
+ last_commit_id: Previous commit id.
+ last_cursor_pos: Last cursor position.
+ """
+ def __init__(self, authority):
+ """Constructor.
+
+ Args:
+ authroity: Authority of this user.
+ """
+ self._authority = authority
+ self.last_commit_id = 0
+ self.last_cursor_pos = 0
+
+ def reset(self):
+ """Reset the commit and cursor to the default value."""
+ self.last_commit_id = 0
+ self.last_cursor_pos = 0
+
+ @property
+ def authority(self):
+ """Gets the autority of this user."""
+ return self._authority
+
+
+class UsersPoolException(Exception):
+ """Exception raised by UsersPool."""
+ pass
+
+class UsersPool(object):
+ """A pool for users.
+
+ Attributes:
+ _users: A dict (key=name, value=an instance of User) of users.
+ _text_chain: A instance of TextChain.
+ _lock: A lock for preventing multiple thread trying to sync at the same
+ time.
+ """
+ def __init__(self, text_chain):
+ """Constructor.
+
+ Args:
+ text_chain: A instance of TextChain
+ """
+ self._users = {}
+ self._text_chain = text_chain
+ self._lock = threading.Lock()
+
+ def add(self, identity, authority):
+ """Adds an user.
+
+ Args:
+ identity: The identity of this new user.
+ authority: The authority of this new user.
+ """
+ with self._lock:
+ if identity in self._users:
+ raise UsersPoolException('User %r already in the users list' %
+ identity)
+ self._users[identity] = UserInfo(authority)
+
+ def delete(self, identity):
+ """Deletes an user.
+
+ Args:
+ identity: The identity of the user.
+ """
+ with self._lock:
+ if identity not in self._users:
+ raise UsersPoolException('User %r not in the users list' %
+ identity)
+ del self._users[identity]
+
+ def sync(self, identity, text, cursor_pos, reset):
+ """Sync the specified user.
+
+ Args:
+ identity: Identity of the user.
+ text: The new text commit from that user.
+ cursor_pos: The cursor position of that user.
+ reset: A flag for reset or not.
+
+ Returns
+ An instance of Snapshot.
+ """
+ with self._lock:
+ if identity not in self._users:
+ raise UsersPoolException('User %r not in the users list' %
+ identity)
+ user = self._users[identity]
+ if reset:
+ user.reset()
+ if user.authority < AUTHORITY.READWRITE:
+ text = self._text_chain.text(user.last_commit_id)
+ cursors = [v.last_cursor_pos
+ for k, v in self._users.items() if k != identity]
+ user.last_commit_id, new_snapshot = self._text_chain.commit(
+ user.last_commit_id, Snapshot(text, cursor_pos, cursors))
+ user.last_cursor_pos = new_snapshot.cursor
+ others = [v for k, v in self._users.items() if k != identity]
+ for index in range(len(others)):
+ others[index].last_cursor_pos = new_snapshot.cursors[index]
+ return new_snapshot
+
+ @property
+ def users(self):
+ """Gets a dict with key=identity, value=authority of users."""
+ ret = {}
+ for identity, user in self._users.items():
+ ret[identity] = user.authority
+ return ret
+
+
+def load_user_list(users_pool, filename):
+ """Loads user list from the file.
+
+ Args:
+ users_pool: A instance of UsersPool.
+ filename: File name.
+ """
+ with open(filename, 'r') as f:
+ while True:
+ line = f.readline()
+ if not line:
+ break
+ if line.endswith('\n'):
+ line = line[:-1]
+ words = [word for word in re.split(r'[ \t\n]', line) if word]
+ if len(words) != 2:
+ continue
+ identity = words[0]
+ authority = AuthorityStringTransformer.to_number(words[1])
+ if authority is None:
+ continue
+ try:
+ users_pool.add(identity, authority)
+ except UsersPoolException:
+ continue
+
+
+class AuthorityStringTransformer: # pylint:disable=W0232
+ """Transforms authority between number and strin format."""
+ @staticmethod
+ def to_string(authority):
+ """Transform number authority value to string value.
+
+ Args:
+ authority: Authority in number format.
+
+ Returns:
+ authority: Corrosponding authority in string format.
+ """
+ if authority == AUTHORITY.READONLY:
+ return 'RO'
+ elif authority == AUTHORITY.READWRITE:
+ return 'RW'
+
+ @staticmethod
+ def to_number(string):
+ """Transform string authority value to number value.
+
+ Args:
+ authority: Authority in string format.
+
+ Returns:
+ authority: Corrosponding authority in number format.
+ """
+ if string == 'RO':
+ return AUTHORITY.READONLY
+ elif string == 'RW':
+ return AUTHORITY.READWRITE
+ return None
+
+class CommandHandler(object):
+ """Base class for CommandHandle_*
+
+ Attributes:
+ _prefix: Prefix string of this command if exists.
+ """
+ def __init__(self, prefix=None):
+ self._prefix = prefix
+
+ @property
+ def prefix(self):
+ """Gets the prefix of this command handler."""
+ return self._prefix
+
+ def is_mine(self, cmd):
+ """Checks whether the specified command is mine."""
+ return cmd.startswith(self._prefix)
+
+ def handle(self, unused_cmd, unused_users_pool):
+ """Run the command."""
+ raise NotImplementedError
+
+
+class CommandHandler_Exit(CommandHandler):
+ """Command handler for exiting."""
+ def __init__(self):
+ super(CommandHandler_Exit, self).__init__('exit')
+
+ def handle(self, unused_cmd, unused_userspool):
+ normal_print('bye~\n', prompt=False)
+ exit()
+
+class CommandHandler_Add(CommandHandler):
+ """Command handler for adding a user."""
+ def __init__(self):
+ super(CommandHandler_Add, self).__init__('add ')
+
+ def handle(self, cmd, users_pool):
+ words = [word for word in re.split(r'[ \t\n]', cmd) if word]
+ if len(words) != 3:
+ return 'Wrong number of arguments.'
+ identity = words[1]
+ authority = AuthorityStringTransformer.to_number(words[2])
+ if authority is None:
+ return 'Unknown flag.'
+ try:
+ users_pool.add(identity, authority)
+ except UsersPoolException as e:
+ return str(e)
+ return 'Done.'
+
+class CommandHandler_Delete(CommandHandler):
+ """Command handler for deleting a user."""
+ def __init__(self):
+ super(CommandHandler_Delete, self).__init__('delete ')
+
+ def handle(self, cmd, users_pool):
+ words = [word for word in re.split(r'[ \t\n]', cmd) if word]
+ if len(words) != 2:
+ return 'Wrong number of arguments.'
+ identity = words[1]
+ try:
+ users_pool.delete(identity)
+ except UsersPoolException as e:
+ return str(e)
+ return 'Done.'
+
+class CommandHandler_Load(CommandHandler):
+ """Command handler for loading user list from a file."""
+ def __init__(self):
+ super(CommandHandler_Load, self).__init__('load ')
+
+ def handle(self, cmd, users_pool):
+ words = [word for word in re.split(r'[ \t\n]', cmd) if word]
+ if len(words) != 2:
+ return 'Wrong number of arguments.'
+ filename = words[1]
+ try:
+ load_user_list(users_pool, filename)
+ except IOError as e:
+ return str(e)
+ return 'Reads user list from file %r successfully.' % filename
+
+class CommandHandler_Save(CommandHandler):
+ """Command handler for saving user list to a file."""
+ def __init__(self):
+ super(CommandHandler_Save, self).__init__('save ')
+
+ def handle(self, cmd, users_pool):
+ words = [word for word in re.split(r'[ \t\n]', cmd) if word]
+ if len(words) != 2:
+ return 'Wrong number of arguments.'
+ filename = words[1]
+ try:
+ with open(filename, 'w') as f:
+ for identity, auth in users_pool.users.items():
+ f.write('%s %s\n' %
+ (identity,
+ AuthorityStringTransformer.to_string(auth)))
+ except IOError as e:
+ return str(e)
+ return 'Write user list to file %r successfully' % filename
+
+class CommandHandler_List(CommandHandler):
+ """Command handler for listing user list to a file."""
+ def __init__(self):
+ super(CommandHandler_List, self).__init__('list')
+
+ def handle(self, unused_cmd, users_pool):
+ ret = ''
+ for identity, auth in users_pool.users.items():
+ ret += '%r => %s\n' % (identity,
+ AuthorityStringTransformer.to_string(auth))
+ return ret
+
+
+def cui_thread(users_pool):
+ """Threading function for CUI.
+
+ Args:
+ users_pool: An instance of UsersPool.
+ """
+ commands = [
+ CommandHandler_Exit(),
+ CommandHandler_Add(),
+ CommandHandler_Delete(),
+ CommandHandler_Load(),
+ CommandHandler_Save(),
+ CommandHandler_List(),
+ ]
+ normal_print()
+ while True:
+ line_str = sys.stdin.readline()
+ if line_str.endswith('\n'):
+ line_str = line_str[:-1]
+ else:
+ CommandHandler_Exit().handle('', users_pool)
+ for command in commands:
+ if command.is_mine(line_str):
+ result = command.handle(line_str, users_pool)
+ normal_print(result + '\n')
+ break
+ else:
+ normal_print('Unknown command %r\n' % line_str)
+
+def start_cui(users_pool):
+ """Starts a thread for running CUI.
+
+ Returns:
+ A instance of Thread.
+ """
+ cui = threading.Thread(target=cui_thread, args=(users_pool,))
+ cui.start()
+ return cui
+
+
+def tcp_server_accepter_thread(port, users_pool):
+ """Accepts tcp connection forever.
+
+ Args:
+ users_pool: An instance of UsersPool.
+ """
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(('', port))
+ s.listen(BACKLOG)
+ except socket.error:
+ normal_print('Cannot start the TCP server at port %d, abort\n' % port,
+ prompt=False)
+ exit(1)
+ normal_print('Server started (port = %d)\n' % port)
+ while True:
+ conn, addr = s.accept()
+ normal_print('Connection address: %s\n' % str(addr))
+ start_tcp_receiver(conn, users_pool)
+
+def start_tcp_server_accepter(port, users_pool):
+ """Starts a thread for running tcp server.
+
+ Args:
+ port: Port to listen.
+ users_pool: An instance of UsersPool.
+ """
+ thr = threading.Thread(target=tcp_server_accepter_thread,
+ args=(port, users_pool))
+ thr.daemon = True
+ thr.start()
+
+
+def tcp_receiver_thread(conn, users_pool):
+ """Handles client request.
+
+ Args:
+ conn: connection.
+ users_pool: An instance of UsersPool.
+ """
+ tcp = StringTCP(conn)
+ request = json.loads(tcp.recv_string())
+ try:
+ snapshot = users_pool.sync(request[JSON_TOKEN.IDENTITY],
+ request[JSON_TOKEN.TEXT],
+ request[JSON_TOKEN.CURSOR],
+ reset=request[JSON_TOKEN.INIT])
+ response = {
+ JSON_TOKEN.TEXT : snapshot.text,
+ JSON_TOKEN.CURSOR : snapshot.cursor,
+ JSON_TOKEN.CURSORS : snapshot.cursors,
+ }
+ except UsersPoolException as e:
+ response = {JSON_TOKEN.ERROR : str(e)}
+ print('request = %r' % request)
+ print('response = %r' % response)
+ tcp.send_string(json.dumps(response))
+ tcp.close()
+
+
+def start_tcp_receiver(conn, users_pool):
+ """Starts a thread for handling client request.
+
+ Args:
+ conn: connection.
+ users_pool: An instance of UsersPool.
+ """
+ thr = threading.Thread(target=tcp_receiver_thread, args=(conn, users_pool))
+ thr.daemon = True
+ thr.start()
+
+
+def main():
+ """Program entry point."""
+ help_str = '[USAGE] %s <port_number> <user_list> <save_file>' % sys.argv[0]
+
+ if len(sys.argv) != 4:
+ normal_print('Wrong arguments!!\n%s' % help_str, prompt=False)
+ exit(1)
+
+ try:
+ port = int(sys.argv[1])
+ except ValueError:
+ normal_print('Wrong port!!\n%s' % help_str, prompt=False)
+ exit(1)
+
+ users_pool = UsersPool(TextChain(sys.argv[3]))
+ load_user_list(users_pool, sys.argv[2])
+ start_tcp_server_accepter(port, users_pool)
+ cui_thr = start_cui(users_pool)
+ cui_thr.join()
+
+
+if __name__ == '__main__':
+ main()