Tri4 commited on
Commit
341df5e
·
verified ·
1 Parent(s): 2751952

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -21
main.py CHANGED
@@ -4,6 +4,8 @@ from huggingface_hub import InferenceClient
4
  # Initialize Flask app
5
  app = Flask(__name__)
6
 
 
 
7
  # Initialize InferenceClient
8
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
9
 
@@ -15,7 +17,9 @@ def format_prompt(message, history):
15
  prompt += f"[INST] {message} [/INST]"
16
  return prompt
17
 
18
- def generate_stream(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
 
 
19
  temperature = float(temperature)
20
  if temperature < 1e-2:
21
  temperature = 1e-2
@@ -32,28 +36,26 @@ def generate_stream(prompt, history, temperature=0.9, max_new_tokens=256, top_p=
32
 
33
  formatted_prompt = format_prompt(prompt, history)
34
 
35
- # Get response from Mistral model
36
- response = client.text_generation(
37
- formatted_prompt,
38
- **generate_kwargs,
39
- stream=True,
40
- details=True,
41
- return_full_text=False
42
- )
 
43
 
44
- def generate():
45
  output = ""
46
- try:
47
- for token in response:
48
- if hasattr(token, 'token') and hasattr(token.token, 'text'):
49
- output += token.token.text
50
- yield output # Yield intermediate response
51
- else:
52
- print(f"Unexpected token structure: {token}")
53
- except Exception as e:
54
- print(f"Error while processing streaming response: {str(e)}")
55
 
56
- return generate
 
 
 
 
57
 
58
  @app.route("/generate", methods=["POST"])
59
  def generate_text():
@@ -66,7 +68,7 @@ def generate_text():
66
  repetition_penalty = data.get("repetition_penalty", 1.0)
67
 
68
  try:
69
- return Response(stream_with_context(generate_stream(
70
  prompt,
71
  history,
72
  temperature=temperature,
 
4
  # Initialize Flask app
5
  app = Flask(__name__)
6
 
7
+ print("\nHello welcome to Sema AI\n", flush=True) # Flush to ensure immediate output
8
+
9
  # Initialize InferenceClient
10
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
11
 
 
17
  prompt += f"[INST] {message} [/INST]"
18
  return prompt
19
 
20
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
21
+ print(f"\nUser: {prompt}\n")
22
+
23
  temperature = float(temperature)
24
  if temperature < 1e-2:
25
  temperature = 1e-2
 
36
 
37
  formatted_prompt = format_prompt(prompt, history)
38
 
39
+ try:
40
+ # Get response from Mistral model
41
+ response = client.text_generation(
42
+ formatted_prompt,
43
+ **generate_kwargs,
44
+ stream=True,
45
+ details=True,
46
+ return_full_text=False
47
+ )
48
 
 
49
  output = ""
50
+ for token in response:
51
+ output += token.token.text
52
+ yield token.token.text # Yield each token for streaming
 
 
 
 
 
 
53
 
54
+ # Print AI response
55
+ print(f"\nSema AI: {output}\n")
56
+ except Exception as e:
57
+ print(f"Exception during generation: {str(e)}")
58
+ yield "Error occurred"
59
 
60
  @app.route("/generate", methods=["POST"])
61
  def generate_text():
 
68
  repetition_penalty = data.get("repetition_penalty", 1.0)
69
 
70
  try:
71
+ return Response(stream_with_context(generate(
72
  prompt,
73
  history,
74
  temperature=temperature,