diff --git a/videocr/opencv_adapter.py b/videocr/opencv_adapter.py new file mode 100644 index 0000000..34889e5 --- /dev/null +++ b/videocr/opencv_adapter.py @@ -0,0 +1,15 @@ +import cv2 + + +class Capture: + def __init__(self, video_path): + self.path = video_path + + def __enter__(self): + self.cap = cv2.VideoCapture(self.path) + if not self.cap.isOpened(): + raise IOError('Can not open video {}.'.format(self.path)) + return self.cap + + def __exit__(self, exc_type, exc_value, traceback): + self.cap.release() diff --git a/videocr/video.py b/videocr/video.py index 43f9888..f79eaf7 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -6,6 +6,7 @@ import cv2 from . import constants from .models import PredictedFrame, PredictedSubtitle +from .opencv_adapter import Capture class Video: @@ -20,13 +21,10 @@ class Video: def __init__(self, path: str): self.path = path - v = cv2.VideoCapture(path) - if not v.isOpened(): - raise IOError('can not open video format {}'.format(path)) - self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) - self.fps = v.get(cv2.CAP_PROP_FPS) - self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) - v.release() + with Capture(path) as v: + self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) + self.fps = v.get(cv2.CAP_PROP_FPS) + self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) def run_ocr(self, lang: str, time_start: str, time_end: str, conf_threshold:int, use_fullframe: bool) -> None: @@ -41,13 +39,12 @@ class Video: num_ocr_frames = ocr_end - ocr_start # get frames from ocr_start to ocr_end - v = cv2.VideoCapture(self.path) - v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) - frames = (v.read()[1] for _ in range(num_ocr_frames)) + with Capture(self.path) as v, multiprocessing.Pool() as pool: + v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) + frames = (v.read()[1] for _ in range(num_ocr_frames)) - # perform ocr to frames in parallel - with Pool() as pool: - it_ocr = pool.imap(self._single_frame_ocr, frames, chunksize=10) + # perform ocr to frames in parallel + it_ocr = pool.imap(self._image_to_data, frames, chunksize=10) self.pred_frames = [ PredictedFrame(i + ocr_start, data, conf_threshold) for i, data in enumerate(it_ocr)]