2019-04-24 21:18:31 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
from concurrent import futures
|
2019-04-26 00:07:25 +02:00
|
|
|
import datetime
|
2019-04-24 21:18:31 +02:00
|
|
|
import pytesseract
|
|
|
|
import cv2
|
|
|
|
|
2019-04-25 01:40:46 +02:00
|
|
|
from .models import PredictedFrame, PredictedSubtitle
|
|
|
|
|
|
|
|
|
2019-04-24 21:18:31 +02:00
|
|
|
class Video:
|
|
|
|
path: str
|
|
|
|
lang: str
|
2019-04-26 00:07:25 +02:00
|
|
|
use_fullframe: bool
|
2019-04-24 21:18:31 +02:00
|
|
|
num_frames: int
|
2019-04-26 00:07:25 +02:00
|
|
|
fps: float
|
2019-04-25 01:40:46 +02:00
|
|
|
pred_frames: List[PredictedFrame]
|
2019-04-26 00:07:25 +02:00
|
|
|
pred_subs: List[PredictedSubtitle]
|
2019-04-24 21:18:31 +02:00
|
|
|
|
2019-04-27 21:41:19 +02:00
|
|
|
def __init__(self, path: str):
|
2019-04-24 21:18:31 +02:00
|
|
|
self.path = path
|
|
|
|
v = cv2.VideoCapture(path)
|
|
|
|
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
|
2019-04-26 00:07:25 +02:00
|
|
|
self.fps = v.get(cv2.CAP_PROP_FPS)
|
2019-04-24 21:18:31 +02:00
|
|
|
v.release()
|
|
|
|
|
2019-04-28 15:46:24 +02:00
|
|
|
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
|
|
|
use_fullframe: bool) -> None:
|
2019-04-27 21:41:19 +02:00
|
|
|
self.lang = lang
|
|
|
|
self.use_fullframe = use_fullframe
|
|
|
|
|
2019-04-28 15:46:24 +02:00
|
|
|
ocr_start = self._frame_index(time_start) if time_start else 0
|
|
|
|
ocr_end = self._frame_index(time_end) if time_end else self.num_frames
|
2019-04-27 00:29:31 +02:00
|
|
|
|
2019-04-28 15:46:24 +02:00
|
|
|
if ocr_end < ocr_start:
|
|
|
|
raise ValueError('time_start is later than time_end')
|
2019-04-27 21:41:19 +02:00
|
|
|
num_ocr_frames = ocr_end - ocr_start
|
|
|
|
|
|
|
|
# get frames from ocr_start to ocr_end
|
|
|
|
v = cv2.VideoCapture(self.path)
|
|
|
|
v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
|
|
|
|
frames = (v.read()[1] for _ in range(num_ocr_frames))
|
|
|
|
|
|
|
|
# perform ocr to frames in parallel
|
|
|
|
with futures.ProcessPoolExecutor() as pool:
|
|
|
|
ocr_map = pool.map(self._single_frame_ocr, frames, chunksize=10)
|
|
|
|
self.pred_frames = [PredictedFrame(i + ocr_start, data)
|
|
|
|
for i, data in enumerate(ocr_map)]
|
2019-04-27 00:29:31 +02:00
|
|
|
|
2019-04-27 21:41:19 +02:00
|
|
|
v.release()
|
2019-04-27 00:29:31 +02:00
|
|
|
|
2019-04-27 21:41:19 +02:00
|
|
|
# convert time str to frame index
|
2019-04-27 00:29:31 +02:00
|
|
|
def _frame_index(self, time: str) -> int:
|
|
|
|
t = time.split(':')
|
2019-04-28 15:46:24 +02:00
|
|
|
t = list(map(float, t))
|
2019-04-27 00:29:31 +02:00
|
|
|
if len(t) == 3:
|
|
|
|
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
|
|
|
elif len(t) == 2:
|
|
|
|
td = datetime.timedelta(minutes=t[0], seconds=t[1])
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
'time data "{}" does not match format "%H:%M:%S"'.format(time))
|
|
|
|
|
|
|
|
index = int(td.total_seconds() * self.fps)
|
|
|
|
if index > self.num_frames or index < 0:
|
|
|
|
raise ValueError(
|
|
|
|
'time data "{}" exceeds video duration'.format(time))
|
|
|
|
|
|
|
|
return index
|
2019-04-24 21:18:31 +02:00
|
|
|
|
2019-04-27 00:29:31 +02:00
|
|
|
def _single_frame_ocr(self, img) -> str:
|
|
|
|
if not self.use_fullframe:
|
|
|
|
# only use bottom half of the frame by default
|
|
|
|
img = img[img.shape[0] // 2:, :]
|
|
|
|
return pytesseract.image_to_data(img, lang=self.lang)
|
|
|
|
|
2019-04-25 01:40:46 +02:00
|
|
|
def get_subtitles(self) -> str:
|
2019-04-26 00:32:47 +02:00
|
|
|
self._generate_subtitles()
|
|
|
|
return ''.join(
|
2019-04-27 00:28:17 +02:00
|
|
|
'{}\n{} --> {}\n{}\n\n'.format(
|
2019-04-26 00:32:47 +02:00
|
|
|
i,
|
|
|
|
self._srt_timestamp(sub.index_start),
|
|
|
|
self._srt_timestamp(sub.index_end),
|
|
|
|
sub.text)
|
|
|
|
for i, sub in enumerate(self.pred_subs))
|
|
|
|
|
|
|
|
def _generate_subtitles(self) -> None:
|
|
|
|
self.pred_subs = []
|
|
|
|
|
2019-04-25 01:40:46 +02:00
|
|
|
if self.pred_frames is None:
|
|
|
|
raise AttributeError(
|
2019-04-27 00:28:17 +02:00
|
|
|
'Please call self.run_ocr() first to perform ocr on frames')
|
2019-04-25 01:40:46 +02:00
|
|
|
|
|
|
|
# divide ocr of frames into subtitle paragraphs using sliding window
|
2019-04-27 00:28:17 +02:00
|
|
|
WIN_BOUND = int(self.fps // 2) # 1/2 sec sliding window boundary
|
2019-04-26 00:07:25 +02:00
|
|
|
bound = WIN_BOUND
|
2019-04-25 01:40:46 +02:00
|
|
|
i = 0
|
|
|
|
j = 1
|
2019-04-27 21:41:19 +02:00
|
|
|
while j < len(self.pred_frames):
|
2019-04-25 01:40:46 +02:00
|
|
|
fi, fj = self.pred_frames[i], self.pred_frames[j]
|
|
|
|
|
|
|
|
if fi.is_similar_to(fj):
|
2019-04-26 00:07:25 +02:00
|
|
|
bound = WIN_BOUND
|
2019-04-25 01:40:46 +02:00
|
|
|
elif bound > 0:
|
|
|
|
bound -= 1
|
|
|
|
else:
|
|
|
|
# divide subtitle paragraphs
|
2019-04-26 00:07:25 +02:00
|
|
|
para_new = j - WIN_BOUND
|
|
|
|
self._append_sub(
|
|
|
|
PredictedSubtitle(self.pred_frames[i:para_new]))
|
2019-04-25 01:40:46 +02:00
|
|
|
i = para_new
|
|
|
|
j = i
|
2019-04-26 00:07:25 +02:00
|
|
|
bound = WIN_BOUND
|
2019-04-25 01:40:46 +02:00
|
|
|
|
|
|
|
j += 1
|
|
|
|
|
2019-04-26 00:32:47 +02:00
|
|
|
# also handle the last remaining frames
|
2019-04-27 21:41:19 +02:00
|
|
|
if i < len(self.pred_frames) - 1:
|
2019-04-26 00:07:25 +02:00
|
|
|
self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
|
|
|
|
|
|
|
|
def _append_sub(self, sub: PredictedSubtitle) -> None:
|
|
|
|
if len(sub.text) == 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
# merge new sub to the last subs if they are similar
|
|
|
|
while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
|
2019-04-27 00:28:17 +02:00
|
|
|
ls = self.pred_subs[-1]
|
2019-04-26 00:07:25 +02:00
|
|
|
del self.pred_subs[-1]
|
2019-04-27 00:28:17 +02:00
|
|
|
sub = PredictedSubtitle(ls.frames + sub.frames)
|
2019-04-26 00:07:25 +02:00
|
|
|
|
|
|
|
self.pred_subs.append(sub)
|
|
|
|
|
2019-04-27 00:28:17 +02:00
|
|
|
def _srt_timestamp(self, frame_index: int) -> str:
|
|
|
|
td = datetime.timedelta(seconds=frame_index / self.fps)
|
|
|
|
ms = td.microseconds // 1000
|
|
|
|
m, s = divmod(td.seconds, 60)
|
|
|
|
h, m = divmod(m, 60)
|
|
|
|
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|