add api definition

This commit is contained in:
Yi Ge 2019-04-28 15:46:24 +02:00
parent bc84ee39ff
commit bd6f15978b
3 changed files with 28 additions and 33 deletions

17
videocr/api.py Normal file
View File

@ -0,0 +1,17 @@
from .video import Video
def get_subtitles(video_path: str, lang='eng',
time_start='0:00', time_end='', use_fullframe=False) -> str:
v = Video(video_path)
v.run_ocr(lang, time_start, time_end, use_fullframe)
return v.get_subtitles()
def save_subtitles_to_file(
video_path: str, file_path='subtitle.srt', lang='eng',
time_start='0:00', time_end='', use_fullframe=False) -> None:
with open(file_path, 'w+') as f:
f.write(get_subtitles(
video_path, lang, time_start, time_end, use_fullframe))

View File

@ -48,9 +48,8 @@ class PredictedFrame:
self.text = ' '.join(word.text for word in self.words)
# remove chars that are obviously ocr errors
translate_table = {ord(c): None for c in '<>{}[];`@#$%^*_=~\\'}
translate_table[ord('|')] = 'I'
self.text = self.text.translate(translate_table).strip()
table = str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')
self.text = self.text.translate(table).replace(' \n ', '\n').strip()
def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
return fuzz.ratio(self.text, other.text) >= threshold
@ -58,14 +57,13 @@ class PredictedFrame:
class PredictedSubtitle:
frames: List[PredictedFrame]
text: str
def __init__(self, frames: List[PredictedFrame]):
self.frames = [f for f in frames if f.confidence > 0]
if self.frames:
conf_max = max(f.confidence for f in self.frames)
self.text = next(f.text for f in self.frames
if f.confidence == conf_max)
self.text = max(self.frames, key=lambda f: f.confidence).text
else:
self.text = ''

View File

@ -3,7 +3,6 @@ from concurrent import futures
import datetime
import pytesseract
import cv2
import timeit
from .models import PredictedFrame, PredictedSubtitle
@ -24,19 +23,16 @@ class Video:
self.fps = v.get(cv2.CAP_PROP_FPS)
v.release()
def run_ocr(self, lang: str, use_fullframe=False,
time_start='0:00', time_end='') -> None:
def run_ocr(self, lang: str, time_start: str, time_end: str,
use_fullframe: bool) -> None:
self.lang = lang
self.use_fullframe = use_fullframe
ocr_start = self._frame_index(time_start)
ocr_start = self._frame_index(time_start) if time_start else 0
ocr_end = self._frame_index(time_end) if time_end else self.num_frames
if time_end:
ocr_end = self._frame_index(time_end)
if ocr_end < ocr_start:
raise ValueError('time_start is later than time_end')
else:
ocr_end = self.num_frames
if ocr_end < ocr_start:
raise ValueError('time_start is later than time_end')
num_ocr_frames = ocr_end - ocr_start
# get frames from ocr_start to ocr_end
@ -55,7 +51,7 @@ class Video:
# convert time str to frame index
def _frame_index(self, time: str) -> int:
t = time.split(':')
t = list(map(int, t))
t = list(map(float, t))
if len(t) == 3:
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
elif len(t) == 2:
@ -139,19 +135,3 @@ class Video:
m, s = divmod(td.seconds, 60)
h, m = divmod(m, 60)
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
def save_subtitles_to_file(self, path='subtitle.srt') -> None:
with open(path, 'w+') as f:
f.write(self.get_subtitles())
time_start = timeit.default_timer()
v = Video('1.mp4', 'HanS')
v.run_ocr()
time_stop = timeit.default_timer()
print('time for ocr: ', time_stop - time_start)
time_start = timeit.default_timer()
v.save_subtitles_to_file()
time_stop = timeit.default_timer()
print('time for save sub: ', time_stop - time_start)