diff --git a/app.py b/app.py index 01279c9..02a81f4 100644 --- a/app.py +++ b/app.py @@ -1,64 +1,83 @@ -from flask import Flask, request, jsonify, render_template, Response, stream_with_context -from transformers import AutoTokenizer import os -import requests +from flask import Flask, request, jsonify, render_template, Response, stream_with_context from flask_limiter import Limiter from flask_limiter.util import get_remote_address +from transformers import AutoTokenizer +import requests +import logging app = Flask(__name__) +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + # Initialize rate limiter limiter = Limiter( get_remote_address, app=app, - storage_uri="memory://" + storage_uri="memory://", + default_limits=[os.getenv('RATE_LIMIT', '15 per minute')] ) # Load the tokenizer -tokenizer = AutoTokenizer.from_pretrained(os.environ.get('TOKENIZER', 'gpt2')) +tokenizer = AutoTokenizer.from_pretrained(os.getenv('TOKENIZER', 'gpt2')) -api_url = os.environ.get('API_URL', 'https://api.openai.com/v1') -api_key = os.environ.get('API_KEY') -api_model = os.environ.get('API_MODEL', 'gpt-3.5-turbo') -temperature = int(os.environ.get('TEMPERATURE', 0)) +# API configuration +API_URL = os.getenv('API_URL', 'https://api.openai.com/v1') +API_KEY = os.getenv('API_KEY') +API_MODEL = os.getenv('API_MODEL', 'gpt-3.5-turbo') +TEMPERATURE = float(os.getenv('TEMPERATURE', 0)) + +logger.info(f"Chat initialized using endpoint: {API_URL}, model: {API_MODEL}, temperature: {TEMPERATURE}") @app.route('/v1/tokenizer/count', methods=['POST']) def token_count(): - data = request.json - messages = data.get('messages', []) - full_text = " ".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - tokens = tokenizer.encode(full_text) - token_count = len(tokens) - return jsonify({"token_count": token_count}) + try: + data = request.json + messages = data.get('messages', []) + full_text = " ".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + tokens = tokenizer.encode(full_text) + return jsonify({"token_count": len(tokens)}) + except Exception as e: + logger.error(f"Error in token_count: {str(e)}") + return jsonify({"error": "Invalid request"}), 400 @app.route('/v1/chat/completions', methods=['POST']) -@limiter.limit(os.environ.get('RATE_LIMIT', '20/minute')) +@limiter.limit(os.getenv('RATE_LIMIT', '15/minute')) def proxy_chat_completions(): headers = { - 'Authorization': f'Bearer {api_key}', + 'Authorization': f'Bearer {API_KEY}', 'Content-Type': 'application/json' } - request_data = request.json + try: + request_data = request.json + request_data['model'] = API_MODEL + request_data['temperature'] = TEMPERATURE + request_data['stream'] = True - request_data['model'] = api_model - request_data['temperature'] = temperature + response = requests.post(f"{API_URL}/chat/completions", + json=request_data, + headers=headers, + stream=True) - request_data['stream'] = True + response.raise_for_status() - response = requests.post(f"{api_url}/chat/completions", - json=request_data, - headers=headers, - stream=True) + def generate(): + for chunk in response.iter_content(chunk_size=8): + if chunk: + yield chunk - # Stream the response back to the client - def generate(): - for chunk in response.iter_content(chunk_size=8): - if chunk: - yield chunk + return Response(stream_with_context(generate()), + content_type=response.headers['content-type']) - return Response(stream_with_context(generate()), - content_type=response.headers['content-type']) + except requests.RequestException as e: + logger.error(f"API request failed: {str(e)}") + return jsonify({"error": "Failed to connect to the API"}), 503 + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + return jsonify({"error": "An unexpected error occurred"}), 500 @app.route('/') def index(): @@ -68,5 +87,9 @@ def index(): def serve_static(filename): return app.send_static_file(filename) +@app.errorhandler(429) +def ratelimit_handler(e): + return jsonify({"error": "Rate limit exceeded. Please try again later."}), 429 + if __name__ == '__main__': - app.run(debug=False, port=5000) + app.run(debug=False, port=int(os.getenv('PORT', 5000))) diff --git a/static/css/index.css b/static/css/index.css index fe17994..8eaacc9 100644 --- a/static/css/index.css +++ b/static/css/index.css @@ -366,3 +366,51 @@ p { transform: rotate(90deg); width: 2rem; } + +.error-toast { + position: fixed; + bottom: 20px; + right: 20px; + background-color: var(--red-color); + color: var(--foreground-color); + padding: 1rem; + border-radius: 10px; + max-width: 300px; + box-shadow: 0 8px 15px rgba(0, 0, 0, 0.2); + z-index: 1000; + cursor: pointer; + transition: all 0.3s ease; +} + +.error-toast:hover { + transform: translateY(-5px); + box-shadow: 0 12px 20px rgba(0, 0, 0, 0.3); +} + +@media (max-width: 640px) { + .error-toast { + left: 20px; + right: 20px; + max-width: none; + } +} + +@keyframes shake { + 0%, + 100% { + transform: translateX(0); + } + 10%, + 30%, + 50%, + 70%, + 90% { + transform: translateX(-5px); + } + 20%, + 40%, + 60%, + 80% { + transform: translateX(5px); + } +} diff --git a/static/js/index.js b/static/js/index.js index df91971..9b6c941 100644 --- a/static/js/index.js +++ b/static/js/index.js @@ -12,7 +12,7 @@ document.addEventListener("alpine:init", () => { home: 0, generating: false, endpoint: window.location.origin + "/v1", - model: "llama3-8b-8192", // This doesen't matter anymore as the backend handles it now + model: "llama3-8b-8192", // This doesn't matter anymore as the backend handles it now stopToken: "<|eot_id|>", // We may need this for some models // performance tracking @@ -20,6 +20,9 @@ document.addEventListener("alpine:init", () => { tokens_per_second: 0, total_tokens: 0, + // New property for error messages + errorMessage: null, + removeHistory(cstate) { const index = this.histories.findIndex((state) => { return state.time === cstate.time; @@ -37,6 +40,7 @@ document.addEventListener("alpine:init", () => { if (this.generating) return; this.generating = true; + this.errorMessage = null; // Clear any previous error messages if (this.home === 0) this.home = 1; // ensure that going back in history will go back to home @@ -56,48 +60,56 @@ document.addEventListener("alpine:init", () => { let tokens = 0; this.tokens_per_second = 0; - // start receiving server sent events - let gottenFirstChunk = false; - for await (const chunk of this.openaiChatCompletion( - this.cstate.messages, - )) { - if (!gottenFirstChunk) { - this.cstate.messages.push({ role: "assistant", content: "" }); - gottenFirstChunk = true; - } + try { + // start receiving server sent events + let gottenFirstChunk = false; + for await (const chunk of this.openaiChatCompletion( + this.cstate.messages, + )) { + if (!gottenFirstChunk) { + this.cstate.messages.push({ role: "assistant", content: "" }); + gottenFirstChunk = true; + } - // add chunk to the last message - this.cstate.messages[this.cstate.messages.length - 1].content += chunk; + // add chunk to the last message + this.cstate.messages[this.cstate.messages.length - 1].content += + chunk; - // calculate performance tracking - tokens += 1; - this.total_tokens += 1; - if (start_time === 0) { - start_time = Date.now(); - this.time_till_first = start_time - prefill_start; - } else { - const diff = Date.now() - start_time; - if (diff > 0) { - this.tokens_per_second = tokens / (diff / 1000); + // calculate performance tracking + tokens += 1; + this.total_tokens += 1; + if (start_time === 0) { + start_time = Date.now(); + this.time_till_first = start_time - prefill_start; + } else { + const diff = Date.now() - start_time; + if (diff > 0) { + this.tokens_per_second = tokens / (diff / 1000); + } } } - } - // update the state in histories or add it if it doesn't exist - const index = this.histories.findIndex((cstate) => { - return cstate.time === this.cstate.time; - }); - this.cstate.time = Date.now(); - if (index !== -1) { - // update the time - this.histories[index] = this.cstate; - } else { - this.histories.push(this.cstate); + // update the state in histories or add it if it doesn't exist + const index = this.histories.findIndex((cstate) => { + return cstate.time === this.cstate.time; + }); + this.cstate.time = Date.now(); + if (index !== -1) { + // update the time + this.histories[index] = this.cstate; + } else { + this.histories.push(this.cstate); + } + // update in local storage + localStorage.setItem("histories", JSON.stringify(this.histories)); + } catch (error) { + console.error("Error in handleSend:", error); + this.showError( + error.message || "An error occurred processing your request.", + ); + } finally { + this.generating = false; } - // update in local storage - localStorage.setItem("histories", JSON.stringify(this.histories)); - - this.generating = false; }, async handleEnter(event) { @@ -108,67 +120,90 @@ document.addEventListener("alpine:init", () => { } }, + showError(message) { + this.errorMessage = message; + setTimeout(() => { + this.errorMessage = null; + }, 3000); // Hide after 5 seconds + }, + updateTotalTokens(messages) { fetch(`${window.location.origin}/v1/tokenizer/count`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ messages }), }) - .then((response) => response.json()) + .then((response) => { + if (!response.ok) { + throw new Error("Failed to count tokens"); + } + return response.json(); + }) .then((data) => { this.total_tokens = data.token_count; }) - .catch(console.error); + .catch((error) => { + console.error("Error updating total tokens:", error); + this.showError("Failed to update token count. Please try again."); + }); }, async *openaiChatCompletion(messages) { - // stream response - const response = await fetch(`${this.endpoint}/chat/completions`, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${this.apiKey}`, - }, - body: JSON.stringify({ - model: this.model, - messages: messages, - stream: true, - stop: [this.stopToken], - }), - }); - if (!response.ok) { - throw new Error("Failed to fetch"); - } + try { + const response = await fetch(`${this.endpoint}/chat/completions`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ + model: this.model, + messages: messages, + stream: true, + stop: [this.stopToken], + }), + }); - const reader = response.body.getReader(); - const decoder = new TextDecoder("utf-8"); - let buffer = ""; - - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.error || "Failed to fetch"); } - buffer += decoder.decode(value, { stream: true }); - const lines = buffer.split("\n"); - buffer = lines.pop(); - for (const line of lines) { - if (line.startsWith("data: ")) { - const data = line.slice(6); - if (data === "[DONE]") { - return; - } - try { - const json = JSON.parse(data); - if (json.choices && json.choices[0].delta.content) { - yield json.choices[0].delta.content; + const reader = response.body.getReader(); + const decoder = new TextDecoder("utf-8"); + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop(); + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = line.slice(6); + if (data === "[DONE]") return; + + try { + const json = JSON.parse(data); + if (json.choices && json.choices[0].delta.content) { + yield json.choices[0].delta.content; + } + } catch (error) { + console.error("Error parsing JSON:", error); } - } catch (error) { - console.error("Error parsing JSON:", error); } } } + } catch (error) { + console.error("Error in openaiChatCompletion:", error); + this.showError( + error.message || + "An error occurred while communicating with the server.", + ); + throw error; } }, })); diff --git a/templates/index.html b/templates/index.html index f2c79aa..d60aed0 100644 --- a/templates/index.html +++ b/templates/index.html @@ -1,178 +1,190 @@ - + - - cchat - - + + cchat + + - - - - - + + + + + - - - - + + + + - + - - - + + + - - - + + + - - - + + + - -
- -
-

cchat

-
- -
+
- -
-
-
-
-
- -

-

TTFT

-
- -

-

TOKENS/SEC

-
- -

-

TOKENS

-
-
-
- - -
-
-
- +
+

cchat

+
+ +
+ +
+ +
+
+
+
+
+ +

+

TTFT

+
+ +

+

TOKENS/SEC

+
+ +

+

TOKENS

+
+
+
+ + +
+
+
+
+
+ + + +