Spaces:
Sleeping
Sleeping
File size: 4,219 Bytes
e873032 e17a026 c04572a 4e926ed c04572a e17a026 c04572a 4e926ed e17a026 4e926ed e17a026 c04572a e17a026 506a8d3 e17a026 c04572a 4e926ed c04572a 4e926ed c04572a 3094d88 c04572a 4e926ed c04572a e873032 c04572a e873032 c04572a e873032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import os.path
import pickle
import torch
from openai import OpenAI
base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
model_id = "HiGenius/Headline-Generation-Model"
hf_token = os.environ.get('HF_TOKEN')
openai_api_key = os.environ.get('OPENAI_API_KEY')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model():
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, use_auth_token=hf_token)
model = PeftModel.from_pretrained(base_model, model_id, use_auth_token=hf_token).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_auth_token=hf_token)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side='left'
tokenizer.truncation_side="left"
return tokenizer, model
def summarize_content(content):
client = OpenAI(api_key=openai_api_key)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "Summarize the following article content concisely while preserving key information:"},
{"role": "user", "content": content}
],
max_tokens=600,
temperature=0.3
)
return response.choices[0].message.content
tokenizer, model = load_model()
guideline_path = "./guidelines.txt"
with open(guideline_path, 'r', encoding='utf-8') as f:
guidelines = f.read()
def process_prompt(tokenizer, content, video_summary = '', guidelines = None):
# Check token lengths
content_tokens = len(tokenizer.encode(content))
total_tokens = content_tokens
if video_summary:
total_tokens += len(tokenizer.encode(video_summary))
if content_tokens > 850 or total_tokens > 900:
content = summarize_content(content)
if guidelines:
system_prompt = "You are a helpful assistant that writes engaging headlines. To maximize engagement, you may follow these proven guidelines:\n" + guidelines
else:
system_prompt = "You are a helpful assistant that writes engaging headlines."
user_prompt = (
f"Below is an article and its accompanying video summary:\n\n"
f"Article Content:\n{content}\n\n"
f"Video Summary:\n{'None' if video_summary == '' else video_summary}\n\n"
f"Write ONLY a single engaging headline that accurately reflects the article. Do not include any additional text, explanations, or options."
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def generate_headlines(content, video_summary):
if not content.strip():
return "Please enter valid article content."
if not video_summary.strip():
video_summary = ''
prompt = process_prompt(tokenizer, content, video_summary, guidelines)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)
headlines = []
for i in range(5):
outputs = model.generate(**inputs,
max_new_tokens=60,
num_return_sequences=1,
do_sample=True,
temperature=0.7)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
response = response.replace('"', '')
headlines.append(f"Headline {i+1}: {response}")
return "\n\n".join(headlines)
# Create Gradio interface
demo = gr.Interface(
fn=generate_headlines,
inputs=[
gr.Textbox(label="Article Content", placeholder="Type the main content of the article here..."),
gr.Textbox(label="Video Summary (Optional)", placeholder="Type the summary of the video related to the article...")
],
outputs=gr.Textbox(label="Generated Headlines"),
title="Article Headline Writer",
description="Write catchy headlines from content and video summary."
)
if __name__ == "__main__":
demo.launch()
|