diff --git a/videocr/video.py b/videocr/video.py index 2031239..8535b87 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -14,10 +14,13 @@ class Video: use_fullframe: bool num_frames: int fps: float + ocr_frame_start: int + num_ocr_frames: int pred_frames: List[PredictedFrame] pred_subs: List[PredictedSubtitle] - def __init__(self, path, lang, use_fullframe=False): + def __init__(self, path: str, lang: str, use_fullframe=False, + time_start='0:00', time_end=''): self.path = path self.lang = lang self.use_fullframe = use_fullframe @@ -26,24 +29,53 @@ class Video: self.fps = v.get(cv2.CAP_PROP_FPS) v.release() + self.ocr_frame_start = self._frame_index(time_start) + + if time_end: + ocr_end = self._frame_index(time_end) + else: + ocr_end = self.num_frames + self.num_ocr_frames = ocr_end - self.ocr_frame_start + + if self.num_ocr_frames < 0: + raise ValueError('time_start is later than time_end') + + def _frame_index(self, time: str) -> int: + t = time.split(':') + t = list(map(int, t)) + if len(t) == 3: + td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) + elif len(t) == 2: + td = datetime.timedelta(minutes=t[0], seconds=t[1]) + else: + raise ValueError( + 'time data "{}" does not match format "%H:%M:%S"'.format(time)) + + index = int(td.total_seconds() * self.fps) + if index > self.num_frames or index < 0: + raise ValueError( + 'time data "{}" exceeds video duration'.format(time)) + + return index + + def run_ocr(self) -> None: + v = cv2.VideoCapture(self.path) + v.set(cv2.CAP_PROP_POS_FRAMES, self.ocr_frame_start) + frames = (v.read()[1] for _ in range(self.num_ocr_frames)) + + # perform ocr to all frames in parallel + with futures.ProcessPoolExecutor() as pool: + ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10) + self.pred_frames = [PredictedFrame(i + self.ocr_frame_start, data) + for i, data in enumerate(ocr_map)] + + v.release() + def _single_frame_ocr(self, img) -> str: if not self.use_fullframe: # only use bottom half of the frame by default img = img[img.shape[0] // 2:, :] - data = pytesseract.image_to_data(img, lang=self.lang) - return data - - def run_ocr(self) -> None: - v = cv2.VideoCapture(self.path) - frames = (v.read()[1] for _ in range(self.num_frames)) - - # perform ocr to all frames in parallel - with futures.ProcessPoolExecutor() as pool: - frames_ocr = pool.map(self._single_frame_ocr, frames, chunksize=10) - self.pred_frames = [PredictedFrame(i, data) - for i, data in enumerate(frames_ocr)] - - v.release() + return pytesseract.image_to_data(img, lang=self.lang) def get_subtitles(self) -> str: self._generate_subtitles() @@ -67,7 +99,7 @@ class Video: bound = WIN_BOUND i = 0 j = 1 - while j < self.num_frames: + while j < self.num_ocr_frames: fi, fj = self.pred_frames[i], self.pred_frames[j] if fi.is_similar_to(fj): @@ -86,7 +118,7 @@ class Video: j += 1 # also handle the last remaining frames - if i < self.num_frames - 1: + if i < self.num_ocr_frames - 1: self._append_sub(PredictedSubtitle(self.pred_frames[i:])) def _append_sub(self, sub: PredictedSubtitle) -> None: