File size: 3,111 Bytes
99c818b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from typing import List, Tuple
from fastapi import FastAPI, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from text_generation import Client

# Ensure the HF_TOKEN environment variable is set
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
    raise ValueError("Please set the HF_TOKEN environment variable.")

# Model and API setup
model_id = 'codellama/CodeLlama-34b-Instruct-hf'
API_URL = "https://api-inference.huggingface.co/models/" + model_id

client = Client(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

EOS_STRING = "</s>"
EOT_STRING = "<EOT>"

app = FastAPI()

# Allow CORS for your frontend application
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change this to your frontend's URL in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Pydantic model for request body
class ChatRequest(BaseModel):
    prompt: str
    history: List[Tuple[str, str]]

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant with a deep knowledge of code and software design. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

def get_prompt(message: str, chat_history: List[Tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)

@app.post("/generate/")
async def generate_response(prompt: str = Form(...), history: str = Form(...)):
    try:
        chat_history = eval(history)  # Convert history string back to list
        system_prompt = DEFAULT_SYSTEM_PROMPT
        message = prompt

        prompt_text = get_prompt(message, chat_history, system_prompt)

        generate_kwargs = dict(
            max_new_tokens=1024,
            do_sample=True,
            top_p=0.9,
            top_k=50,
            temperature=0.1,
        )
        
        stream = client.generate_stream(prompt_text, **generate_kwargs)
        output = ""
        for response in stream:
            if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
                break
            else:
                output += response.token.text

        return {"response": output}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))