diff --git a/videocr/utils.py b/videocr/utils.py index ee6161c..13e955f 100644 --- a/videocr/utils.py +++ b/videocr/utils.py @@ -1,5 +1,6 @@ from urllib.request import urlopen import shutil +import datetime from . import constants @@ -19,3 +20,27 @@ def download_lang_data(lang: str): with urlopen(url) as res, open(filepath, 'w+b') as f: shutil.copyfileobj(res, f) + + +# convert time string to frame index +def get_frame_index(time_str: str, fps: float): + t = time_str.split(':') + t = list(map(float, t)) + if len(t) == 3: + td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) + elif len(t) == 2: + td = datetime.timedelta(minutes=t[0], seconds=t[1]) + else: + raise ValueError( + 'Time data "{}" does not match format "%H:%M:%S"'.format(time_str)) + index = int(td.total_seconds() * fps) + return index + + +# convert frame index into SRT timestamp +def get_srt_timestamp(frame_index: int, fps: float): + td = datetime.timedelta(seconds=frame_index / fps) + ms = td.microseconds // 1000 + m, s = divmod(td.seconds, 60) + h, m = divmod(m, 60) + return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms) diff --git a/videocr/video.py b/videocr/video.py index f79eaf7..754a17a 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -1,10 +1,10 @@ from __future__ import annotations -from multiprocessing import Pool -import datetime +import multiprocessing import pytesseract import cv2 from . import constants +from . import utils from .models import PredictedFrame, PredictedSubtitle from .opencv_adapter import Capture @@ -27,12 +27,12 @@ class Video: self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) def run_ocr(self, lang: str, time_start: str, time_end: str, - conf_threshold:int, use_fullframe: bool) -> None: + conf_threshold: int, use_fullframe: bool) -> None: self.lang = lang self.use_fullframe = use_fullframe - ocr_start = self._frame_index(time_start) if time_start else 0 - ocr_end = self._frame_index(time_end) if time_end else self.num_frames + ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0 + ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames if ocr_end < ocr_start: raise ValueError('time_start is later than time_end') @@ -46,31 +46,11 @@ class Video: # perform ocr to frames in parallel it_ocr = pool.imap(self._image_to_data, frames, chunksize=10) self.pred_frames = [ - PredictedFrame(i + ocr_start, data, conf_threshold) - for i, data in enumerate(it_ocr)] + PredictedFrame(i + ocr_start, data, conf_threshold) + for i, data in enumerate(it_ocr) + ] - v.release() - - # convert time str to frame index - def _frame_index(self, time: str) -> int: - t = time.split(':') - t = list(map(float, t)) - if len(t) == 3: - td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) - elif len(t) == 2: - td = datetime.timedelta(minutes=t[0], seconds=t[1]) - else: - raise ValueError( - 'time data "{}" does not match format "%H:%M:%S"'.format(time)) - - index = int(td.total_seconds() * self.fps) - if index > self.num_frames or index < 0: - raise ValueError( - 'time data "{}" exceeds video duration'.format(time)) - - return index - - def _single_frame_ocr(self, img) -> str: + def _image_to_data(self, img) -> str: if not self.use_fullframe: # only use bottom half of the frame by default img = img[self.height // 2:, :] @@ -82,8 +62,8 @@ class Video: return ''.join( '{}\n{} --> {}\n{}\n\n'.format( i, - self._srt_timestamp(sub.index_start), - self._srt_timestamp(sub.index_end), + utils.get_srt_timestamp(sub.index_start, self.fps), + utils.get_srt_timestamp(sub.index_end, self.fps), sub.text) for i, sub in enumerate(self.pred_subs)) @@ -133,10 +113,3 @@ class Video: sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold) self.pred_subs.append(sub) - - def _srt_timestamp(self, frame_index: int) -> str: - td = datetime.timedelta(seconds=frame_index / self.fps) - ms = td.microseconds // 1000 - m, s = divmod(td.seconds, 60) - h, m = divmod(m, 60) - return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)