diff options
author | cathook <b01902109@csie.ntu.edu.tw> | 2014-11-04 21:27:07 +0800 |
---|---|---|
committer | cathook <b01902109@csie.ntu.edu.tw> | 2014-11-04 21:27:07 +0800 |
commit | b03e0b80e59fc649a7d26c880d10b545aeee6024 (patch) | |
tree | 141eac55ab2f1297f9061a072f7d342e7ca1b1d7 | |
download | vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.gz vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.bz2 vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.lz vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.xz vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.tar.zst vim-shrvim-b03e0b80e59fc649a7d26c880d10b545aeee6024.zip |
Init commit, gives a prototype of the shared vim.
-rw-r--r-- | shared_vim.vim | 257 | ||||
-rwxr-xr-x | shared_vim_server | 995 |
2 files changed, 1252 insertions, 0 deletions
diff --git a/shared_vim.vim b/shared_vim.vim new file mode 100644 index 0000000..d2b06f5 --- /dev/null +++ b/shared_vim.vim @@ -0,0 +1,257 @@ +function! SharedVimConnect(server_name, port, identity) + let b:shared_vim_server_name = a:server_name + let b:shared_vim_port = a:port + let b:shared_vim_identity = a:identity + let b:shared_vim_init = 1 + call SharedVimSync() +endfunction + + +function! SharedVimDisconnect() + unlet! b:shared_vim_server_name + unlet! b:shared_vim_port + unlet! b:shared_vim_identity + unlet! b:shared_vim_init +endfunction + + +function! SharedVimSync() +python << EOF +import json +import re +import socket +import vim +import zlib + + +class JSON_TOKEN: # pylint:disable=W0232 + """Enumeration the Ttken strings for json object.""" + CURSOR = 'cursor' # cursor position + CURSORS = 'cursors' # other users' cursor position + ERROR = 'error' # error string + IDENTITY = 'identity' # identity of myself + INIT = 'init' # initialize connect flag + TEXT = 'text' # text content in the buffer + +def vim_input(prompt='', default_value=''): + vim.command('call inputsave()') + vim.command("let user_input = input('%s','%s')" % (prompt, default_value)) + vim.command('call inputrestore()') + return vim.eval('user_input') + +class StringTCP(object): + """Send/receive strings by tcp connection. + + Attributes: + _connection: The tcp connection. + """ + ENCODING = 'utf-8' + COMPRESS_LEVEL = 2 + HEADER_LENGTH = 10 + + def __init__(self, sock=None, servername=None, port=None): + """Constructor. + + Args: + sock: The tcp connection. if it is None, the constructor will + automatically creates an tcp connection to servername:port. + servername: The server name if needs. + port: The server port if needs. + """ + if sock is None: + self._connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._connection.connect((servername, port)) + else: + self._connection = sock + + def send_string(self, string): + """Sends a string to the tcp-connection. + + Args: + string: The string to be sent. + """ + body = StringTCP._create_body_from_string(string) + header = StringTCP._create_header_from_body(body) + self._connection.send(header + body) + + def recv_string(self): + """Receives a string from the tcp-connection. + + Returns: + The string received. + """ + header = StringTCP._recv_header_string(self._connection) + body = StringTCP._recv_body_string(self._connection, header) + return body + + def close(self): + """Closes the socket.""" + self._connection.close() + + @staticmethod + def _create_body_from_string(string): + """Creates package body from data string. + + Args: + string: Data string. + """ + byte_string = string.encode(StringTCP.ENCODING) + return zlib.compress(byte_string, StringTCP.COMPRESS_LEVEL) + + @staticmethod + def _create_header_from_body(body): + """Creates package header from package body. + + Args: + body: Package body. + """ + header_string = ('%%0%dd' % StringTCP.HEADER_LENGTH) % len(body) + return header_string.encode(StringTCP.ENCODING) + + @staticmethod + def _recv_header_string(conn): + """Receives package header from specified tcp connection. + + Args: + conn: The specified tcp connection. + + Returns: + Package header. + """ + return conn.recv(StringTCP.HEADER_LENGTH).decode(StringTCP.ENCODING) + + @staticmethod + def _recv_body_string(conn, header): + """Receives package body from specified tcp connection and header. + + Args: + conn: The specified tcp connection. + header: The package header. + + Returns: + Package body. + """ + body_length = int(header) + body = conn.recv(body_length) + body_byte = zlib.decompress(body) + return body_byte.decode(StringTCP.ENCODING) + + +def rc_to_num(lines, rc): + """Transforms cursor position from row-col format to numeric format. + + Args: + lines: List of lines of the context. + rc: A 2-tuple for row-col position. + + Returns: + A number for cursor position. + """ + result = 0 + for row in range(0, rc[0]): + result += len(lines[row]) + 1 + result += rc[1] + return result + +def nums_to_rcs(lines, nums): + """Transforms cursor positions from numeric format to row-col format. + + Args: + lines: List of lines of the context. + nums: List of number of each positions to be transformed. + + Returns: + A list of 2-tuple row-col format cursor positions. + """ + sorted_index = sorted(range(len(nums)), key=lambda x: nums[x]) + now_index, max_index = 0, len(sorted_index) + num = 0 + rcs = [None] * len(nums) + for row in range(len(lines)): + num_max = num + len(lines[row]) + while now_index < max_index: + if nums[sorted_index[now_index]] > num_max: + break + rcs[sorted_index[now_index]] = \ + (row, nums[sorted_index[now_index]] - num) + now_index += 1 + else: + break + num = num_max + 1 + return rcs + + +def main(): + """Main process.""" + try: + # Fetches information. + server_name = vim.current.buffer.vars['shared_vim_server_name'] + port = vim.current.buffer.vars['shared_vim_port'] + identity = vim.current.buffer.vars['shared_vim_identity'] + cursor_position = rc_to_num(vim.current.buffer[:], + (vim.current.window.cursor[0] - 1, + vim.current.window.cursor[1])) + text = '\n'.join(vim.current.buffer[:]) + init_flag = vim.current.buffer.vars['shared_vim_init'] + + if text and init_flag: + result = vim_input('It will clear the buffer, would you want to ' + + 'continue? [Y/n] ', 'Y') + if result == 'y' or result == 'Y': + vim.current.buffer[0 : len(vim.current.buffer)] = [''] + text = '' + else: + return + print('') + + # Creates request. + request = { + JSON_TOKEN.IDENTITY : identity, + JSON_TOKEN.TEXT : text, + JSON_TOKEN.CURSOR : cursor_position, + JSON_TOKEN.INIT : bool(init_flag), + } + + # Connects to the server and gets the response. + print('Connect to %s:%d' % (server_name, port)) + conn = StringTCP(servername=server_name, port=port) + conn.send_string(json.dumps(request)) + response = json.loads(conn.recv_string()) + conn.close() + + # Sync. + if JSON_TOKEN.ERROR in response: + raise Exception('from server: ' + response[JSON_TOKEN.ERROR]) + else: + lines = re.split(r'\n', response[JSON_TOKEN.TEXT]) + rcs = nums_to_rcs(lines, response[JSON_TOKEN.CURSORS]) + my_rc = nums_to_rcs(lines, [response[JSON_TOKEN.CURSOR]])[0] + + other_cursors_ptrn = '/%s/' % ('\\|'.join( + ['\\%%%dl\\%%%dc' % (rc[0] + 1, rc[1] + 1) for rc in rcs])) + + vim.command('match SharedVimOthersCursors %s' % other_cursors_ptrn) + vim.current.buffer[0 : len(vim.current.buffer)] = lines + vim.current.window.cursor = (my_rc[0] + 1, my_rc[1]) + + except Exception as e: + print(e) + +main() +EOF + let b:shared_vim_init = 0 +endfunction + + +function! SharedVimEventsHandler(event_name) + if exists('b:shared_vim_server_name') + if a:event_name == 'VimCursorMoved' + call SharedVimSync() + endif + endif +endfunction + + +autocmd CursorMoved * call SharedVimEventsHandler('VimCursorMoved') + +highlight SharedVimOthersCursors ctermbg=darkred diff --git a/shared_vim_server b/shared_vim_server new file mode 100755 index 0000000..dd88155 --- /dev/null +++ b/shared_vim_server @@ -0,0 +1,995 @@ +#! /usr/bin/env python3 + +"""Shared Vim server.""" + +import difflib +import json +import re +import socket +import sys +import threading +import zlib + + +BACKLOG = 1024 +PROMPT = '> ' + +class AUTHORITY: # pylint:disable=W0232 + """Enumeration the types of authority.""" + READONLY = 1 # can only read. + READWRITE = 2 # can read and write. + +class JSON_TOKEN: # pylint:disable=W0232 + """Enumeration the Ttken strings for json object.""" + CURSOR = 'cursor' # cursor position + CURSORS = 'cursors' # other users' cursor position + ERROR = 'error' # error string + IDENTITY = 'identity' # identity of myself + INIT = 'init' # initialize connect flag + TEXT = 'text' # text content in the buffer + + +def normal_print(string=None, prompt=True): + """Prints string to stdout. + + It will also prints a prompt string for CUI if needs. + + Args: + string: The string to be printed. + prompt: True of it needs to print a prompt string. + """ + string = '' if string is None else string + if prompt: + sys.stdout.write(string + PROMPT) + else: + sys.stdout.write(string) + sys.stdout.flush() + + +class StringTCP(object): + """Send/receive strings by tcp connection. + + Attributes: + _connection: The tcp connection. + """ + ENCODING = 'utf-8' + COMPRESS_LEVEL = 2 + HEADER_LENGTH = 10 + + def __init__(self, sock=None, servername=None, port=None): + """Constructor. + + Args: + sock: The tcp connection. if it is None, the constructor will + automatically creates an tcp connection to servername:port. + servername: The server name if needs. + port: The server port if needs. + """ + if sock is None: + self._connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._connection.connect((servername, port)) + else: + self._connection = sock + + def send_string(self, string): + """Sends a string to the tcp-connection. + + Args: + string: The string to be sent. + """ + body = StringTCP._create_body_from_string(string) + header = StringTCP._create_header_from_body(body) + self._connection.send(header + body) + + def recv_string(self): + """Receives a string from the tcp-connection. + + Returns: + The string received. + """ + header = StringTCP._recv_header_string(self._connection) + body = StringTCP._recv_body_string(self._connection, header) + return body + + def close(self): + """Closes the socket.""" + self._connection.close() + + @staticmethod + def _create_body_from_string(string): + """Creates package body from data string. + + Args: + string: Data string. + + Returns: + Package body. + """ + byte_string = string.encode(StringTCP.ENCODING) + return zlib.compress(byte_string, StringTCP.COMPRESS_LEVEL) + + @staticmethod + def _create_header_from_body(body): + """Creates package header from package body. + + Args: + body: Package body. + + Returns: + Package header. + """ + header_string = ('%%0%dd' % StringTCP.HEADER_LENGTH) % len(body) + return header_string.encode(StringTCP.ENCODING) + + @staticmethod + def _recv_header_string(conn): + """Receives package header from specified tcp connection. + + Args: + conn: The specified tcp connection. + + Returns: + Package header. + """ + return conn.recv(StringTCP.HEADER_LENGTH).decode(StringTCP.ENCODING) + + @staticmethod + def _recv_body_string(conn, header): + """Receives package body from specified tcp connection and header. + + Args: + conn: The specified tcp connection. + header: The package header. + + Returns: + Package body. + """ + body_length = int(header) + body = conn.recv(body_length) + body_byte = zlib.decompress(body) + return body_byte.decode(StringTCP.ENCODING) + + +class ChgStepInfo(object): + """Stores one step in modifying one string to anothor string. + + Here a step is to replace a range of substring with another string. So from + the original string, we can do lots of step and then get the final result. + + Attributes: + _begin: Begin of the replacement in the original string. + _end: End of the replacement in the original string. + _new_str: The string to replace. + _begin2: Where the _new_str at the final string. + """ + def __init__(self, begin, end, begin2, new_str): + """Constructor. + + Args: + begin: Begin of the replacement in the original string. + end: End of the replacement in the original string. + begin2: Where the _new_str at the final string. + new_str: The string to replace. + """ + self._begin = begin + self._end = end + self._begin2 = begin2 + self._new_str = new_str + + def rebase(self, step2): + """Inserts a modify event before it. + + It will update the self._begin, self._end to prevent confliction. + + Args: + step: The other modify event. + """ + if self._begin < step2.begin: + self._end = min([self._end, step2.begin]) + else: + self._begin = max([self._begin, step2.end]) + self._end = min([self._begin, self._end]) + self._begin += step2.increased_length + self._end += step2.increased_length + + @property + def begin(self): + """Gets the begin of the replacing range.""" + return self._begin + + @property + def end(self): + """Gets the end of the replacing range.""" + return self._end + + @property + def begin2(self): + """Gets the begin of the replacing string.""" + return self._begin2 + + @property + def new_str(self): + """Gets the replacing string.""" + return self._new_str + + @property + def increased_length(self): + """Gets the amount of length increased if applying this step.""" + return len(self.new_str) - (self.end - self.begin) + + +class CursorPosInfo_RelativeToNew(object): + """Stores a cursor position. + + Attributes: + step_id: The step id in that commit. + delta: Difference between the step.begin and the real position. + """ + def __init__(self, step_id, delta): + """Constructor. + + Args: + step_id: The step id in that commit. + delta: Difference between the step.begin and the real position. + """ + self.step_id = step_id + self.delta = delta + +class CursorPosInfo_RelativeToOld(object): + """Stores a cursor position. + + Attributes: + pos: The cursor position + """ + def __init__(self, position): + """Constructor. + + Args: + pos: The cursor position + """ + self.pos = position + + +class TextCommit(object): + """Contains a commit information. + + Attributes: + _text: Text of this commit. + _steps: List of instance of ChgStepInfo between it and the prevoius + commit. + count: Number of user points to it. + """ + def __init__(self, prev_text, my_text, count=1): + """Constructor. + + Args: + prev_text: Text in previous commit. + my_text: Text in this commit. + count: Default count. + """ + self._text = my_text + self._steps = [] + diff = difflib.SequenceMatcher(a=prev_text, b=my_text) + for tag, begin, end, begin2, end2 in diff.get_opcodes(): + if tag in ('replace', 'delete', 'insert'): + self._steps += [ChgStepInfo(begin, end, + begin2, my_text[begin2 : end2])] + + self.count = count + + def rebase(self, commits): + """Rebase to specified commits. + + Like git-rebase. + + Args: + commits: List of instance of TextCommit. + """ + for commit in commits: + self._rebase_steps(commit._steps) + if commits: + self._rebase_text(commits[-1].text) + + def get_cursor_pos_info(self, cursor_pos): + """Gets the cursor position information. + + Args: + cursor_pos: A number for cursor position. + """ + delta = 0 + for ind in range(len(self._steps)): + step = self._steps[ind] + if step.begin2 <= cursor_pos < step.begin2 + len(step.new_str): + return CursorPosInfo_RelativeToNew(ind, + cursor_pos - step.begin2) + elif cursor_pos < step.begin2: + break + delta += self._steps[ind].increased_length + return CursorPosInfo_RelativeToOld(cursor_pos - delta) + + def get_cursor_pos(self, cursor_pos_info): + """Gets the cursor position. + + Args: + cursor_pos_info: A instance of one of the CursorPosInfo_* + + Returns: + A number for the cursor position. + """ + if isinstance(cursor_pos_info, CursorPosInfo_RelativeToNew): + return (self._steps[cursor_pos_info.step_id].begin + + cursor_pos_info.delta) + else: + return cursor_pos_info.pos + + def rebase_cursor_pos_info(self, cursor_pos_info): + """Rebases the cursor position info by applying steps in this commit. + + Args: + cursor_pos_info: An instance of one of the CursorPosInfo_* + + Returns: + An instance of one of the CursorPosInfo_* + """ + if isinstance(cursor_pos_info, CursorPosInfo_RelativeToOld): + cursor_pos_info.pos = self.rebase_cursor_pos(cursor_pos_info.pos) + + def rebase_cursor_pos(self, cursor_pos): + """Rebases the cursor position by applying steps in this commit. + + Args: + cursor_pos: Cursor position to be rebased. + + Returns: + Rebased cursor position. + """ + for step in self._steps: + cursor_pos = max([cursor_pos, step.end]) + step.increased_length + return cursor_pos + + @property + def text(self): + """Gets the text of this commit.""" + return self._text + + def _rebase_steps(self, steps): + """Rebases the specified steps only. + + Args: + steps: List of instance of ChgStepInfo. + """ + for step in steps: + for my_step in self._steps: + my_step.rebase(step) + + def _rebase_text(self, past_text): + """Rebases the text by the specified text. + + Args: + past_text: The specified text. + """ + end_index = 0 + self._text = '' + for step in self._steps: + self._text += past_text[end_index : step.begin] + self._text += step.new_str + end_index = step.end + self._text += past_text[end_index : ] + + def __str__(self): + s = '\n'.join(['Replace (%d, %d) with %r' % (e.begin, e.end, e.new_str) + for e in self._steps] + + ['Become %r' % self._text]) + return s + + +class Snapshot(object): + """Stores informations about current state. + + Attrbutes: + _text: Text. + _cursor: Cursor position. + _cursors: Other users' cursor position. + """ + def __init__(self, text, cursor_pos, other_cursors): + """Constructor. + + Args: + text: Text. + cursor_pos: Cursor position. + other_cursors: Other users' cursor positions. + """ + self._text = text + self._cursor = cursor_pos + self._cursors = other_cursors + + @property + def text(self): + """Gets the text.""" + return self._text + + @property + def cursor(self): + """Gets the cursor position.""" + return self._cursor + + @property + def cursors(self): + """Gets the other users' cursor positions.""" + return self._cursors + + +class TextChain(object): + """Contains a list of text. + + You can "commit" a new version of text string from a very old text string, + because it will automatically merge it. + + Attributes: + _commits_id_max: Maximum index of the commits. + _commits: Dict with key be the commit and the value be the commit. + _saved_filename: Name of the file for initialize and backup. + _lock: Lock for preventing multiple threads commit at the same time. + """ + def __init__(self, filename): + """Constructor. + + Args: + filename: Name of the file for initialize and backup. + """ + self._commits_id_max = 0 + self._commits = {self._commits_id_max : TextCommit('', '')} + self._saved_filename = filename + self._lock = threading.RLock() + try: + with open(filename, 'r') as f: + self.commit(self._commits_id_max, Snapshot(f.read(), 0, [])) + self.del_reference(self._commits_id_max) + except IOError: + pass + + def commit(self, last_commit_id, snapshot): + """Commits a text. + + Args: + last_commit_id: The id of the original commit. + snapshot: The snapshot of the state, includes: + text: The new text. + cursor: The user's cursor position. + cursors: Other users' cursor positions in the previous commit. + + Returns: + A 2-tuple for new_commit_id, new_snapshot + """ + with self._lock: + new_commit = TextCommit(self._commits[last_commit_id].text, + snapshot.text) + cursor_pos_info = new_commit.get_cursor_pos_info(snapshot.cursor) + new_commit.rebase(self._get_commits_after(last_commit_id)) + new_commit_id = self._push_commit(new_commit) + new_snapshot = Snapshot( + new_commit.text, + self._commit_rebase_cursor(last_commit_id, cursor_pos_info), + [new_commit.rebase_cursor_pos(x) for x in snapshot.cursors]) + self.del_reference(last_commit_id) + self._save() + return (new_commit_id, new_snapshot) + + def text(self, commit_id): + """Gets the text of a specified commit id. + + Args: + commit_id: The specified commit id. + + Returns: + text. + """ + return self._commits[commit_id].text + + def _push_commit(self, new_commit): + """Pushes a commit into commit chain. + + Args: + new_commit: The commit to be push. + + Returns: + The commit of this new commit. + """ + self._commits_id_max += 1 + self._commits[self._commits_id_max] = new_commit + return self._commits_id_max + + def _commit_rebase_cursor(self, last_commit_id, cursor_pos_info): + """Updates the cursor by gived the gived cursor position. + + Args: + last_commit_it: Last commit id. + cursor_info: An instance of CursorPosInfo + + Returns: + New cursor position. + """ + commits_for_rebase = self._get_commits_after(last_commit_id) + for commit in commits_for_rebase: + commit.rebase_cursor_pos_info(cursor_pos_info) + return commits_for_rebase[-1].get_cursor_pos(cursor_pos_info) + + def del_reference(self, commit_id): + """Removes a reference from a commit id. + + Args: + commit_id: Commit id. + """ + with self._lock: + self._commits[commit_id].count -= 1 + self._merge_unnecessary_commits() + + def _get_commits_after(self, threshold): + """Returns the commit after a specified id. + + Args: + threshold: The specified commit id. + + Returns: + A list of commits. + """ + valid_commits = [c for c in self._commits.items() if c[0] > threshold] + return [cmt[1] for cmt in sorted(valid_commits, key=lambda x: x[0])] + + def _save(self): + """Save the last commit into the file.""" + try: + with open(self._saved_filename, 'w') as f: + f.write(self._commits[self._commits_id_max].text) + except IOError: + pass + + def _merge_unnecessary_commits(self): + """Merges the unnecessary commits.""" + commit_ids = sorted(self._commits.keys()) + try: + for index in range(1, len(commit_ids) - 1): + pre_id = commit_ids[index - 1] + cur_id = commit_ids[index] + nxt_id = commit_ids[index + 1] + print('%d---%d(%d)---%d' % + (pre_id, cur_id, self._commits[cur_id].count, nxt_id)) + if self._commits[cur_id].count == 0: + del self._commits[cur_id] + self._commits[nxt_id] = TextCommit( + self._commits[pre_id].text, + self._commits[nxt_id].text, + self._commits[nxt_id].count) + except KeyError as e: + print('Shit!! %r' % e) + print(sorted(self._commits.keys())) + + +class UserInfo(object): + """Informations about a user. + + Attributes: + _authority: Authority of this user. + last_commit_id: Previous commit id. + last_cursor_pos: Last cursor position. + """ + def __init__(self, authority): + """Constructor. + + Args: + authroity: Authority of this user. + """ + self._authority = authority + self.last_commit_id = 0 + self.last_cursor_pos = 0 + + def reset(self): + """Reset the commit and cursor to the default value.""" + self.last_commit_id = 0 + self.last_cursor_pos = 0 + + @property + def authority(self): + """Gets the autority of this user.""" + return self._authority + + +class UsersPoolException(Exception): + """Exception raised by UsersPool.""" + pass + +class UsersPool(object): + """A pool for users. + + Attributes: + _users: A dict (key=name, value=an instance of User) of users. + _text_chain: A instance of TextChain. + _lock: A lock for preventing multiple thread trying to sync at the same + time. + """ + def __init__(self, text_chain): + """Constructor. + + Args: + text_chain: A instance of TextChain + """ + self._users = {} + self._text_chain = text_chain + self._lock = threading.Lock() + + def add(self, identity, authority): + """Adds an user. + + Args: + identity: The identity of this new user. + authority: The authority of this new user. + """ + with self._lock: + if identity in self._users: + raise UsersPoolException('User %r already in the users list' % + identity) + self._users[identity] = UserInfo(authority) + + def delete(self, identity): + """Deletes an user. + + Args: + identity: The identity of the user. + """ + with self._lock: + if identity not in self._users: + raise UsersPoolException('User %r not in the users list' % + identity) + del self._users[identity] + + def sync(self, identity, text, cursor_pos, reset): + """Sync the specified user. + + Args: + identity: Identity of the user. + text: The new text commit from that user. + cursor_pos: The cursor position of that user. + reset: A flag for reset or not. + + Returns + An instance of Snapshot. + """ + with self._lock: + if identity not in self._users: + raise UsersPoolException('User %r not in the users list' % + identity) + user = self._users[identity] + if reset: + user.reset() + if user.authority < AUTHORITY.READWRITE: + text = self._text_chain.text(user.last_commit_id) + cursors = [v.last_cursor_pos + for k, v in self._users.items() if k != identity] + user.last_commit_id, new_snapshot = self._text_chain.commit( + user.last_commit_id, Snapshot(text, cursor_pos, cursors)) + user.last_cursor_pos = new_snapshot.cursor + others = [v for k, v in self._users.items() if k != identity] + for index in range(len(others)): + others[index].last_cursor_pos = new_snapshot.cursors[index] + return new_snapshot + + @property + def users(self): + """Gets a dict with key=identity, value=authority of users.""" + ret = {} + for identity, user in self._users.items(): + ret[identity] = user.authority + return ret + + +def load_user_list(users_pool, filename): + """Loads user list from the file. + + Args: + users_pool: A instance of UsersPool. + filename: File name. + """ + with open(filename, 'r') as f: + while True: + line = f.readline() + if not line: + break + if line.endswith('\n'): + line = line[:-1] + words = [word for word in re.split(r'[ \t\n]', line) if word] + if len(words) != 2: + continue + identity = words[0] + authority = AuthorityStringTransformer.to_number(words[1]) + if authority is None: + continue + try: + users_pool.add(identity, authority) + except UsersPoolException: + continue + + +class AuthorityStringTransformer: # pylint:disable=W0232 + """Transforms authority between number and strin format.""" + @staticmethod + def to_string(authority): + """Transform number authority value to string value. + + Args: + authority: Authority in number format. + + Returns: + authority: Corrosponding authority in string format. + """ + if authority == AUTHORITY.READONLY: + return 'RO' + elif authority == AUTHORITY.READWRITE: + return 'RW' + + @staticmethod + def to_number(string): + """Transform string authority value to number value. + + Args: + authority: Authority in string format. + + Returns: + authority: Corrosponding authority in number format. + """ + if string == 'RO': + return AUTHORITY.READONLY + elif string == 'RW': + return AUTHORITY.READWRITE + return None + +class CommandHandler(object): + """Base class for CommandHandle_* + + Attributes: + _prefix: Prefix string of this command if exists. + """ + def __init__(self, prefix=None): + self._prefix = prefix + + @property + def prefix(self): + """Gets the prefix of this command handler.""" + return self._prefix + + def is_mine(self, cmd): + """Checks whether the specified command is mine.""" + return cmd.startswith(self._prefix) + + def handle(self, unused_cmd, unused_users_pool): + """Run the command.""" + raise NotImplementedError + + +class CommandHandler_Exit(CommandHandler): + """Command handler for exiting.""" + def __init__(self): + super(CommandHandler_Exit, self).__init__('exit') + + def handle(self, unused_cmd, unused_userspool): + normal_print('bye~\n', prompt=False) + exit() + +class CommandHandler_Add(CommandHandler): + """Command handler for adding a user.""" + def __init__(self): + super(CommandHandler_Add, self).__init__('add ') + + def handle(self, cmd, users_pool): + words = [word for word in re.split(r'[ \t\n]', cmd) if word] + if len(words) != 3: + return 'Wrong number of arguments.' + identity = words[1] + authority = AuthorityStringTransformer.to_number(words[2]) + if authority is None: + return 'Unknown flag.' + try: + users_pool.add(identity, authority) + except UsersPoolException as e: + return str(e) + return 'Done.' + +class CommandHandler_Delete(CommandHandler): + """Command handler for deleting a user.""" + def __init__(self): + super(CommandHandler_Delete, self).__init__('delete ') + + def handle(self, cmd, users_pool): + words = [word for word in re.split(r'[ \t\n]', cmd) if word] + if len(words) != 2: + return 'Wrong number of arguments.' + identity = words[1] + try: + users_pool.delete(identity) + except UsersPoolException as e: + return str(e) + return 'Done.' + +class CommandHandler_Load(CommandHandler): + """Command handler for loading user list from a file.""" + def __init__(self): + super(CommandHandler_Load, self).__init__('load ') + + def handle(self, cmd, users_pool): + words = [word for word in re.split(r'[ \t\n]', cmd) if word] + if len(words) != 2: + return 'Wrong number of arguments.' + filename = words[1] + try: + load_user_list(users_pool, filename) + except IOError as e: + return str(e) + return 'Reads user list from file %r successfully.' % filename + +class CommandHandler_Save(CommandHandler): + """Command handler for saving user list to a file.""" + def __init__(self): + super(CommandHandler_Save, self).__init__('save ') + + def handle(self, cmd, users_pool): + words = [word for word in re.split(r'[ \t\n]', cmd) if word] + if len(words) != 2: + return 'Wrong number of arguments.' + filename = words[1] + try: + with open(filename, 'w') as f: + for identity, auth in users_pool.users.items(): + f.write('%s %s\n' % + (identity, + AuthorityStringTransformer.to_string(auth))) + except IOError as e: + return str(e) + return 'Write user list to file %r successfully' % filename + +class CommandHandler_List(CommandHandler): + """Command handler for listing user list to a file.""" + def __init__(self): + super(CommandHandler_List, self).__init__('list') + + def handle(self, unused_cmd, users_pool): + ret = '' + for identity, auth in users_pool.users.items(): + ret += '%r => %s\n' % (identity, + AuthorityStringTransformer.to_string(auth)) + return ret + + +def cui_thread(users_pool): + """Threading function for CUI. + + Args: + users_pool: An instance of UsersPool. + """ + commands = [ + CommandHandler_Exit(), + CommandHandler_Add(), + CommandHandler_Delete(), + CommandHandler_Load(), + CommandHandler_Save(), + CommandHandler_List(), + ] + normal_print() + while True: + line_str = sys.stdin.readline() + if line_str.endswith('\n'): + line_str = line_str[:-1] + else: + CommandHandler_Exit().handle('', users_pool) + for command in commands: + if command.is_mine(line_str): + result = command.handle(line_str, users_pool) + normal_print(result + '\n') + break + else: + normal_print('Unknown command %r\n' % line_str) + +def start_cui(users_pool): + """Starts a thread for running CUI. + + Returns: + A instance of Thread. + """ + cui = threading.Thread(target=cui_thread, args=(users_pool,)) + cui.start() + return cui + + +def tcp_server_accepter_thread(port, users_pool): + """Accepts tcp connection forever. + + Args: + users_pool: An instance of UsersPool. + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(('', port)) + s.listen(BACKLOG) + except socket.error: + normal_print('Cannot start the TCP server at port %d, abort\n' % port, + prompt=False) + exit(1) + normal_print('Server started (port = %d)\n' % port) + while True: + conn, addr = s.accept() + normal_print('Connection address: %s\n' % str(addr)) + start_tcp_receiver(conn, users_pool) + +def start_tcp_server_accepter(port, users_pool): + """Starts a thread for running tcp server. + + Args: + port: Port to listen. + users_pool: An instance of UsersPool. + """ + thr = threading.Thread(target=tcp_server_accepter_thread, + args=(port, users_pool)) + thr.daemon = True + thr.start() + + +def tcp_receiver_thread(conn, users_pool): + """Handles client request. + + Args: + conn: connection. + users_pool: An instance of UsersPool. + """ + tcp = StringTCP(conn) + request = json.loads(tcp.recv_string()) + try: + snapshot = users_pool.sync(request[JSON_TOKEN.IDENTITY], + request[JSON_TOKEN.TEXT], + request[JSON_TOKEN.CURSOR], + reset=request[JSON_TOKEN.INIT]) + response = { + JSON_TOKEN.TEXT : snapshot.text, + JSON_TOKEN.CURSOR : snapshot.cursor, + JSON_TOKEN.CURSORS : snapshot.cursors, + } + except UsersPoolException as e: + response = {JSON_TOKEN.ERROR : str(e)} + print('request = %r' % request) + print('response = %r' % response) + tcp.send_string(json.dumps(response)) + tcp.close() + + +def start_tcp_receiver(conn, users_pool): + """Starts a thread for handling client request. + + Args: + conn: connection. + users_pool: An instance of UsersPool. + """ + thr = threading.Thread(target=tcp_receiver_thread, args=(conn, users_pool)) + thr.daemon = True + thr.start() + + +def main(): + """Program entry point.""" + help_str = '[USAGE] %s <port_number> <user_list> <save_file>' % sys.argv[0] + + if len(sys.argv) != 4: + normal_print('Wrong arguments!!\n%s' % help_str, prompt=False) + exit(1) + + try: + port = int(sys.argv[1]) + except ValueError: + normal_print('Wrong port!!\n%s' % help_str, prompt=False) + exit(1) + + users_pool = UsersPool(TextChain(sys.argv[3])) + load_user_list(users_pool, sys.argv[2]) + start_tcp_server_accepter(port, users_pool) + cui_thr = start_cui(users_pool) + cui_thr.join() + + +if __name__ == '__main__': + main() |