make sim_threshold adjustable through api

This commit is contained in:
Yi Ge 2019-04-29 03:50:06 +02:00
parent 77362dce1a
commit efd7223624
3 changed files with 20 additions and 16 deletions

View File

@ -5,8 +5,9 @@ from . import constants
from .video import Video from .video import Video
def get_subtitles(video_path: str, lang='eng', time_start='0:00', time_end='', def get_subtitles(
conf_threshold=65, use_fullframe=False) -> str: video_path: str, lang='eng', time_start='0:00', time_end='',
conf_threshold=65, sim_threshold=90, use_fullframe=False) -> str:
# download tesseract data file to ~/tessdata if necessary # download tesseract data file to ~/tessdata if necessary
fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang) fpath = constants.TESSDATA_DIR / '{}.traineddata'.format(lang)
if not fpath.is_file(): if not fpath.is_file():
@ -20,14 +21,14 @@ def get_subtitles(video_path: str, lang='eng', time_start='0:00', time_end='',
v = Video(video_path) v = Video(video_path)
v.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe) v.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe)
return v.get_subtitles() return v.get_subtitles(sim_threshold)
def save_subtitles_to_file( def save_subtitles_to_file(
video_path: str, file_path='subtitle.srt', lang='eng', video_path: str, file_path='subtitle.srt', lang='eng',
time_start='0:00', time_end='', conf_threshold=65, time_start='0:00', time_end='', conf_threshold=65, sim_threshold=90,
use_fullframe=False) -> None: use_fullframe=False) -> None:
with open(file_path, 'w+') as f: with open(file_path, 'w+') as f:
f.write(get_subtitles( f.write(get_subtitles(
video_path, lang, time_start, time_end, conf_threshold, video_path, lang, time_start, time_end, conf_threshold,
use_fullframe)) sim_threshold, use_fullframe))

View File

@ -54,10 +54,12 @@ class PredictedFrame:
class PredictedSubtitle: class PredictedSubtitle:
frames: List[PredictedFrame] frames: List[PredictedFrame]
sim_threshold: int
text: str text: str
def __init__(self, frames: List[PredictedFrame]): def __init__(self, frames: List[PredictedFrame], sim_threshold: int):
self.frames = [f for f in frames if f.confidence > 0] self.frames = [f for f in frames if f.confidence > 0]
self.sim_threshold = sim_threshold
if self.frames: if self.frames:
self.text = max(self.frames, key=lambda f: f.confidence).text self.text = max(self.frames, key=lambda f: f.confidence).text
@ -76,8 +78,8 @@ class PredictedSubtitle:
return self.frames[-1].index return self.frames[-1].index
return 0 return 0
def is_similar_to(self, other: PredictedSubtitle, threshold=90) -> bool: def is_similar_to(self, other: PredictedSubtitle) -> bool:
return fuzz.partial_ratio(self.text, other.text) >= threshold return fuzz.partial_ratio(self.text, other.text) >= self.sim_threshold
def __repr__(self): def __repr__(self):
return '{} - {}. {}'.format(self.index_start, self.index_end, self.text) return '{} - {}. {}'.format(self.index_start, self.index_end, self.text)

View File

@ -80,8 +80,8 @@ class Video:
config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR) config = '--tessdata-dir "{}"'.format(constants.TESSDATA_DIR)
return pytesseract.image_to_data(img, lang=self.lang, config=config) return pytesseract.image_to_data(img, lang=self.lang, config=config)
def get_subtitles(self) -> str: def get_subtitles(self, sim_threshold: int) -> str:
self._generate_subtitles() self._generate_subtitles(sim_threshold)
return ''.join( return ''.join(
'{}\n{} --> {}\n{}\n\n'.format( '{}\n{} --> {}\n{}\n\n'.format(
i, i,
@ -90,7 +90,7 @@ class Video:
sub.text) sub.text)
for i, sub in enumerate(self.pred_subs)) for i, sub in enumerate(self.pred_subs))
def _generate_subtitles(self) -> None: def _generate_subtitles(self, sim_threshold: int) -> None:
self.pred_subs = [] self.pred_subs = []
if self.pred_frames is None: if self.pred_frames is None:
@ -112,8 +112,8 @@ class Video:
else: else:
# divide subtitle paragraphs # divide subtitle paragraphs
para_new = j - WIN_BOUND para_new = j - WIN_BOUND
self._append_sub( self._append_sub(PredictedSubtitle(
PredictedSubtitle(self.pred_frames[i:para_new])) self.pred_frames[i:para_new], sim_threshold))
i = para_new i = para_new
j = i j = i
bound = WIN_BOUND bound = WIN_BOUND
@ -122,7 +122,8 @@ class Video:
# also handle the last remaining frames # also handle the last remaining frames
if i < len(self.pred_frames) - 1: if i < len(self.pred_frames) - 1:
self._append_sub(PredictedSubtitle(self.pred_frames[i:])) self._append_sub(PredictedSubtitle(
self.pred_frames[i:], sim_threshold))
def _append_sub(self, sub: PredictedSubtitle) -> None: def _append_sub(self, sub: PredictedSubtitle) -> None:
if len(sub.text) == 0: if len(sub.text) == 0:
@ -132,7 +133,7 @@ class Video:
while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]): while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
ls = self.pred_subs[-1] ls = self.pred_subs[-1]
del self.pred_subs[-1] del self.pred_subs[-1]
sub = PredictedSubtitle(ls.frames + sub.frames) sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
self.pred_subs.append(sub) self.pred_subs.append(sub)