move util functions to utils.py
This commit is contained in:
parent
9360ebdd40
commit
f8e99465c7
@ -1,5 +1,6 @@
|
||||
from urllib.request import urlopen
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
from . import constants
|
||||
|
||||
@ -19,3 +20,27 @@ def download_lang_data(lang: str):
|
||||
|
||||
with urlopen(url) as res, open(filepath, 'w+b') as f:
|
||||
shutil.copyfileobj(res, f)
|
||||
|
||||
|
||||
# convert time string to frame index
|
||||
def get_frame_index(time_str: str, fps: float):
|
||||
t = time_str.split(':')
|
||||
t = list(map(float, t))
|
||||
if len(t) == 3:
|
||||
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
||||
elif len(t) == 2:
|
||||
td = datetime.timedelta(minutes=t[0], seconds=t[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
'Time data "{}" does not match format "%H:%M:%S"'.format(time_str))
|
||||
index = int(td.total_seconds() * fps)
|
||||
return index
|
||||
|
||||
|
||||
# convert frame index into SRT timestamp
|
||||
def get_srt_timestamp(frame_index: int, fps: float):
|
||||
td = datetime.timedelta(seconds=frame_index / fps)
|
||||
ms = td.microseconds // 1000
|
||||
m, s = divmod(td.seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
||||
|
@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from multiprocessing import Pool
|
||||
import datetime
|
||||
import multiprocessing
|
||||
import pytesseract
|
||||
import cv2
|
||||
|
||||
from . import constants
|
||||
from . import utils
|
||||
from .models import PredictedFrame, PredictedSubtitle
|
||||
from .opencv_adapter import Capture
|
||||
|
||||
@ -27,12 +27,12 @@ class Video:
|
||||
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
||||
conf_threshold:int, use_fullframe: bool) -> None:
|
||||
conf_threshold: int, use_fullframe: bool) -> None:
|
||||
self.lang = lang
|
||||
self.use_fullframe = use_fullframe
|
||||
|
||||
ocr_start = self._frame_index(time_start) if time_start else 0
|
||||
ocr_end = self._frame_index(time_end) if time_end else self.num_frames
|
||||
ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0
|
||||
ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames
|
||||
|
||||
if ocr_end < ocr_start:
|
||||
raise ValueError('time_start is later than time_end')
|
||||
@ -46,31 +46,11 @@ class Video:
|
||||
# perform ocr to frames in parallel
|
||||
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
|
||||
self.pred_frames = [
|
||||
PredictedFrame(i + ocr_start, data, conf_threshold)
|
||||
for i, data in enumerate(it_ocr)]
|
||||
PredictedFrame(i + ocr_start, data, conf_threshold)
|
||||
for i, data in enumerate(it_ocr)
|
||||
]
|
||||
|
||||
v.release()
|
||||
|
||||
# convert time str to frame index
|
||||
def _frame_index(self, time: str) -> int:
|
||||
t = time.split(':')
|
||||
t = list(map(float, t))
|
||||
if len(t) == 3:
|
||||
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
||||
elif len(t) == 2:
|
||||
td = datetime.timedelta(minutes=t[0], seconds=t[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
'time data "{}" does not match format "%H:%M:%S"'.format(time))
|
||||
|
||||
index = int(td.total_seconds() * self.fps)
|
||||
if index > self.num_frames or index < 0:
|
||||
raise ValueError(
|
||||
'time data "{}" exceeds video duration'.format(time))
|
||||
|
||||
return index
|
||||
|
||||
def _single_frame_ocr(self, img) -> str:
|
||||
def _image_to_data(self, img) -> str:
|
||||
if not self.use_fullframe:
|
||||
# only use bottom half of the frame by default
|
||||
img = img[self.height // 2:, :]
|
||||
@ -82,8 +62,8 @@ class Video:
|
||||
return ''.join(
|
||||
'{}\n{} --> {}\n{}\n\n'.format(
|
||||
i,
|
||||
self._srt_timestamp(sub.index_start),
|
||||
self._srt_timestamp(sub.index_end),
|
||||
utils.get_srt_timestamp(sub.index_start, self.fps),
|
||||
utils.get_srt_timestamp(sub.index_end, self.fps),
|
||||
sub.text)
|
||||
for i, sub in enumerate(self.pred_subs))
|
||||
|
||||
@ -133,10 +113,3 @@ class Video:
|
||||
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
|
||||
|
||||
self.pred_subs.append(sub)
|
||||
|
||||
def _srt_timestamp(self, frame_index: int) -> str:
|
||||
td = datetime.timedelta(seconds=frame_index / self.fps)
|
||||
ms = td.microseconds // 1000
|
||||
m, s = divmod(td.seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
||||
|
Loading…
Reference in New Issue
Block a user