import torch from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import requests model_name = "Writer/palmyra-small" tokenizer = AutoTokenizer.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained(model_name).to(device) def get_movie_info(movie_title): api_key = "20e959f0f28e6b3e3de49c50f358538a" search_url = f"https://api.themoviedb.org/3/search/movie" # Make a search query to TMDb params = { "api_key": api_key, "query": movie_title, "language": "en-US", "page": 1, } try: search_response = requests.get(search_url, params=params) search_data = search_response.json() # Check if any results are found if search_data.get("results"): movie_id = search_data["results"][0]["id"] # Fetch detailed information using the movie ID details_url = f"https://api.themoviedb.org/3/movie/{movie_id}" details_params = { "api_key": api_key, "language": "en-US", } details_response = requests.get(details_url, params=details_params) details_data = details_response.json() # Extract relevant information title = details_data.get("title", "Unknown Title") year = details_data.get("release_date", "Unknown Year")[:4] genre = ", ".join(genre["name"] for genre in details_data.get("genres", [])) tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}" # poster_path = details_data.get("poster_path") # Convert poster_path to a complete image URL # image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else "" return { "title": title, "year": year, "genre": genre, "tmdb_link": tmdb_link, } else: return {"error": "Movie not found"} except Exception as e: return {"error": f"Error: {e}"} def process_image(movie_info): # Process the image, return image_url # For now, let's just return a placeholder return "https://via.placeholder.com/150" def generate_response(prompt): input_text_template = ( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions. " f"USER: {prompt} " "ASSISTANT:" ) # Call the get_movie_info function to enrich the response movie_info = get_movie_info(prompt) if "error" in movie_info: return f"Error: {movie_info['error']}", None # Process the image separately image_url = process_image(movie_info) # Concatenate the movie info with the input template input_text_template += ( f" Movie Info: Title: {movie_info['title']}, Year: {movie_info['year']}, " f"Genre: {movie_info['genre']}\nFind more info here: {movie_info['tmdb_link']}" ) model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device) gen_conf = { "top_k": 20, "max_length": 20, "temperature": 0.6, "do_sample": True, "eos_token_id": tokenizer.eos_token_id, } output = model.generate(**model_inputs, **gen_conf) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return ( f"Movie Info: {movie_info['title']}, {movie_info['year']}, {movie_info['genre']}\n\n" \ f"Generated Response: {generated_text[:100]}", # Truncate to 100 characters image_url ) # Define chat function for gr.ChatInterface def chat_function(message, history): response_text, response_image = generate_response(message) history.append([message, response_text]) history.append([message, response_image]) # Separate history for image return response_text, response_image # Create Gradio Chat Interface chat_interface = gr.ChatInterface(chat_function) chat_interface.launch(share=True)