Fix backward compatibility.
[vuplus_transtreamproxy] / src / main.cpp
1 /*
2  * main.cpp
3  *
4  *  Created on: 2014. 6. 10.
5  *      Author: oskwon
6  */
7
8 #include <stdio.h>
9 #include <unistd.h>
10 #include <string.h>
11 #include <pthread.h>
12 #include <poll.h>
13 #include <errno.h>
14 #include <signal.h>
15
16 #include <string>
17
18 #include "Util.h"
19 #include "Logger.h"
20
21 #include "Http.h"
22 #include "Mpeg.h"
23
24 #include "Demuxer.h"
25 #include "Encoder.h"
26 #include "UriDecoder.h"
27
28 using namespace std;
29 //----------------------------------------------------------------------
30
31 #define BUFFFER_SIZE (188 * 256)
32
33 void show_help();
34 void signal_handler(int sig_no);
35
36 void *source_thread_main(void *params);
37 void *streaming_thread_main(void *params);
38
39 int streaming_write(const char *buffer, size_t buffer_len, bool enable_log = false);
40 //----------------------------------------------------------------------
41
42 static bool is_terminated = true;
43 static int source_thread_id, stream_thread_id;
44 static pthread_t source_thread_handle, stream_thread_handle;
45 //----------------------------------------------------------------------
46
47 int main(int argc, char **argv)
48 {
49         if (argc > 1) {
50                 if (strcmp(argv[1], "-h") == 0)
51                         show_help();
52                 exit(0);
53         }
54         Logger::instance()->init("/tmp/transtreamproxy", Logger::WARNING);
55
56         signal(SIGINT, signal_handler);
57
58         HttpHeader header;
59         std::string req = HttpHeader::read_request();
60
61         DEBUG("request head :\n%s", req.c_str());
62
63         try {
64                 if (req.find("\r\n\r\n") == std::string::npos) {
65                         throw(http_trap("no found request done code.", 400, "Bad Request"));
66                 }
67
68                 if (header.parse_request(req) == false) {
69                         throw(http_trap("request parse error.", 400, "Bad Request"));
70                 }
71
72                 if (header.method != "GET") {
73                         throw(http_trap("not support request type.", 400, "Bad Request, not support request"));
74                 }
75
76                 Encoder encoder;
77                 Source *source = 0;
78                 ThreadParams thread_params = { 0, &encoder, &header };
79
80                 int video_pid = 0, audio_pid = 0, pmt_pid = 0;
81
82                 switch(header.type) {
83                 case HttpHeader::TRANSCODING_FILE:
84                         try {
85                                 std::string uri = UriDecoder().decode(header.page_params["file"].c_str());
86                                 Mpeg *ts = new Mpeg(uri, false);
87                                 pmt_pid   = ts->pmt_pid;
88                                 video_pid = ts->video_pid;
89                                 audio_pid = ts->audio_pid;
90                                 source = ts;
91                         }
92                         catch (const trap &e) {
93                                 throw(http_trap(e.what(), 404, "Not Found"));
94                         }
95                         break;
96                 case HttpHeader::TRANSCODING_LIVE:
97                         try {
98                                 Demuxer *dmx = new Demuxer(&header);
99                                 pmt_pid   = dmx->pmt_pid;
100                                 video_pid = dmx->video_pid;
101                                 audio_pid = dmx->audio_pid;
102                                 source = dmx;
103                         }
104                         catch (const http_trap &e) {
105                                 throw(e);
106                         }
107                         break;
108                 case HttpHeader::M3U:
109                         try {
110                                 std::string response = header.build_response((Mpeg*) source);
111                                 if (response != "") {
112                                         streaming_write(response.c_str(), response.length(), true);
113                                 }
114                         }
115                         catch (...) {
116                         }
117                         exit(0);
118                 default:
119                         throw(http_trap(std::string("not support source type : ") + Util::ultostr(header.type), 400, "Bad Request"));
120                 }
121                 thread_params.source = source;
122
123                 if (!encoder.retry_open(2, 3)) {
124                         throw(http_trap("encoder open fail.", 503, "Service Unavailable"));
125                 }
126
127                 if (encoder.state == Encoder::ENCODER_STAT_OPENED) {
128                         std::string response = header.build_response((Mpeg*) source);
129                         if (response == "") {
130                                 throw(http_trap("response build fail.", 503, "Service Unavailable"));
131                         }
132
133                         streaming_write(response.c_str(), response.length(), true);
134
135                         if (header.type == HttpHeader::TRANSCODING_FILE) {
136                                 ((Mpeg*) source)->seek(header);
137                         }
138
139                         if (!encoder.ioctl(Encoder::IOCTL_SET_VPID, video_pid)) {
140                                 throw(http_trap("video pid setting fail.", 503, "Service Unavailable"));
141                         }
142                         if (!encoder.ioctl(Encoder::IOCTL_SET_APID, audio_pid)) {
143                                 throw(http_trap("audio pid setting fail.", 503, "Service Unavailable"));
144                         }
145                         if (!encoder.ioctl(Encoder::IOCTL_SET_PMTPID, pmt_pid)) {
146                                 throw(http_trap("pmt pid setting fail.", 503, "Service Unavailable"));
147                         }
148                 }
149
150                 is_terminated = false;
151                 source_thread_id = pthread_create(&source_thread_handle, 0, source_thread_main, (void *)&thread_params);
152                 if (source_thread_id < 0) {
153                         is_terminated = true;
154                         throw(http_trap("souce thread create fail.", 503, "Service Unavailable"));
155                 }
156                 else {
157                         pthread_detach(source_thread_handle);
158                         if (!encoder.ioctl(Encoder::IOCTL_START_TRANSCODING, 0)) {
159                                 is_terminated = true;
160                                 throw(http_trap("start transcoding fail.", 503, "Service Unavailable"));
161                         }
162                         else {
163                                 stream_thread_id = pthread_create(&stream_thread_handle, 0, streaming_thread_main, (void *)&thread_params);
164                                 if (stream_thread_id < 0) {
165                                         is_terminated = true;
166                                         throw(http_trap("stream thread create fail.", 503, "Service Unavailable"));
167                                 }
168                         }
169                 }
170                 pthread_join(stream_thread_handle, 0);
171                 is_terminated = true;
172
173                 if (source != 0) {
174                         delete source;
175                         source = 0;
176                 }
177         }
178         catch (const http_trap &e) {
179                 ERROR("%s", e.message.c_str());
180                 std::string error = "";
181                 if (e.http_error == 401 && header.authorization.length() > 0) {
182                         error = header.authorization;
183                 }
184                 else {
185                         error = HttpUtil::http_error(e.http_error, e.http_header);
186                 }
187                 streaming_write(error.c_str(), error.length(), true);
188                 exit(-1);
189         }
190         catch (...) {
191                 ERROR("unknown exception...");
192                 std::string error = HttpUtil::http_error(400, "Bad request");
193                 streaming_write(error.c_str(), error.length(), true);
194                 exit(-1);
195         }
196         return 0;
197 }
198 //----------------------------------------------------------------------
199
200 void *streaming_thread_main(void *params)
201 {
202         if (is_terminated) return 0;
203
204         INFO("streaming thread start.");
205         Encoder *encoder = ((ThreadParams*) params)->encoder;
206         HttpHeader *header = ((ThreadParams*) params)->request;
207
208         try {
209                 int poll_state, rc, wc;
210                 struct pollfd poll_fd[2];
211                 unsigned char buffer[BUFFFER_SIZE];
212
213                 poll_fd[0].fd = encoder->get_fd();
214                 poll_fd[0].events = POLLIN | POLLHUP;
215
216                 while(!is_terminated) {
217                         poll_state = poll(poll_fd, 1, 1000);
218                         if (poll_state == -1) {
219                                 throw(trap("poll error."));
220                         }
221                         else if (poll_state == 0) {
222                                 continue;
223                         }
224                         if (poll_fd[0].revents & POLLIN) {
225                                 rc = wc = 0;
226                                 rc = read(encoder->get_fd(), buffer, BUFFFER_SIZE - 1);
227                                 if (rc <= 0) {
228                                         break;
229                                 }
230                                 else if (rc > 0) {
231                                         wc = streaming_write((const char*) buffer, rc);
232                                         if (wc < rc) {
233                                                 //DEBUG("need rewrite.. remain (%d)", rc - wc);
234                                                 int retry_wc = 0;
235                                                 for (int remain_len = rc - wc; rc != wc; remain_len -= retry_wc) {
236                                                         poll_fd[0].revents = 0;
237
238                                                         retry_wc = streaming_write((const char*) (buffer + rc - remain_len), remain_len);
239                                                         wc += retry_wc;
240                                                 }
241                                                 LOG("re-write result : %d - %d", wc, rc);
242                                         }
243                                 }
244                         }
245                         else if (poll_fd[0].revents & POLLHUP)
246                         {
247                                 if (encoder->state == Encoder::ENCODER_STAT_STARTED) {
248                                         DEBUG("stop transcoding..");
249                                         encoder->ioctl(Encoder::IOCTL_STOP_TRANSCODING, 0);
250                                 }
251                                 break;
252                         }
253                 }
254         }
255         catch (const trap &e) {
256                 ERROR("%s %s (%d)", e.what(), strerror(errno), errno);
257         }
258         is_terminated = true;
259         INFO("streaming thread stop.");
260
261         if (encoder->state == Encoder::ENCODER_STAT_STARTED) {
262                 DEBUG("stop transcoding..");
263                 encoder->ioctl(Encoder::IOCTL_STOP_TRANSCODING, 0);
264         }
265
266         pthread_exit(0);
267
268         return 0;
269 }
270 //----------------------------------------------------------------------
271
272 void *source_thread_main(void *params)
273 {
274         Source *source = ((ThreadParams*) params)->source;
275         Encoder *encoder = ((ThreadParams*) params)->encoder;
276         HttpHeader *header = ((ThreadParams*) params)->request;
277
278         INFO("source thread start.");
279
280         try {
281                 int poll_state, rc, wc;
282                 struct pollfd poll_fd[2];
283                 unsigned char buffer[BUFFFER_SIZE];
284
285                 poll_fd[0].fd = encoder->get_fd();
286                 poll_fd[0].events = POLLOUT;
287
288                 poll_fd[1].fd = source->get_fd();
289                 poll_fd[1].events = POLLIN;
290
291                 while(!is_terminated) {
292                         poll_state = poll(poll_fd, 2, 1000);
293                         if (poll_state == -1) {
294                                 throw(trap("poll error."));
295                         }
296                         else if (poll_state == 0) {
297                                 continue;
298                         }
299
300                         if (poll_fd[0].revents & POLLOUT) {
301                                 rc = wc = 0;
302                                 if (poll_fd[1].revents & POLLIN) {
303                                         rc = read(source->get_fd(), buffer, BUFFFER_SIZE - 1);
304                                         if (rc == 0) {
305                                                 break;
306                                         }
307                                         else if (rc > 0) {
308                                                 wc = write(encoder->get_fd(), buffer, rc);
309                                                 //DEBUG("write : %d", wc);
310                                                 if (wc < rc) {
311                                                         //DEBUG("need rewrite.. remain (%d)", rc - wc);
312                                                         int retry_wc = 0;
313                                                         for (int remain_len = rc - wc; rc != wc; remain_len -= retry_wc) {
314                                                                 poll_fd[0].revents = 0;
315
316                                                                 poll_state = poll(poll_fd, 1, 1000);
317                                                                 if (poll_fd[0].revents & POLLOUT) {
318                                                                         retry_wc = write(encoder->get_fd(), (buffer + rc - remain_len), remain_len);
319                                                                         wc += retry_wc;
320                                                                 }
321                                                         }
322                                                         LOG("re-write result : %d - %d", wc, rc);
323                                                         usleep(500000);
324                                                 }
325                                         }
326                                 }
327                         }
328                 }
329         }
330         catch (const trap &e) {
331                 ERROR("%s %s (%d)", e.what(), strerror(errno), errno);
332         }
333         INFO("source thread stop.");
334
335         pthread_exit(0);
336
337         return 0;
338 }
339 //----------------------------------------------------------------------
340
341 int streaming_write(const char *buffer, size_t buffer_len, bool enable_log)
342 {
343         if (enable_log) {
344                 DEBUG("response data :\n%s", buffer);
345         }
346         return write(1, buffer, buffer_len);
347 }
348 //----------------------------------------------------------------------
349
350 void signal_handler(int sig_no)
351 {
352         INFO("signal no : %d", sig_no);
353         is_terminated = true;
354 }
355 //----------------------------------------------------------------------
356
357 void show_help()
358 {
359         printf("usage : transtreamproxy [-h]\n");
360         printf("\n");
361         printf(" * To active debug mode, input NUMBER on /tmp/debug_on file. (default : warning)\n");
362         printf("   NUMBER : error(1), warning(2), info(3), debug(4), log(5)\n");
363         printf("\n");
364         printf(" ex > echo \"4\" > /tmp/.debug_on\n");
365 }
366 //----------------------------------------------------------------------