File size: 3,736 Bytes
424ded8
9539d3f
 
0439b4c
 
 
 
 
dae9913
0439b4c
424ded8
 
 
 
0b4346d
 
061103a
2b5ab1c
 
424ded8
f93e307
48a339d
2b5ab1c
48a339d
fe9e631
 
 
92b2ac4
42dbf93
fe9e631
92b2ac4
 
 
48a339d
dfee0c3
 
 
 
 
92b2ac4
424ded8
 
 
 
061103a
48a339d
424ded8
469ea1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424ded8
 
 
 
 
 
 
 
 
 
 
 
061103a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
# from huggingface_hub import InferenceClient
from transformers import pipeline
import os

# Retrieve the Hugging Face API token from environment variables
hf_token = os.getenv("HF_TOKEN")

if not hf_token:
    raise ValueError("API token is not set. Please set the HF_TOKEN environment variable in Space Settings.")

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# requires space hardware update to use large models (TODO)
# client = InferenceClient("mistralai/Mistral-Large-Instruct-2407")
# Note change in instantiation***
# pipeline move to func
# text_generator = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct", use_auth_token=hf_token, trust_remote_code=True)

def authenticate_and_generate(message, history, system_message, max_tokens, temperature, top_p):
    # Initialize the text-generation pipeline with the provided token
    text_generator = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct", use_auth_token=hf_token, trust_remote_code=True)
    
    # Ensure that system_message is a string
    system_message = str(system_message)
    
    # Construct the prompt with system message, history, and user input
    history_str = "\n".join([f"User: {str(msg[0])}\nAssistant: {str(msg[1])}" for msg in history if isinstance(msg, (tuple, list)) and len(msg) == 2])
    prompt = system_message + "\n" + history_str
    prompt += f"\nUser: {message}\nAssistant:"

    # Generate a response using the model
    response = text_generator(prompt, max_length=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, truncation=True)

    # Extract the generated text from the response list
    assistant_response = response[0]['generated_text']
    # Optionally trim the assistant response if it includes the prompt again
    assistant_response = assistant_response.split("Assistant:", 1)[-1].strip()
    return assistant_response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
athena = gr.ChatInterface(
    fn=authenticate_and_generate,
    additional_inputs=[
        gr.Textbox(value=
                   """
                   You are a marketing-minded content writer for Plan.com (a UK telecommunications company).
                   You will be provided a bullet-point list of guidelines from which to generate an article to be published in the company News section of the website. 
                   Please follow these guidelines:
                   - Always speak using British English expressions, syntax, and spelling.
                   - Make the articles engaging and fun, but also professional and informative.
                   To provide relevant contextual information about the company, please source information from the following websites:
                   - https://plan.com/our-story
                   - https://plan.com/products-services
                   - https://plan.com/features/productivity-and-performance
                   - https://plan.com/features/security-and-connectivity
                   - https://plan.com/features/connectivity-and-cost
                   """, 
                   label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    athena.launch()