File size: 4,098 Bytes
6dcddee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d36e009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb15562
 
93340df
d36e009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93340df
bb15562
 
d36e009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb15562
 
d36e009
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Inference


import requests
import json

def send_request_to_flask(prompt, history, temperature=0.7, max_new_tokens=100, top_p=0.9, repetition_penalty=1.2):
    # URL of the Flask endpoint
    url = "https://jikoni-llamasms.hf.space/generate"  # Adjust the URL if needed

    # Create the payload
    payload = {
        "prompt": prompt,
        "history": history,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "top_p": top_p,
        "repetition_penalty": repetition_penalty
    }

    try:
        # Send the POST request
        response = requests.post(url, json=payload)
        
        # Check if the request was successful
        if response.status_code == 200:
            result = response.json()
            return result["response"]
        else:
            print("Failed to get response from Flask app.")
            print("Status Code:", response.status_code)
            print("Response Text:", response.text)
            return None
    
    except requests.RequestException as e:
        print("An error occurred:", e)
        return None

if __name__ == "__main__":
    history = []  # Initialize an empty history list
    
    while True:
        # Prompt the user for input
        prompt = input("You: ")
        
        if prompt.lower() in ['exit', 'quit', 'stop']:
            print("Exiting the chat.")
            break
        
        # Send request and get response
        response_text = send_request_to_flask(prompt, history)

        if response_text:
            print("Response from Flask app:")
            print(response_text)
            # Update history
            history.append((prompt, response_text))
        else:
            print("No response received.")

"""


from flask import Flask, request, jsonify
from huggingface_hub import InferenceClient

# Initialize Flask app
app = Flask(__name__)

print("\nHello welcome to Sema AI\n", flush=True)  # Flush to ensure immediate output

# Initialize InferenceClient
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    # Print user prompt
    print(f"\nUser: {prompt}\n")

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history)

    # Get response from Mistral model
    response = client.text_generation(
        formatted_prompt,
        **generate_kwargs,
        stream=True,
        details=True,
        return_full_text=False
    )

    output = ""
    for token in response:
        output += token.token.text

    # Print AI response
    print(f"\nSema AI: {output}\n")
    return output

@app.route("/generate", methods=["POST"])
def generate_text():
    data = request.json
    prompt = data.get("prompt", "")
    history = data.get("history", [])
    temperature = data.get("temperature", 0.9)
    max_new_tokens = data.get("max_new_tokens", 256)
    top_p = data.get("top_p", 0.95)
    repetition_penalty = data.get("repetition_penalty", 1.0)

    try:
        response_text = generate(
            prompt,
            history,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty
        )
        return jsonify({"response": response_text})
    except Exception as e:
        # Print error
        print(f"Error: {str(e)}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(debug=True, port=5000)