From 77362dce1a38c7cdc3674b2b42883203a277531b Mon Sep 17 00:00:00 2001 From: Yi Ge Date: Mon, 29 Apr 2019 03:05:02 +0200 Subject: [PATCH] make conf_threshold adjustable through api --- videocr/api.py | 12 +++++++----- videocr/models.py | 2 +- videocr/video.py | 11 ++++++++--- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/videocr/api.py b/videocr/api.py index c581890..cd1c8d0 100644 --- a/videocr/api.py +++ b/videocr/api.py @@ -5,8 +5,8 @@ from . import constants from .video import Video -def get_subtitles(video_path: str, lang='eng', - time_start='0:00', time_end='', use_fullframe=False) -> str: +def get_subtitles(video_path: str, lang='eng', time_start='0:00', time_end='', + conf_threshold=65, use_fullframe=False) -> str: # download tesseract data file to ~/tessdata if necessary fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang) if not fpath.is_file(): @@ -19,13 +19,15 @@ def get_subtitles(video_path: str, lang='eng', shutil.copyfileobj(res, f) v = Video(video_path) - v.run_ocr(lang, time_start, time_end, use_fullframe) + v.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe) return v.get_subtitles() def save_subtitles_to_file( video_path: str, file_path='subtitle.srt', lang='eng', - time_start='0:00', time_end='', use_fullframe=False) -> None: + time_start='0:00', time_end='', conf_threshold=65, + use_fullframe=False) -> None: with open(file_path, 'w+') as f: f.write(get_subtitles( - video_path, lang, time_start, time_end, use_fullframe)) + video_path, lang, time_start, time_end, conf_threshold, + use_fullframe)) diff --git a/videocr/models.py b/videocr/models.py index 5b5549f..0dc4298 100644 --- a/videocr/models.py +++ b/videocr/models.py @@ -17,7 +17,7 @@ class PredictedFrame: confidence: int # total confidence of all words text: str - def __init__(self, index: int, pred_data: str, conf_threshold=70): + def __init__(self, index: int, pred_data: str, conf_threshold: int): self.index = index self.words = [] diff --git a/videocr/video.py b/videocr/video.py index fb2fe04..493a53c 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -14,18 +14,22 @@ class Video: use_fullframe: bool num_frames: int fps: float + height: int pred_frames: List[PredictedFrame] pred_subs: List[PredictedSubtitle] def __init__(self, path: str): self.path = path v = cv2.VideoCapture(path) + if not v.isOpened(): + raise IOError('can not open video format {}'.format(path)) self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) self.fps = v.get(cv2.CAP_PROP_FPS) + self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) v.release() def run_ocr(self, lang: str, time_start: str, time_end: str, - use_fullframe: bool) -> None: + conf_threshold: int, use_fullframe: bool) -> None: self.lang = lang self.use_fullframe = use_fullframe @@ -44,8 +48,9 @@ class Video: # perform ocr to frames in parallel with futures.ProcessPoolExecutor() as pool: ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10) - self.pred_frames = [PredictedFrame(i + ocr_start, data) - for i, data in enumerate(ocr_map)] + self.pred_frames = [ + PredictedFrame(i + ocr_start, data, conf_threshold) + for i, data in enumerate(ocr_map)] v.release()