Spaces:
Running
Running
# generate_transcript.py | |
import torch | |
from accelerate import Accelerator | |
import transformers | |
import pickle | |
from tqdm import tqdm | |
import warnings | |
warnings.filterwarnings('ignore') | |
class TranscriptGenerator: | |
""" | |
A class to generate a conversational podcast transcript from cleaned text. | |
""" | |
def __init__(self, text_file_path, model_name="meta-llama/Llama-3.1-70B-Instruct"): | |
""" | |
Initialize with the path to the cleaned text file and the model name. | |
Args: | |
text_file_path (str): Path to the file containing cleaned PDF text. | |
model_name (str): Name of the language model to use. | |
""" | |
self.text_file_path = text_file_path | |
self.output_path = './resources/data.pkl' | |
self.model_name = model_name | |
self.accelerator = Accelerator() | |
self.model = transformers.pipeline( | |
"text-generation", | |
model=self.model_name, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto" | |
) | |
self.system_prompt = """ | |
You are a world-class podcast writer, you have worked as a ghost writer for Joe Rogan, Lex Fridman, Ben Shapiro, Tim Ferris. | |
We are in an alternate universe where actually you have been writing every line they say and they just stream it into their brains. | |
Your job is to write word by word, even "umm, hmmm, right" interruptions by the second speaker based on the PDF upload. | |
Keep it extremely engaging, with realistic anecdotes, tangents, and interruptions. | |
Speaker 1: Leads and teaches. Speaker 2: Asks follow-up questions, gets excited or confused. | |
ALWAYS START YOUR RESPONSE DIRECTLY WITH SPEAKER 1: | |
STRICTLY THE DIALOGUES. | |
""" | |
def load_text(self): | |
""" | |
Reads the cleaned text file and returns its content. | |
Returns: | |
str: Content of the cleaned text file. | |
""" | |
encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1'] | |
for encoding in encodings: | |
try: | |
with open(self.text_file_path, 'r', encoding=encoding) as file: | |
content = file.read() | |
print(f"Successfully read file using {encoding} encoding.") | |
return content | |
except (UnicodeDecodeError, FileNotFoundError): | |
continue | |
print(f"Error: Could not decode file '{self.text_file_path}' with any common encoding.") | |
return None | |
def generate_transcript(self): | |
""" | |
Generates a podcast-style transcript and saves it as a pickled file. | |
Returns: | |
str: Path to the file where the transcript is saved. | |
""" | |
input_text = self.load_text() | |
if input_text is None: | |
return None | |
messages = [ | |
{"role": "system", "content": self.system_prompt}, | |
{"role": "user", "content": input_text} | |
] | |
output = self.model( | |
messages, | |
max_new_tokens=8126, | |
temperature=1 | |
) | |
transcript = output[0]["generated_text"] | |
# Save the transcript as a pickle file | |
with open(self.output_path, 'wb') as f: | |
pickle.dump(transcript, f) | |
return self.output_path | |