vykanand commited on
Commit
3e753c0
·
1 Parent(s): dac9de5

modified app.py 3

Browse files
Files changed (4) hide show
  1. README.md +46 -9
  2. app.py +21 -22
  3. requirements.txt +0 -1
  4. start.sh +0 -1
README.md CHANGED
@@ -3,27 +3,64 @@ title: LLaMA 7B Server
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: "4.17.0"
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
  # LLaMA 7B Server
13
 
14
- A web interface for interacting with the LLaMA 7B model.
15
 
16
  ## Features
17
 
18
  - [x] Text generation
19
- - [x] Chat interface
20
  - [x] Model parameters configuration
 
21
 
22
- ## How to Use
23
 
24
- 1. Enter your prompt in the text box
25
- 2. Click "Generate" or press Enter
26
- 3. View the model's response below
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  ## Model Details
29
 
 
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: fastapi
7
+ sdk_version: "0.95.0"
8
+ app_file: main.py
9
  pinned: false
10
  ---
11
 
12
  # LLaMA 7B Server
13
 
14
+ A FastAPI-based server for interacting with the LLaMA 7B model.
15
 
16
  ## Features
17
 
18
  - [x] Text generation
 
19
  - [x] Model parameters configuration
20
+ - [x] REST API interface
21
 
22
+ ## API Usage
23
 
24
+ ### Text Generation
25
+
26
+ Make a POST request to `/generate` with the following JSON body:
27
+
28
+ ```json
29
+ {
30
+ "prompt": "your prompt here",
31
+ "max_length": 2048,
32
+ "num_beams": 3,
33
+ "early_stopping": true,
34
+ "no_repeat_ngram_size": 3
35
+ }
36
+ ```
37
+
38
+ Example using curl:
39
+
40
+ ```bash
41
+ curl -X POST http://localhost:7860/generate \
42
+ -H "Content-Type: application/json" \
43
+ -d '{"prompt": "Hello, how are you?"}'
44
+ ```
45
+
46
+ Example using Python:
47
+
48
+ ```python
49
+ import requests
50
+
51
+ url = "http://localhost:7860/generate"
52
+ data = {
53
+ "prompt": "Hello, how are you?",
54
+ "max_length": 2048,
55
+ "num_beams": 3,
56
+ "early_stopping": True,
57
+ "no_repeat_ngram_size": 3
58
+ }
59
+
60
+ response = requests.post(url, json=data)
61
+ result = response.json()
62
+ print(result["generated_text"]) # This will contain your generated text
63
+ ```
64
 
65
  ## Model Details
66
 
app.py CHANGED
@@ -1,39 +1,38 @@
1
- import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
 
 
5
  # Load model and tokenizer once on startup
6
  tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-220m")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-220m")
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  model = model.to(device)
10
 
11
- def generate(prompt):
12
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
13
  outputs = model.generate(
14
  **inputs,
15
- max_length=2048,
16
- num_beams=3,
17
- early_stopping=True,
18
- no_repeat_ngram_size=3,
19
  eos_token_id=tokenizer.eos_token_id,
20
  pad_token_id=tokenizer.pad_token_id,
21
  )
22
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
- return output_text
24
-
25
- # Create Gradio interface
26
- iface = gr.Interface(
27
- fn=generate,
28
- inputs=gr.Textbox(lines=10, label="Input Prompt"),
29
- outputs=gr.Textbox(label="Generated Output"),
30
- title="LLaMA 7B Server",
31
- description="A web interface for interacting with the LLaMA 7B model.",
32
- allow_flagging="never",
33
- allow_api=True,
34
- queue=False
35
- )
36
 
37
- # Launch the interface
38
  if __name__ == "__main__":
39
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
6
+ app = FastAPI()
7
+
8
  # Load model and tokenizer once on startup
9
  tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-220m")
10
  model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-220m")
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model = model.to(device)
13
 
14
+ class GenerationRequest(BaseModel):
15
+ prompt: str
16
+ max_length: int = 2048
17
+ num_beams: int = 3
18
+ early_stopping: bool = True
19
+ no_repeat_ngram_size: int = 3
20
+
21
+ @app.post("/generate")
22
+ async def generate_text(request: GenerationRequest):
23
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
24
  outputs = model.generate(
25
  **inputs,
26
+ max_length=request.max_length,
27
+ num_beams=request.num_beams,
28
+ early_stopping=request.early_stopping,
29
+ no_repeat_ngram_size=request.no_repeat_ngram_size,
30
  eos_token_id=tokenizer.eos_token_id,
31
  pad_token_id=tokenizer.pad_token_id,
32
  )
33
  output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ return {"generated_text": output_text}
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
36
  if __name__ == "__main__":
37
+ import uvicorn
38
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -2,4 +2,3 @@ fastapi
2
  uvicorn[standard]
3
  transformers
4
  torch
5
- gradio>=4.17.0
 
2
  uvicorn[standard]
3
  transformers
4
  torch
 
start.sh CHANGED
@@ -1,3 +1,2 @@
1
  #!/bin/bash
2
- python app.py
3
  uvicorn app:app --host 0.0.0.0 --port 7860
 
1
  #!/bin/bash
 
2
  uvicorn app:app --host 0.0.0.0 --port 7860