add api definition
This commit is contained in:
parent
bc84ee39ff
commit
bd6f15978b
17
videocr/api.py
Normal file
17
videocr/api.py
Normal file
@ -0,0 +1,17 @@
|
||||
|
||||
from .video import Video
|
||||
|
||||
|
||||
def get_subtitles(video_path: str, lang='eng',
|
||||
time_start='0:00', time_end='', use_fullframe=False) -> str:
|
||||
v = Video(video_path)
|
||||
v.run_ocr(lang, time_start, time_end, use_fullframe)
|
||||
return v.get_subtitles()
|
||||
|
||||
|
||||
def save_subtitles_to_file(
|
||||
video_path: str, file_path='subtitle.srt', lang='eng',
|
||||
time_start='0:00', time_end='', use_fullframe=False) -> None:
|
||||
with open(file_path, 'w+') as f:
|
||||
f.write(get_subtitles(
|
||||
video_path, lang, time_start, time_end, use_fullframe))
|
@ -48,9 +48,8 @@ class PredictedFrame:
|
||||
|
||||
self.text = ' '.join(word.text for word in self.words)
|
||||
# remove chars that are obviously ocr errors
|
||||
translate_table = {ord(c): None for c in '<>{}[];`@#$%^*_=~\\'}
|
||||
translate_table[ord('|')] = 'I'
|
||||
self.text = self.text.translate(translate_table).strip()
|
||||
table = str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')
|
||||
self.text = self.text.translate(table).replace(' \n ', '\n').strip()
|
||||
|
||||
def is_similar_to(self, other: PredictedFrame, threshold=70) -> bool:
|
||||
return fuzz.ratio(self.text, other.text) >= threshold
|
||||
@ -58,14 +57,13 @@ class PredictedFrame:
|
||||
|
||||
class PredictedSubtitle:
|
||||
frames: List[PredictedFrame]
|
||||
text: str
|
||||
|
||||
def __init__(self, frames: List[PredictedFrame]):
|
||||
self.frames = [f for f in frames if f.confidence > 0]
|
||||
|
||||
if self.frames:
|
||||
conf_max = max(f.confidence for f in self.frames)
|
||||
self.text = next(f.text for f in self.frames
|
||||
if f.confidence == conf_max)
|
||||
self.text = max(self.frames, key=lambda f: f.confidence).text
|
||||
else:
|
||||
self.text = ''
|
||||
|
||||
|
@ -3,7 +3,6 @@ from concurrent import futures
|
||||
import datetime
|
||||
import pytesseract
|
||||
import cv2
|
||||
import timeit
|
||||
|
||||
from .models import PredictedFrame, PredictedSubtitle
|
||||
|
||||
@ -24,19 +23,16 @@ class Video:
|
||||
self.fps = v.get(cv2.CAP_PROP_FPS)
|
||||
v.release()
|
||||
|
||||
def run_ocr(self, lang: str, use_fullframe=False,
|
||||
time_start='0:00', time_end='') -> None:
|
||||
def run_ocr(self, lang: str, time_start: str, time_end: str,
|
||||
use_fullframe: bool) -> None:
|
||||
self.lang = lang
|
||||
self.use_fullframe = use_fullframe
|
||||
|
||||
ocr_start = self._frame_index(time_start)
|
||||
ocr_start = self._frame_index(time_start) if time_start else 0
|
||||
ocr_end = self._frame_index(time_end) if time_end else self.num_frames
|
||||
|
||||
if time_end:
|
||||
ocr_end = self._frame_index(time_end)
|
||||
if ocr_end < ocr_start:
|
||||
raise ValueError('time_start is later than time_end')
|
||||
else:
|
||||
ocr_end = self.num_frames
|
||||
if ocr_end < ocr_start:
|
||||
raise ValueError('time_start is later than time_end')
|
||||
num_ocr_frames = ocr_end - ocr_start
|
||||
|
||||
# get frames from ocr_start to ocr_end
|
||||
@ -55,7 +51,7 @@ class Video:
|
||||
# convert time str to frame index
|
||||
def _frame_index(self, time: str) -> int:
|
||||
t = time.split(':')
|
||||
t = list(map(int, t))
|
||||
t = list(map(float, t))
|
||||
if len(t) == 3:
|
||||
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
|
||||
elif len(t) == 2:
|
||||
@ -139,19 +135,3 @@ class Video:
|
||||
m, s = divmod(td.seconds, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return '{:02d}:{:02d}:{:02d},{:03d}'.format(h, m, s, ms)
|
||||
|
||||
def save_subtitles_to_file(self, path='subtitle.srt') -> None:
|
||||
with open(path, 'w+') as f:
|
||||
f.write(self.get_subtitles())
|
||||
|
||||
|
||||
time_start = timeit.default_timer()
|
||||
v = Video('1.mp4', 'HanS')
|
||||
v.run_ocr()
|
||||
time_stop = timeit.default_timer()
|
||||
print('time for ocr: ', time_stop - time_start)
|
||||
|
||||
time_start = timeit.default_timer()
|
||||
v.save_subtitles_to_file()
|
||||
time_stop = timeit.default_timer()
|
||||
print('time for save sub: ', time_stop - time_start)
|
||||
|
Loading…
Reference in New Issue
Block a user