File size: 4,315 Bytes
917a8c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
from fastapi import FastAPI, File, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from groq import Groq
import io
# Set up the Groq client
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
client = Groq(api_key=os.environ["GROQ_API_KEY"])
app = FastAPI()
# Pydantic model for the transcription result
class TranscriptionResponse(BaseModel):
transcription: str
class FeedbackResponse(BaseModel):
grammar_feedback: str
vocabulary_feedback: str
grammar_score: int
vocabulary_score: int
@app.get("/")
async def index():
return {"message": "Welcome to the Audio Transcription API!"}
@app.post("/transcribe")
async def transcribe_audio(audio_data: bytes = File(...), language: str = Form(...)):
try:
# Transcribe the audio based on the selected language
transcription = client.audio.transcriptions.create(
file=("audio.wav", audio_data),
model="whisper-large-v3",
prompt="Transcribe the audio accurately based on the selected language.",
response_format="text",
language=language,
)
return JSONResponse(content=TranscriptionResponse(transcription=transcription).dict())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/check_grammar")
async def check_grammar(transcription: str = Form(...), language: str = Form(...)):
if not transcription or not language:
raise HTTPException(status_code=400, detail="Missing transcription or language selection")
try:
# Grammar check
grammar_prompt = (
f"Briefly check the grammar of the following text in {language}: {transcription}. "
"Identify any word that does not belong to the selected language and flag it. Based on the number of incorrect words, "
"also check the grammar deeply and carefully. Provide a score from 1 to 10 based on the grammar accuracy, reducing points for incorrect words, "
"and make sure to output the score on a new line after two line breaks like \"SCORE=\"."
)
grammar_check_response = client.chat.completions.create(
model="llama3-groq-70b-8192-tool-use-preview",
messages=[{"role": "user", "content": grammar_prompt}]
)
grammar_feedback = grammar_check_response.choices[0].message.content.strip()
# Vocabulary check
vocabulary_prompt = (
f"Check the vocabulary accuracy of the following text in {language}: {transcription}. "
"Identify any word that does not belong to the selected language and flag it. Based on the number of incorrect words, "
"also check the grammar deeply and carefully. Provide a score from 1 to 10 based on the vocabulary accuracy, reducing points for incorrect words, "
"and make sure to output the score on a new line after two line breaks like \"SCORE=\"."
)
vocabulary_check_response = client.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[{"role": "user", "content": vocabulary_prompt}]
)
vocabulary_feedback = vocabulary_check_response.choices[0].message.content.strip()
# Calculate scores from feedback
grammar_score = calculate_score(grammar_feedback)
vocabulary_score = calculate_score(vocabulary_feedback)
return JSONResponse(content=FeedbackResponse(
grammar_feedback=grammar_feedback,
vocabulary_feedback=vocabulary_feedback,
grammar_score=grammar_score,
vocabulary_score=vocabulary_score
).dict())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def calculate_score(feedback):
"""
Calculate score based on feedback content.
This function searches for the keyword 'SCORE=' or similar variations
(SCORE:, score:, etc.) and extracts the score value.
"""
import re
match = re.search(r'(SCORE=|score=|SCORE:|score:|SCORE = )\s*(\d+)', feedback)
if match:
# Extract and return the score as an integer
return int(match.group(2))
# Return a default score of 0 if no score is found in the feedback
return 0
|