forked from pradana.aumars/videocr
move util functions to utils.py
This commit is contained in:
parent
9360ebdd40
commit
f8e99465c7
@ -1,5 +1,6 @@
|
|||||||
from urllib.request import urlopen
|
from urllib.request import urlopen
|
||||||
import shutil
|
import shutil
|
||||||
|
import datetime
|
||||||
|
|
||||||
from . import constants
|
from . import constants
|
||||||
|
|
||||||
@ -19,3 +20,27 @@ def download_lang_data(lang: str):
|
|||||||
|
|
||||||
with urlopen(url) as res, open(filepath, 'w+b') as f:
|
with urlopen(url) as res, open(filepath, 'w+b') as f:
|
||||||
shutil.copyfileobj(res, f)
|
shutil.copyfileobj(res, f)
|
||||||
|
|
||||||
|
|
||||||
|
# convert time string to frame index
|
||||||
|
def get_frame_index(time_str: str, fps: float):
|
||||||
|
t = time_str.split(':')
|
||||||
|
t = list(map(float, t))
|
||||||
|
if len(t) == 3:
|
||||||
|
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
||||||
|
elif len(t) == 2:
|
||||||
|
td = datetime.timedelta(minutes=t[0], seconds=t[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Time data "{}" does not match format "%H:%M:%S"'.format(time_str))
|
||||||
|
index = int(td.total_seconds() * fps)
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
# convert frame index into SRT timestamp
|
||||||
|
def get_srt_timestamp(frame_index: int, fps: float):
|
||||||
|
td = datetime.timedelta(seconds=frame_index / fps)
|
||||||
|
ms = td.microseconds // 1000
|
||||||
|
m, s = divmod(td.seconds, 60)
|
||||||
|
h, m = divmod(m, 60)
|
||||||
|
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from multiprocessing import Pool
|
import multiprocessing
|
||||||
import datetime
|
|
||||||
import pytesseract
|
import pytesseract
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from . import constants
|
from . import constants
|
||||||
|
from . import utils
|
||||||
from .models import PredictedFrame, PredictedSubtitle
|
from .models import PredictedFrame, PredictedSubtitle
|
||||||
from .opencv_adapter import Capture
|
from .opencv_adapter import Capture
|
||||||
|
|
||||||
@ -27,12 +27,12 @@ class Video:
|
|||||||
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
|
||||||
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
||||||
conf_threshold:int, use_fullframe: bool) -> None:
|
conf_threshold: int, use_fullframe: bool) -> None:
|
||||||
self.lang = lang
|
self.lang = lang
|
||||||
self.use_fullframe = use_fullframe
|
self.use_fullframe = use_fullframe
|
||||||
|
|
||||||
ocr_start = self._frame_index(time_start) if time_start else 0
|
ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0
|
||||||
ocr_end = self._frame_index(time_end) if time_end else self.num_frames
|
ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames
|
||||||
|
|
||||||
if ocr_end < ocr_start:
|
if ocr_end < ocr_start:
|
||||||
raise ValueError('time_start is later than time_end')
|
raise ValueError('time_start is later than time_end')
|
||||||
@ -46,31 +46,11 @@ class Video:
|
|||||||
# perform ocr to frames in parallel
|
# perform ocr to frames in parallel
|
||||||
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
|
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
|
||||||
self.pred_frames = [
|
self.pred_frames = [
|
||||||
PredictedFrame(i + ocr_start, data, conf_threshold)
|
PredictedFrame(i + ocr_start, data, conf_threshold)
|
||||||
for i, data in enumerate(it_ocr)]
|
for i, data in enumerate(it_ocr)
|
||||||
|
]
|
||||||
|
|
||||||
v.release()
|
def _image_to_data(self, img) -> str:
|
||||||
|
|
||||||
# convert time str to frame index
|
|
||||||
def _frame_index(self, time: str) -> int:
|
|
||||||
t = time.split(':')
|
|
||||||
t = list(map(float, t))
|
|
||||||
if len(t) == 3:
|
|
||||||
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
|
||||||
elif len(t) == 2:
|
|
||||||
td = datetime.timedelta(minutes=t[0], seconds=t[1])
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'time data "{}" does not match format "%H:%M:%S"'.format(time))
|
|
||||||
|
|
||||||
index = int(td.total_seconds() * self.fps)
|
|
||||||
if index > self.num_frames or index < 0:
|
|
||||||
raise ValueError(
|
|
||||||
'time data "{}" exceeds video duration'.format(time))
|
|
||||||
|
|
||||||
return index
|
|
||||||
|
|
||||||
def _single_frame_ocr(self, img) -> str:
|
|
||||||
if not self.use_fullframe:
|
if not self.use_fullframe:
|
||||||
# only use bottom half of the frame by default
|
# only use bottom half of the frame by default
|
||||||
img = img[self.height // 2:, :]
|
img = img[self.height // 2:, :]
|
||||||
@ -82,8 +62,8 @@ class Video:
|
|||||||
return ''.join(
|
return ''.join(
|
||||||
'{}\n{} --> {}\n{}\n\n'.format(
|
'{}\n{} --> {}\n{}\n\n'.format(
|
||||||
i,
|
i,
|
||||||
self._srt_timestamp(sub.index_start),
|
utils.get_srt_timestamp(sub.index_start, self.fps),
|
||||||
self._srt_timestamp(sub.index_end),
|
utils.get_srt_timestamp(sub.index_end, self.fps),
|
||||||
sub.text)
|
sub.text)
|
||||||
for i, sub in enumerate(self.pred_subs))
|
for i, sub in enumerate(self.pred_subs))
|
||||||
|
|
||||||
@ -133,10 +113,3 @@ class Video:
|
|||||||
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
|
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
|
||||||
|
|
||||||
self.pred_subs.append(sub)
|
self.pred_subs.append(sub)
|
||||||
|
|
||||||
def _srt_timestamp(self, frame_index: int) -> str:
|
|
||||||
td = datetime.timedelta(seconds=frame_index / self.fps)
|
|
||||||
ms = td.microseconds // 1000
|
|
||||||
m, s = divmod(td.seconds, 60)
|
|
||||||
h, m = divmod(m, 60)
|
|
||||||
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user