move util functions to utils.py

This commit is contained in:
Yi Ge 2019-12-15 21:38:48 +08:00
parent 9360ebdd40
commit f8e99465c7
2 changed files with 36 additions and 38 deletions

View File

@ -1,5 +1,6 @@
from urllib.request import urlopen from urllib.request import urlopen
import shutil import shutil
import datetime
from . import constants from . import constants
@ -19,3 +20,27 @@ def download_lang_data(lang: str):
with urlopen(url) as res, open(filepath, 'w+b') as f: with urlopen(url) as res, open(filepath, 'w+b') as f:
shutil.copyfileobj(res, f) shutil.copyfileobj(res, f)
# convert time string to frame index
def get_frame_index(time_str: str, fps: float):
t = time_str.split(':')
t = list(map(float, t))
if len(t) == 3:
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
elif len(t) == 2:
td = datetime.timedelta(minutes=t[0], seconds=t[1])
else:
raise ValueError(
'Time data "{}" does not match format "%H:%M:%S"'.format(time_str))
index = int(td.total_seconds() * fps)
return index
# convert frame index into SRT timestamp
def get_srt_timestamp(frame_index: int, fps: float):
td = datetime.timedelta(seconds=frame_index / fps)
ms = td.microseconds // 1000
m, s = divmod(td.seconds, 60)
h, m = divmod(m, 60)
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)

View File

@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
from multiprocessing import Pool import multiprocessing
import datetime
import pytesseract import pytesseract
import cv2 import cv2
from . import constants from . import constants
from . import utils
from .models import PredictedFrame, PredictedSubtitle from .models import PredictedFrame, PredictedSubtitle
from .opencv_adapter import Capture from .opencv_adapter import Capture
@ -27,12 +27,12 @@ class Video:
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
def run_ocr(self, lang: str, time_start: str, time_end: str, def run_ocr(self, lang: str, time_start: str, time_end: str,
conf_threshold:int, use_fullframe: bool) -> None: conf_threshold: int, use_fullframe: bool) -> None:
self.lang = lang self.lang = lang
self.use_fullframe = use_fullframe self.use_fullframe = use_fullframe
ocr_start = self._frame_index(time_start) if time_start else 0 ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0
ocr_end = self._frame_index(time_end) if time_end else self.num_frames ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames
if ocr_end < ocr_start: if ocr_end < ocr_start:
raise ValueError('time_start is later than time_end') raise ValueError('time_start is later than time_end')
@ -46,31 +46,11 @@ class Video:
# perform ocr to frames in parallel # perform ocr to frames in parallel
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10) it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
self.pred_frames = [ self.pred_frames = [
PredictedFrame(i + ocr_start, data, conf_threshold) PredictedFrame(i + ocr_start, data, conf_threshold)
for i, data in enumerate(it_ocr)] for i, data in enumerate(it_ocr)
]
v.release() def _image_to_data(self, img) -> str:
# convert time str to frame index
def _frame_index(self, time: str) -> int:
t = time.split(':')
t = list(map(float, t))
if len(t) == 3:
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
elif len(t) == 2:
td = datetime.timedelta(minutes=t[0], seconds=t[1])
else:
raise ValueError(
'time data "{}" does not match format "%H:%M:%S"'.format(time))
index = int(td.total_seconds() * self.fps)
if index > self.num_frames or index < 0:
raise ValueError(
'time data "{}" exceeds video duration'.format(time))
return index
def _single_frame_ocr(self, img) -> str:
if not self.use_fullframe: if not self.use_fullframe:
# only use bottom half of the frame by default # only use bottom half of the frame by default
img = img[self.height // 2:, :] img = img[self.height // 2:, :]
@ -82,8 +62,8 @@ class Video:
return ''.join( return ''.join(
'{}\n{} --> {}\n{}\n\n'.format( '{}\n{} --> {}\n{}\n\n'.format(
i, i,
self._srt_timestamp(sub.index_start), utils.get_srt_timestamp(sub.index_start, self.fps),
self._srt_timestamp(sub.index_end), utils.get_srt_timestamp(sub.index_end, self.fps),
sub.text) sub.text)
for i, sub in enumerate(self.pred_subs)) for i, sub in enumerate(self.pred_subs))
@ -133,10 +113,3 @@ class Video:
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold) sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
self.pred_subs.append(sub) self.pred_subs.append(sub)
def _srt_timestamp(self, frame_index: int) -> str:
td = datetime.timedelta(seconds=frame_index / self.fps)
ms = td.microseconds // 1000
m, s = divmod(td.seconds, 60)
h, m = divmod(m, 60)
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)