File size: 4,185 Bytes
d903275
0a2c880
2ff483e
727c299
2ff483e
d903275
 
 
 
2ff483e
7b6b4a2
727c299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b6b4a2
727c299
 
7b6b4a2
727c299
 
 
 
97b7fcf
dc5958a
5bd8a1a
 
dc5958a
727c299
ed2673b
 
 
 
 
 
727c299
 
ed2673b
5bd8a1a
727c299
ed2673b
 
 
 
 
 
7b6b4a2
0a2c880
 
 
 
 
 
 
3b25749
727c299
dc5958a
7b6b4a2
ed2673b
 
 
 
 
 
7b6b4a2
ed2673b
 
 
 
7b6b4a2
0a2c880
 
 
 
5bf06ec
0a2c880
 
 
 
 
 
 
 
b99a30a
1592e82
 
 
 
 
0a2c880
8139f95
 
ed2673b
 
 
 
5bd8a1a
8139f95
 
0cdb0a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests

model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

def get_movie_info(movie_title):
    api_key = "20e959f0f28e6b3e3de49c50f358538a"
    search_url = f"https://api.themoviedb.org/3/search/movie"
    
    # Make a search query to TMDb
    params = {
        "api_key": api_key,
        "query": movie_title,
        "language": "en-US",
        "page": 1,
    }

    try:
        search_response = requests.get(search_url, params=params)
        search_data = search_response.json()

        # Check if any results are found
        if search_data.get("results"):
            movie_id = search_data["results"][0]["id"]

            # Fetch detailed information using the movie ID
            details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
            details_params = {
                "api_key": api_key,
                "language": "en-US",
            }

            details_response = requests.get(details_url, params=details_params)
            details_data = details_response.json()

            # Extract relevant information
            title = details_data.get("title", "Unknown Title")
            year = details_data.get("release_date", "Unknown Year")[:4]
            genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
            tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
            # poster_path = details_data.get("poster_path")
            
            # Convert poster_path to a complete image URL
            # image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else ""

            return {
                "title": title,
                "year": year,
                "genre": genre,
                "tmdb_link": tmdb_link,
            }

        else:
            return {"error": "Movie not found"}

    except Exception as e:
        return {"error": f"Error: {e}"}

def process_image(movie_info):
    # Process the image, return image_url
    # For now, let's just return a placeholder
    return "https://via.placeholder.com/150"

def generate_response(prompt):
    input_text_template = (
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions. "
        f"USER: {prompt} "
        "ASSISTANT:"
    )

    # Call the get_movie_info function to enrich the response
    movie_info = get_movie_info(prompt)

    if "error" in movie_info:
        return f"Error: {movie_info['error']}", None

    # Process the image separately
    image_url = process_image(movie_info)

    # Concatenate the movie info with the input template
    input_text_template += (
        f" Movie Info: Title: {movie_info['title']}, Year: {movie_info['year']}, "
        f"Genre: {movie_info['genre']}\nFind more info here: {movie_info['tmdb_link']}"
    )

    model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)

    gen_conf = {
        "top_k": 20,
        "max_length": 20,
        "temperature": 0.6,
        "do_sample": True,
        "eos_token_id": tokenizer.eos_token_id,
    }

    output = model.generate(**model_inputs, **gen_conf)

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return (
        f"Movie Info: {movie_info['title']}, {movie_info['year']}, {movie_info['genre']}\n\n" \
        f"Generated Response: {generated_text[:100]}",  # Truncate to 100 characters
        image_url
        )

# Define chat function for gr.ChatInterface
def chat_function(message, history):
    response_text, response_image = generate_response(message)
    history.append([message, response_text])
    history.append([message, response_image])  # Separate history for image
    return response_text, response_image

# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(chat_function)
chat_interface.launch(share=True)