diff --git a/videocr/video.py b/videocr/video.py index 8535b87..86d4826 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -14,32 +14,45 @@ class Video: use_fullframe: bool num_frames: int fps: float - ocr_frame_start: int - num_ocr_frames: int pred_frames: List[PredictedFrame] pred_subs: List[PredictedSubtitle] - def __init__(self, path: str, lang: str, use_fullframe=False, - time_start='0:00', time_end=''): + def __init__(self, path: str): self.path = path - self.lang = lang - self.use_fullframe = use_fullframe v = cv2.VideoCapture(path) self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) self.fps = v.get(cv2.CAP_PROP_FPS) v.release() - self.ocr_frame_start = self._frame_index(time_start) + def run_ocr(self, lang: str, use_fullframe=False, + time_start='0:00', time_end='') -> None: + self.lang = lang + self.use_fullframe = use_fullframe + + ocr_start = self._frame_index(time_start) if time_end: ocr_end = self._frame_index(time_end) + if ocr_end < ocr_start: + raise ValueError('time_start is later than time_end') else: ocr_end = self.num_frames - self.num_ocr_frames = ocr_end - self.ocr_frame_start + num_ocr_frames = ocr_end - ocr_start - if self.num_ocr_frames < 0: - raise ValueError('time_start is later than time_end') + # get frames from ocr_start to ocr_end + v = cv2.VideoCapture(self.path) + v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) + frames = (v.read()[1] for _ in range(num_ocr_frames)) + # perform ocr to frames in parallel + with futures.ProcessPoolExecutor() as pool: + ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10) + self.pred_frames = [PredictedFrame(i + ocr_start, data) + for i, data in enumerate(ocr_map)] + + v.release() + + # convert time str to frame index def _frame_index(self, time: str) -> int: t = time.split(':') t = list(map(int, t)) @@ -58,19 +71,6 @@ class Video: return index - def run_ocr(self) -> None: - v = cv2.VideoCapture(self.path) - v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start) - frames = (v.read()[1] for _ in range(self.num_ocr_frames)) - - # perform ocr to all frames in parallel - with futures.ProcessPoolExecutor() as pool: - ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10) - self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data) - for i, data in enumerate(ocr_map)] - - v.release() - def _single_frame_ocr(self, img) -> str: if not self.use_fullframe: # only use bottom half of the frame by default @@ -99,7 +99,7 @@ class Video: bound = WIN_BOUND i = 0 j = 1 - while j < self.num_ocr_frames: + while j < len(self.pred_frames): fi, fj = self.pred_frames[i], self.pred_frames[j] if fi.is_similar_to(fj): @@ -118,7 +118,7 @@ class Video: j += 1 # also handle the last remaining frames - if i < self.num_ocr_frames - 1: + if i < len(self.pred_frames) - 1: self._append_sub(PredictedSubtitle(self.pred_frames[i:])) def _append_sub(self, sub: PredictedSubtitle) -> None: