"""Assembly AI transcription wrapper.
Python package requirements:
- requests
Optional external program requirements:
- exiftool (to display estimated transcription time)
"""
import argparse
import dataclasses
import json
import os
import subprocess
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from time import sleep
import requests
SUPPORTED_EXTENSIONS = {
".3ga",
".webm",
".8svx",
".mts",
".m2ts",
".ts",
".aac",
".mov",
".ac3",
".mp2",
".aif",
".mp4",
".m4p",
".m4v",
".aiff",
".mxf",
".alac",
".amr",
".ape",
".au",
".dss",
".flac",
".flv",
".m4a",
".m4b",
".m4p",
".m4r",
".mp3",
".mpga",
".ogg",
".oga",
".mogg",
".opus",
".qcp",
".tta",
".voc",
".wav",
".wma",
".wv",
}
# Uploaded files are limited to 2.2GB.
UPLOAD_SIZE_LIMIT_IN_BYTES = 2200 * (10**6)
ACCEPTABLE_LANGUAGES = {
"en_au",
"es",
"it",
"fr_fr",
"nl",
"fr_ca",
"de",
"en_us",
"fr",
"ja",
"en_uk",
"hi",
"pt",
"en",
}
BASE_URL = "https://api.assemblyai.com/v2"
@dataclass(frozen=True)
class Word:
confidence: float
end: int
speaker: str
start: int
text: str
@dataclass(frozen=True)
class Utterance:
confidence: float
end: int
speaker: str
start: int
text: str
words: list[Word]
@dataclass(frozen=True)
class Args:
file: Path
language: str
n_speakers: int | None
key: str | None
def parse_arguments() -> Args:
parser = argparse.ArgumentParser(description="transcribe an audio file")
parser.add_argument(
"-f",
"--file",
metavar="FILE",
type=Path,
required=True,
help="file to be transcribed",
)
parser.add_argument(
"-n",
"--number-speakers",
dest="n_speakers",
metavar="N",
type=int,
required=False,
help="number of expected speakers",
)
parser.add_argument(
"-l",
"--language",
dest="language",
metavar="CODE",
default="en_us",
type=str,
required=False,
help="language to be transcribed (can include region specifier, eg en_uk)",
)
parser.add_argument(
"-k",
"--key",
dest="key",
metavar="API_KEY",
type=str,
required=False,
help="AssemblyAI API key",
)
args = parser.parse_args()
file = args.file
assert file.exists(), f"{file} does not exist"
assert file.suffix in SUPPORTED_EXTENSIONS, f"Unsupported filetype: {file.suffix}"
assert (
file.stat().st_size < UPLOAD_SIZE_LIMIT_IN_BYTES
), "File is too large; upload limit is 2.2GB."
lang = args.language
if lang not in ACCEPTABLE_LANGUAGES:
langs = ", ".join(sorted(ACCEPTABLE_LANGUAGES))
msg = f"Unsupported language: {lang}\n" f"Must be one of: {langs}"
raise RuntimeError(msg)
return Args(
file=file,
language=lang,
n_speakers=args.n_speakers,
key=args.key,
)
def get_duration(file: Path) -> int:
"""Get the duration of `file` to the nearest second."""
args = ["exiftool", "-n", "-S", "-t", "-Duration"]
raw = subprocess.check_output(
[*args, file],
encoding="utf-8",
)
return round(float(raw.strip()))
def calculate_processing_bounds(seconds: int) -> tuple[int, int]:
"""Calculate lower and upper bounds of transcription time in seconds.
AssemblyAI states their transcription time is 15-30% of the file's length,
so this is a (over-cautious) approximation.
"""
upper = seconds // 3
lower = upper // 2
return (lower, upper)
def format_seconds(s: int) -> str:
"""Format integer seconds into XmYYs."""
return f"{s // 60}m{s % 60:02}s"
def log(message: str) -> None:
"""Print a message prefixed with the current time."""
now = datetime.now().strftime("%H:%M")
print(f"{now} :: {message}")
def millis_to_timestamp(m: int) -> str:
"""Convert milliseconds to HH:MM:SS duration string."""
total_seconds = m // 1000
hours = total_seconds // 60 // 60
minutes = total_seconds // 60 % 60
seconds = total_seconds % 60
return f"{hours:02}:{minutes:02}:{seconds:02}"
def _format_part(part: Utterance, opening: str, closing: str) -> str:
timestamp = millis_to_timestamp(part.start)
speaker = part.speaker
text = part.text
return f"[{timestamp}]\n{opening}{speaker}{closing}: {text}"
def format_part_md(part: Utterance) -> str:
"""Format part as markdown."""
return _format_part(part, "**", "**")
def format_part_bb(part: Utterance) -> str:
"""Format part as BBCode."""
return _format_part(part, "[B]", "[/B]")
def attempt_to_print_time_estimate(file: Path) -> None:
try:
duration = get_duration(file)
except subprocess.CalledProcessError:
# Likely that exiftool was not found
return
lower_s, upper_s = calculate_processing_bounds(duration)
lower = format_seconds(lower_s)
upper = format_seconds(upper_s)
log(f"Transcription expected to take between {lower} and {upper}.")
def upload_audio(file: Path, headers: dict[str, str]) -> str:
file_bytes = args.file.read_bytes()
log("Uploading audio file...")
upload_response = requests.post(
BASE_URL + "/upload", headers=headers, data=file_bytes
)
url: str = upload_response.json()["upload_url"]
log(f"Audio uploaded to {url}")
return url
@dataclass(frozen=True)
class SubmissionResponse:
id: str
url: str
@dataclass(frozen=True)
class TranscriptPayload:
audio_url: str
language_code: str
speaker_labels: bool
speakers_expected: int | None
def submit_file_for_transcription(
payload: TranscriptPayload, headers: dict[str, str]
) -> SubmissionResponse:
response = requests.post(
BASE_URL + "/transcript",
headers=headers,
json=dataclasses.asdict(payload),
).json()
if error := response.get("error"):
raise RuntimeError(f"Error response from API: {error}")
transcript_id = response["id"]
transcript_url = f"{BASE_URL}/transcript/{transcript_id}"
return SubmissionResponse(id=transcript_id, url=transcript_url)
def main(args: Args) -> None:
# Key provided as an argument overrides the environment variable.
key = args.key or os.getenv("ASSEMBLY_AI_KEY")
if key is None:
raise RuntimeError(
"ASSEMBLY_AI_KEY environment variable must be set"
" or --key must be given as an argument."
)
headers = {"authorization": key}
audio_url = upload_audio(args.file, headers)
payload = TranscriptPayload(
audio_url=audio_url,
language_code=args.language,
speaker_labels=args.n_speakers is not None,
speakers_expected=args.n_speakers,
)
transcript_response = submit_file_for_transcription(payload, headers)
polling_endpoint = transcript_response.url
attempt_to_print_time_estimate(args.file)
# Print the transcript_id in case you need to look it up in the web UI.
log(f"Transcript ID: {transcript_response.id}")
log("Polling for completion...")
while True:
status_response = requests.get(polling_endpoint, headers=headers).json()
if status_response["status"] == "completed":
log("Completed transcription.")
break
elif status_response["status"] == "error":
error = status_response["error"]
raise RuntimeError(f"Transcript failed: {error}")
else:
sleep(3)
# Write out the raw JSON response
json_file = args.file.with_suffix(".json")
log(f"Writing response JSON to {json_file}")
json_file.write_text(json.dumps(status_response, indent=2))
# Write out the transcription text
if utterances := status_response.get("utterances"):
# Write per-speaker sections
utterances = [Utterance(**obj) for obj in status_response["utterances"]]
output_md = "\n\n".join([format_part_md(u) for u in utterances])
output_bb = (
"[QUOTE]\n"
+ "\n\n".join([format_part_bb(u) for u in utterances])
+ "\n[/QUOTE]"
)
else:
# Transcript not split up by speaker, so write the whole thing.
output_md = status_response["text"]
output_bb = status_response["text"]
md_file = args.file.with_suffix(".md")
log(f"Writing Markdown transcript to {md_file}")
md_file.write_text(output_md)
bb_file = args.file.with_suffix(".bbcode")
log(f"Writing BBCode transcript to {bb_file}")
bb_file.write_text(output_bb)
if __name__ == "__main__":
args = parse_arguments()
main(args)