merge new sub to the last subs if they are similar

This commit is contained in:
Yi Ge 2019-04-26 00:07:25 +02:00
parent 0d86e14fbc
commit 3a73f1f508
2 changed files with 64 additions and 18 deletions

View File

@ -38,18 +38,22 @@ class PredictedFrame:
# handle line breaks # handle line breaks
if block < block_num: if block < block_num:
block = block_num block = block_num
self.words.append(PredictedWord(0, '\n')) if self.words and self.words[-1].text != '\n':
self.words.append(PredictedWord(0, '\n'))
if conf >= CONF_THRESHOLD: if conf >= CONF_THRESHOLD:
self.words.append(PredictedWord(conf, text)) self.words.append(PredictedWord(conf, text))
self.confidence = sum(word.confidence for word in self.words) self.confidence = sum(word.confidence for word in self.words)
self.text = ''.join(word.text + ' ' for word in self.words).strip()
def is_similar_to(self, other: PredictedFrame, threshold=60) -> bool: self.text = ' '.join(word.text for word in self.words)
if len(self.text) == 0 or len(other.text) == 0: # remove chars that are obviously ocr errors
return False translate_table = {ord(c): None for c in '<>{};`@#$%^*_=\\'}
return fuzz.ratio(self.text, other.text) >= threshold translate_table[ord('|')] = 'I'
self.text = self.text.translate(translate_table).strip()
def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
return fuzz.partial_ratio(self.text, other.text) >= threshold
class PredictedSubtitle: class PredictedSubtitle:
@ -76,3 +80,9 @@ class PredictedSubtitle:
if self.frames: if self.frames:
return self.frames[-1].index return self.frames[-1].index
return 0 return 0
def is_similar_to(self, other: PredictedSubtitle, threshold=70) -> bool:
return fuzz.partial_ratio(self.text, other.text) >= threshold
def __repr__(self):
return '{} - {}. {}'.format(self.index_start, self.index_end, self.text)

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from concurrent import futures from concurrent import futures
import datetime
import pytesseract import pytesseract
import cv2 import cv2
import timeit import timeit
@ -7,24 +8,28 @@ import timeit
from .models import PredictedFrame, PredictedSubtitle from .models import PredictedFrame, PredictedSubtitle
SUBTITLE_BOUND = 10
class Video: class Video:
path: str path: str
lang: str lang: str
use_fullframe: bool
num_frames: int num_frames: int
fps: float
pred_frames: List[PredictedFrame] pred_frames: List[PredictedFrame]
pred_subs: List[PredictedSubtitle]
def __init__(self, path, lang): def __init__(self, path, lang, use_fullframe=False):
self.path = path self.path = path
self.lang = lang self.lang = lang
self.use_fullframe = use_fullframe
v = cv2.VideoCapture(path) v = cv2.VideoCapture(path)
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = v.get(cv2.CAP_PROP_FPS)
v.release() v.release()
def _single_frame_ocr(self, img) -> str: def _single_frame_ocr(self, img) -> str:
img = img[img.shape[0] // 2:, :] # only use bottom half of the frame if not self.use_fullframe:
# only use bottom half of the frame by default
img = img[img.shape[0] // 2:, :]
data = pytesseract.image_to_data(img, lang=self.lang) data = pytesseract.image_to_data(img, lang=self.lang)
return data return data
@ -45,36 +50,67 @@ class Video:
raise AttributeError( raise AttributeError(
'Please call self.run_ocr() first to generate ocr of frames') 'Please call self.run_ocr() first to generate ocr of frames')
self.pred_subs = []
# divide ocr of frames into subtitle paragraphs using sliding window # divide ocr of frames into subtitle paragraphs using sliding window
WIN_BOUND = int(self.fps / 2) # 1/2 sec sliding window boundary
bound = WIN_BOUND
i = 0 i = 0
j = 1 j = 1
bound = SUBTITLE_BOUND
while j < self.num_frames: while j < self.num_frames:
fi, fj = self.pred_frames[i], self.pred_frames[j] fi, fj = self.pred_frames[i], self.pred_frames[j]
if fi.is_similar_to(fj): if fi.is_similar_to(fj):
bound = SUBTITLE_BOUND bound = WIN_BOUND
elif bound > 0: elif bound > 0:
bound -= 1 bound -= 1
else: else:
# divide subtitle paragraphs # divide subtitle paragraphs
para_new = j - SUBTITLE_BOUND para_new = j - WIN_BOUND
print(PredictedSubtitle(self.pred_frames[i:para_new]).text) self._append_sub(
PredictedSubtitle(self.pred_frames[i:para_new]))
i = para_new i = para_new
j = i j = i
bound = SUBTITLE_BOUND bound = WIN_BOUND
j += 1 j += 1
if i < self.num_frames - 1: if i < self.num_frames - 1:
print(PredictedSubtitle(self.pred_frames[i:]).text) self._append_sub(PredictedSubtitle(self.pred_frames[i:]))
for i, sub in enumerate(self.pred_subs):
print('{}\n{} --> {}\n{}\n'.format(
i,
self._srt_timestamp(sub.index_start),
self._srt_timestamp(sub.index_end),
sub.text))
return '' return ''
def _append_sub(self, sub: PredictedSubtitle) -> None:
if len(sub.text) == 0:
return
# merge new sub to the last subs if they are similar
while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
lsub = self.pred_subs[-1]
del self.pred_subs[-1]
sub = PredictedSubtitle(lsub.frames + sub.frames)
self.pred_subs.append(sub)
def _srt_timestamp(self, frame_index) -> str:
time = str(datetime.timedelta(seconds=frame_index / self.fps))
return time.replace('.', ',') # srt uses comma as fractional separator
time_start = timeit.default_timer() time_start = timeit.default_timer()
v = Video('1.mp4', 'HanS') v = Video('1.mp4', 'HanS')
v.run_ocr() v.run_ocr()
time_stop = timeit.default_timer()
print('time for ocr: ', time_stop - time_start)
time_start = timeit.default_timer()
v.get_subtitles() v.get_subtitles()
time_stop = timeit.default_timer() time_stop = timeit.default_timer()
print(time_stop - time_start) print('time for get sub: ', time_stop - time_start)