From bd6f15978b905d9e6c72fe6f8c2b43e80e0d44ab Mon Sep 17 00:00:00 2001 From: Yi Ge Date: Sun, 28 Apr 2019 15:46:24 +0200 Subject: [PATCH] add api definition --- videocr/api.py | 17 +++++++++++++++++ videocr/models.py | 10 ++++------ videocr/video.py | 34 +++++++--------------------------- 3 files changed, 28 insertions(+), 33 deletions(-) create mode 100644 videocr/api.py diff --git a/videocr/api.py b/videocr/api.py new file mode 100644 index 0000000..f253638 --- /dev/null +++ b/videocr/api.py @@ -0,0 +1,17 @@ + +from .video import Video + + +def get_subtitles(video_path: str, lang='eng', + time_start='0:00', time_end='', use_fullframe=False) -> str: + v = Video(video_path) + v.run_ocr(lang, time_start, time_end, use_fullframe) + return v.get_subtitles() + + +def save_subtitles_to_file( + video_path: str, file_path='subtitle.srt', lang='eng', + time_start='0:00', time_end='', use_fullframe=False) -> None: + with open(file_path, 'w+') as f: + f.write(get_subtitles( + video_path, lang, time_start, time_end, use_fullframe)) diff --git a/videocr/models.py b/videocr/models.py index 789ece1..f787712 100644 --- a/videocr/models.py +++ b/videocr/models.py @@ -48,9 +48,8 @@ class PredictedFrame: self.text = ' '.join(word.text for word in self.words) # remove chars that are obviously ocr errors - translate_table = {ord(c): None for c in '<>{}[];`@#$%^*_=~\\'} - translate_table[ord('|')] = 'I' - self.text = self.text.translate(translate_table).strip() + table = str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\') + self.text = self.text.translate(table).replace(' \n ', '\n').strip() def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool: return fuzz.ratio(self.text, other.text) >= threshold @@ -58,14 +57,13 @@ class PredictedFrame: class PredictedSubtitle: frames: List[PredictedFrame] + text: str def __init__(self, frames: List[PredictedFrame]): self.frames = [f for f in frames if f.confidence > 0] if self.frames: - conf_max = max(f.confidence for f in self.frames) - self.text = next(f.text for f in self.frames - if f.confidence == conf_max) + self.text = max(self.frames, key=lambda f: f.confidence).text else: self.text = '' diff --git a/videocr/video.py b/videocr/video.py index 86d4826..2485e5d 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -3,7 +3,6 @@ from concurrent import futures import datetime import pytesseract import cv2 -import timeit from .models import PredictedFrame, PredictedSubtitle @@ -24,19 +23,16 @@ class Video: self.fps = v.get(cv2.CAP_PROP_FPS) v.release() - def run_ocr(self, lang: str, use_fullframe=False, - time_start='0:00', time_end='') -> None: + def run_ocr(self, lang: str, time_start: str, time_end: str, + use_fullframe: bool) -> None: self.lang = lang self.use_fullframe = use_fullframe - ocr_start = self._frame_index(time_start) + ocr_start = self._frame_index(time_start) if time_start else 0 + ocr_end = self._frame_index(time_end) if time_end else self.num_frames - if time_end: - ocr_end = self._frame_index(time_end) - if ocr_end < ocr_start: - raise ValueError('time_start is later than time_end') - else: - ocr_end = self.num_frames + if ocr_end < ocr_start: + raise ValueError('time_start is later than time_end') num_ocr_frames = ocr_end - ocr_start # get frames from ocr_start to ocr_end @@ -55,7 +51,7 @@ class Video: # convert time str to frame index def _frame_index(self, time: str) -> int: t = time.split(':') - t = list(map(int, t)) + t = list(map(float, t)) if len(t) == 3: td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) elif len(t) == 2: @@ -139,19 +135,3 @@ class Video: m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms) - - def save_subtitles_to_file(self, path='subtitle.srt') -> None: - with open(path, 'w+') as f: - f.write(self.get_subtitles()) - - -time_start = timeit.default_timer() -v = Video('1.mp4', 'HanS') -v.run_ocr() -time_stop = timeit.default_timer() -print('time for ocr: ', time_stop - time_start) - -time_start = timeit.default_timer() -v.save_subtitles_to_file() -time_stop = timeit.default_timer() -print('time for save sub: ', time_stop - time_start)